Commit 425a2473 authored by wxj's avatar wxj
Browse files

Update Llama_pretraining.sh

parent 4923c381
Pipeline #2075 failed with stages
in 0 seconds
...@@ -28,12 +28,12 @@ TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path> ...@@ -28,12 +28,12 @@ TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH="/datasets/oscar-1GB-llama_text_document" #<Specify path and file prefix>_text_document DATA_PATH="/datasets/oscar-1GB-llama_text_document" #<Specify path and file prefix>_text_document
GPT_MODEL_ARGS=( GPT_MODEL_ARGS=(
--num-layers 6 --num-layers 36
--hidden-size 1024 --hidden-size 4096
--ffn-hidden-size 2048 --ffn-hidden-size 11008
--num-attention-heads 16 --num-attention-heads 32
--seq-length 4096 #4096 --seq-length 4096 #4096
--max-position-embeddings 32768 --max-position-embeddings 4096
) )
# export NVTE_FLASH_ATTN=1 # 走autlass # export NVTE_FLASH_ATTN=1 # 走autlass
...@@ -69,7 +69,10 @@ TRAINING_ARGS=( ...@@ -69,7 +69,10 @@ TRAINING_ARGS=(
--lr-decay-style cosine --lr-decay-style cosine
--min-lr 3.0e-6 --min-lr 3.0e-6
--lr-warmup-iters 1 --lr-warmup-iters 1
--use-flash-attn-triton
) )
# --use-flash-attn-ck
# --use-flash-attn-triton
MODEL_PARALLEL_ARGS=( MODEL_PARALLEL_ARGS=(
--sequence-parallel --sequence-parallel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment