sbatch_train.sh 3.11 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/bin/bash
#SBATCH --job-name=grpo_train           # 作业名称
#SBATCH --output=logs/grpo_train_%j.out # 输出日志文件
#SBATCH --error=logs/grpo_train_%j.out  # 错误日志文件
#SBATCH --nodes=2                         # 使用节点数量
#SBATCH --qos=dcudvp
#SBATCH --gres=dcu:8                      # 每节点 8 张 DCU
#SBATCH --cpus-per-task=32                # 每个任务分配 32 个 CPU
#SBATCH --partition=dcu                   # 使用 DCU 分区sinfo
#SBATCH --ntasks-per-node=1
#SBATCH --mem=960G

NODE_LIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))

for RANK in "${!NODE_LIST[@]}"; do
    node="${NODE_LIST[$RANK]}"
    srun --nodes=1 --exclusive -w $node bash <<EOF &

source ~/packages/dtk-25.04.1/env.sh
source ~/miniconda3/etc/profile.d/conda.sh
conda activate grpo

export DISABLE_VERSION_CHECK=1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 默认节点8张卡
export HSA_FORCE_FINE_GRAIN_PCIE=1
export ALLREDUCE_STREAM_WITH_COMPUTE=1
export HF_ENDPOINT=https://hf-mirror.com

export MASTER_ADDR=XXXXXX ## 实际启动mastet节点 hostname或者IP地址
export MASTER_ADDR=${NODE_LIST[0]}
export RANK=$RANK
export MASTER_PORT=29568
export WORLD_SIZE=$((8 * ${#NODE_LIST[@]}))

export NCCL_SOCKET_IFNAME=ibxxxxx # ifconfig查看实际IB网口名
export NCCL_DEBUG=INFO
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export NCCL_MIN_NCHANNELS=32
export NCCL_MAX_NCHANNELS=32
export NCCL_MIN_P2P_NCHANNELS=32
export NCCL_MAX_P2P_NCHANNELS=32
export NCCL_NCHANNELS_PER_PEER=32
export VLLM_RPC_TIMEOUT=1800000
export NCCL_IB_TIMEOUT=30

export VLLM_MLA_DISABLE=0
export VLLM_USE_FLASH_MLA=1

echo "分配的节点列表: $SLURM_NODELIST"
echo "主节点地址: \$MASTER_ADDR"
echo "主节点端口: \$MASTER_PORT"
echo "总进程数: \$WORLD_SIZE"
sleep \$((RANK*3))
DISTRIBUTED_ARGS="
    --nproc_per_node=8 \
    --nnodes=\$SLURM_JOB_NUM_NODES \
    --node-rank=\${RANK} \
    --master_addr=\${MASTER_ADDR} \
    --master_port=\${MASTER_PORT}
"
torchrun \$DISTRIBUTED_ARGS /path/of/llama-factory/src/train.py   \
    --deepspeed /path/of/deepspeed/ds_z3_config.json \
    --stage grpo \
    --do_train \
    --finetuning_type freeze \
    --freeze_trainable_layers 5 \
    --freeze_trainable_modules all \
    --model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Llama-70B \
    --dataset_dir /path/of/llama-factory-0923/data/ \
    --dataset dapo_math,hiyouga-math12k \
    --max_samples 20000 \
    --template deepseekr1 \
    --output_dir /path/of/saves/DeepSeek-R1-Distill-Llama-70B-0923/grpo/full/ \
    --overwrite_output_dir \
    --trust_remote_code \
    --warmup_ratio 0.1 \
    --max_grad_norm 1.0 \
    --weight_decay 0.1 \
    --repetition_penalty 50 \
    --top_k 50 \
    --top_p 0.8 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --preprocessing_num_workers 16 \
    --ddp_timeout 120000000 \
    --learning_rate 5e-3 \
    --lr_scheduler_type cosine \
    --optim paged_adamw_32bit \
    --logging_steps 1 \
    --cutoff_len 8192 \
    --save_steps 100 \
    --plot_loss True \
    --num_train_epochs 1 \
    --bf16 \
    --seed 42 \
    --report_to none \
    --save_only_model
EOF
done
wait