set -x

if [ "$#" -lt 2 ]; then
    echo "Usage: run_mixtral_8x7B.sh <nproc_per_node> <save_path> [other_configs...]"
    exit 1
fi

nproc_per_node=$1
save_path=$2

source /opt/dtk/env.sh
# export NCCL_P2P_LEVEL=PXB # SYS

# Runs Mixtral 8x7B model
# export HIP_DIRECT_DISPATCH=0
# export HSA_FORCE_FINE_GRAIN_PCIE=1
# export OMP_NUM_THREADS=1
# export GPU_MAX_HW_QUEUES=10
# export NCCL_ALGO=Ring
# export NCCL_SOCKET_IFNAME=enp33s0f3u1

# export NCCL_NCHANNELS_PER_PEER=16
# export NCCL_MIN_NCHANNELS=32 # 20
# export NCCL_MAX_NCHANNELS=32 # 20
# export NCCL_IB_TIMEOUT=22
# export CUDA_DEVICE_MAX_CONNECTIONS=1

# export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
# export NCCL_NET_GDR_LEVEL=7
# export NCCL_NET_GDR_READ=1
# export RCCL_SDMA_COPY_ENABLE=0
# export NCCL_TOPO_FILE="/public/home/fugx1/datasets/rccl-test/topo-input.xml"
# export GLOG_minloglevel=3 # 打印error级别的nccl日志

# export PATH=/opt/hpc/software/mpi/hpcx/2.12.0/gcc-8.3.1/bin/:$PATH
# export LD_LIBRARY_PATH=/opt/hpc/software/mpi/hpcx/2.12.0/gcc-8.3.1/lib/:$LD_LIBRARY_PATH

# 导入hipblaslt库
# export LD_LIBRARY_PATH=/public/home/fugx1/tests1/test03/whl/hipblaslt-install-dtk-25.04-0212/lib:$LD_LIBRARY_PATH

# 更新rocblas
# export LD_LIBRARY_PATH=/public/home/fugx1/tests1/test03/whl/rocblas-install-0224/lib:$LD_LIBRARY_PATH


# RANK=$OMPI_COMM_WORLD_RANK
# LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
# WORLD_SIZE=$OMPI_COMM_WORLD_SIZE


# Shift the arguments so $@ refers to the rest
shift 2

torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
     -m verl.trainer.fsdp_sft_trainer \
    data.train_files=/public/home/fugx1/ds/gsm8k/gsm8k-train.parquet \
    data.val_files=/public/home/fugx1/ds/gsm8k/gsm8k-test.parquet \
    data.prompt_key='question' \
    data.response_key='answer' \
    +data.prompt_dict_keys=['question'] \
    +data.response_dict_keys=['answer'] \
    data.micro_batch_size_per_gpu=4 \
    model.partial_pretrain=/public/opendas/DL_DATA/llm-models/Mixtral-8x7B-Instruct-v0.1 \
    trainer.default_local_dir=$save_path \
    trainer.project_name=gsm8k-sft \
    trainer.experiment_name=gsm8k-sft-mixtral-8x7B-Instruct-v0.1 \
    trainer.total_epochs=1 \
    trainer.logger=['console'] \
    trainer.default_hdfs_dir=null $@
