Commit 64266070 authored by wxj's avatar wxj
Browse files

Update Llama_pretraining.sh

parent 40ea1bd3
Pipeline #2206 passed with stage
...@@ -48,8 +48,11 @@ GPT_MODEL_ARGS=( ...@@ -48,8 +48,11 @@ GPT_MODEL_ARGS=(
--hidden-size 4096 --hidden-size 4096
--ffn-hidden-size 11008 --ffn-hidden-size 11008
--num-attention-heads 32 --num-attention-heads 32
--seq-length 4096 #4096
--max-position-embeddings 4096 --max-position-embeddings 4096
--normalization RMSNorm
--position-embedding-type rope
--untie-embeddings-and-output-weights # 分开处理embed和输出权重, 增加灵活性
) )
# export NVTE_FLASH_ATTN=1 # 走cutlass # export NVTE_FLASH_ATTN=1 # 走cutlass
...@@ -83,11 +86,13 @@ TRAINING_ARGS=( ...@@ -83,11 +86,13 @@ TRAINING_ARGS=(
--min-lr 3.0e-6 --min-lr 3.0e-6
--lr-warmup-iters 1 --lr-warmup-iters 1
--ckpt-format torch --ckpt-format torch
--ddp-average-in-collective --ddp-average-in-collective # 在dp阶段通信中, 梯度或参数将被直接平均, 而不是先求和(到一个设备)再平均
# --recompute-granularity full # 开启重计算降低显存增加耗时 # --recompute-granularity full # 开启重计算降低显存增加耗时
# --recompute-num-layers 5 #0 # # --recompute-num-layers 5 #0 #
# --recompute-method block # --recompute-method block
--overlap-grad-reduce --overlap-grad-reduce # 重叠ddp grad reduce
# --tp-comm-overlap # tensor parallel comm和gemm重叠, 优化项未适配
# --tp-comm-overlap-rs-dgrad # reduce-scatter和dgrad gemm重叠, 优化项未适配
--use-flash-attn-triton --use-flash-attn-triton
) )
# --use-flash-attn-cutlass # cutlass fa # --use-flash-attn-cutlass # cutlass fa
...@@ -96,16 +101,13 @@ TRAINING_ARGS=( ...@@ -96,16 +101,13 @@ TRAINING_ARGS=(
MODEL_PARALLEL_ARGS=( MODEL_PARALLEL_ARGS=(
--sequence-parallel --sequence-parallel
--tensor-model-parallel-size 2 --tensor-model-parallel-size 2
--pipeline-model-parallel-size 4 --pipeline-model-parallel-size 2
) )
DATA_ARGS=( DATA_ARGS=(
--data-path $DATA_PATH --data-path $DATA_PATH
--seq-length 4096 #4096
--split 949,50,1 --split 949,50,1
--untie-embeddings-and-output-weights
--use-rotary-position-embeddings
--normalization RMSNorm
--no-position-embedding
--tokenizer-type Llama2Tokenizer --tokenizer-type Llama2Tokenizer
--tokenizer-model /data/model_weights/llama2_7b_hf/tokenizer.model --tokenizer-model /data/model_weights/llama2_7b_hf/tokenizer.model
) )
...@@ -157,46 +159,46 @@ APP="python -u pretrain_gpt.py \ ...@@ -157,46 +159,46 @@ APP="python -u pretrain_gpt.py \
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # # 4,5,6,7 #, export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # # 4,5,6,7 #,
# export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3, # export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3,
# ${APP} ${APP}
case ${LOCAL_RANK} in # case ${LOCAL_RANK} in
[0]) # [0])
# # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# # numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [1])
# # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# # numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [2])
# # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# # numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [3])
# # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# # numactl --cpunodebind=0 --membind=0 ${APP}
# ;;
# [4])
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
[1]) # [5])
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
[2]) # [6])
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
[3]) # [7])
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
[4]) # esac
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
numactl --cpunodebind=0 --membind=0 ${APP}
;;
[5])
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
numactl --cpunodebind=0 --membind=0 ${APP}
;;
[6])
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
numactl --cpunodebind=0 --membind=0 ${APP}
;;
[7])
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
numactl --cpunodebind=0 --membind=0 ${APP}
;;
esac
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment