Commit 2b09ea90 authored by liangjing's avatar liangjing
Browse files

update

parent af4cf80e
Pipeline #1889 passed with stage
...@@ -538,8 +538,6 @@ def validate_args(args, defaults={}): ...@@ -538,8 +538,6 @@ def validate_args(args, defaults={}):
if args.decoupled_lr is not None or args.decoupled_min_lr is not None: if args.decoupled_lr is not None or args.decoupled_min_lr is not None:
assert not args.use_legacy_models, \ assert not args.use_legacy_models, \
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.' '--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
args.use_flash_attn = args.use_flash_attn_ck or args.use_flash_attn_triton
# Legacy RoPE arguments # Legacy RoPE arguments
if args.use_rotary_position_embeddings: if args.use_rotary_position_embeddings:
...@@ -1220,11 +1218,9 @@ def _add_training_args(parser): ...@@ -1220,11 +1218,9 @@ def _add_training_args(parser):
group.add_argument('--cross-entropy-loss-fusion', action='store_true', group.add_argument('--cross-entropy-loss-fusion', action='store_true',
help='Enabled fusion of cross entropy loss calculation.', help='Enabled fusion of cross entropy loss calculation.',
dest='cross_entropy_loss_fusion') dest='cross_entropy_loss_fusion')
group.add_argument('--use-flash-attn-ck', action='store_true', group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. ' help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135') 'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--disable-bias-linear', action='store_false', group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers', help='Disable bias in the linear layers',
dest='add_bias_linear') dest='add_bias_linear')
......
...@@ -38,10 +38,10 @@ GPT_MODEL_ARGS=( ...@@ -38,10 +38,10 @@ GPT_MODEL_ARGS=(
TRAINING_ARGS=( TRAINING_ARGS=(
--log-throughput --log-throughput
--transformer-impl local --transformer-impl transformer_engine
--use-legacy-models --use-mcore-models
--micro-batch-size 1 --micro-batch-size 1
--global-batch-size 12 #512 --global-batch-size 12
--train-iters 100 --train-iters 100
--weight-decay 0.1 --weight-decay 0.1
--adam-beta1 0.9 --adam-beta1 0.9
...@@ -50,7 +50,7 @@ TRAINING_ARGS=( ...@@ -50,7 +50,7 @@ TRAINING_ARGS=(
--clip-grad 1.0 --clip-grad 1.0
--bf16 --bf16
--use-distributed-optimizer --use-distributed-optimizer
--use-flash-attn-triton --use-flash-attn
--disable-bias-linear --disable-bias-linear
--attention-dropout 0 --attention-dropout 0
--hidden-dropout 0 --hidden-dropout 0
...@@ -61,7 +61,6 @@ TRAINING_ARGS=( ...@@ -61,7 +61,6 @@ TRAINING_ARGS=(
--lr-decay-style cosine --lr-decay-style cosine
--min-lr 3.0e-6 --min-lr 3.0e-6
--lr-warmup-iters 1 --lr-warmup-iters 1
--use-fast-cross-entropy-loss
) )
MODEL_PARALLEL_ARGS=( MODEL_PARALLEL_ARGS=(
--sequence-parallel --sequence-parallel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment