run-pretrain.sh 2.16 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
rm -rf output_mcore_llava_pretrain/
#需要用transformers==4.45
ENV=dsw                          # 运行环境配置开关: dsw单机训练训练,dlc表示多机训练环境
MODEL_SIZE=7B                   # 模型结构参数量级: 0.5B/1.5B/3B/7B/14B/32B/72B
BATCH_SIZE=1                   # 一次迭代一个数据并行内的样本数
GLOBAL_BATCH_SIZE=256            # 一次迭代多个数据并行的总样本数
LR=0.00015                           # 学习率
MIN_LR=1e-5                       # 最小学习率
SEQ_LEN=576                      # 序列长度
DECODER_SEQ_LEN=1024              # 解码序列长度
PR=fp16                         # 训练精度: fp16, bf16, fp8
TP=4                        # 模型并行度
PP=1                        # 流水并行度
CP=1                       # 上下文并行度
DO=true                        # 是否使用Megatron版Zero-1降显存优化器: true, false
FL=true                        # 是否优先使用Flash Attention: true, false
AC=true                        # 激活检查点模式: sel, full, offload, false
OPTIMIZER_OFFLOAD=false         # 是否启用Offload optimizer: false, static, auto
SAVE_INTERVAL=100000             # 保存ckpt的间隔
DATASET_PATH=/public/new-pai/data/llava-datasets/wds              # 训练数据集路径
VALID_DATASET_PATH=/public/new-pai/data/llava-datasets/wds        # 验证数据集路径
PRETRAIN_CHECKPOINT_PATH=/public/new-pai/Pai-Megatron-Patch/examples/llava_mcore/Mistral-7B-Instruct-v0.3-to-mcore-tp2-pp1  #需要用转换后的模型
TRAIN_ITERS=20000               # 训练TOKEN或者Iter数
LR_WARMUP_ITERS=200           # 预热TOKEN或者Iter数        
OUTPUT_BASEPATH=./output_mcore_llava_pretrain     # 训练输出日志文件路径

sh run_mcore_llava-dcu.sh  \
    $ENV  \
    $MODEL_SIZE   \
    $BATCH_SIZE    \
    $GLOBAL_BATCH_SIZE \
    $LR   \
    $MIN_LR   \
    $SEQ_LEN  \
    $DECODER_SEQ_LEN  \
    $PR  \
    $TP   \
    $PP  \
    $CP \
    $DO \
    $FL   \
    $AC \
    $OPTIMIZER_OFFLOAD \
    $SAVE_INTERVAL  \
    $DATASET_PATH   \
    $VALID_DATASET_PATH   \
    $PRETRAIN_CHECKPOINT_PATH \
    $TRAIN_ITERS  \
    $LR_WARMUP_ITERS   \
    $OUTPUT_BASEPATH