Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
e7da80dd
Commit
e7da80dd
authored
Dec 16, 2024
by
wxj
Browse files
Merge branch 'main' into 'main'
更新torchprof支持 See merge request
!6
parents
340ddce9
fef9c0d9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
82 deletions
+140
-82
Llama_pretraining.sh
Llama_pretraining.sh
+85
-45
megatron/training/arguments.py
megatron/training/arguments.py
+4
-2
megatron/training/training.py
megatron/training/training.py
+51
-35
No files found.
Llama_pretraining.sh
View file @
e7da80dd
...
@@ -19,16 +19,34 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
...
@@ -19,16 +19,34 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
export
NCCL_IB_HCA
=
mlx5_1,mlx5_2
export
NCCL_IB_HCA
=
mlx5_1,mlx5_2
export
NCCL_NET_GDR_LEVEL
=
SYS
export
NCCL_NET_GDR_LEVEL
=
SYS
export
NCCL_NET_GDR_READ
=
0
export
NCCL_NET_GDR_READ
=
0
export
GLOG_minloglevel
=
3
# 打印error级别的nccl日志
source
/opt/dtk/env.sh
source
/opt/dtk/env.sh
# te调用gemm需要导入hipblaslt库
# te调用gemm需要导入hipblaslt库
# export LD_LIBRARY_PATH=/data/hipblaslt-install-0904/lib:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/data/hipblaslt-install-0904/lib:$LD_LIBRARY_PATH
# 更新rocblas
export
LD_LIBRARY_PATH
=
/data/rocblas-install/lib:
$LD_LIBRARY_PATH
# # prof采集添加同步
# export GPU_FLUSH_ON_EXECUTION=1
# export HIP_DIRECT_DISPATCH=0
CHECKPOINT_PATH
=
./tmp_7b
#$1 #<Specify path>
CHECKPOINT_PATH
=
./tmp_7b
#$1 #<Specify path>
TENSORBOARD_LOGS_PATH
=
./tmp_7b
#$2 #<Specify path>
TENSORBOARD_LOGS_PATH
=
./tmp_7b
#$2 #<Specify path>
DATA_PATH
=
"/datasets/oscar-1GB-llama_text_document"
#<Specify path and file prefix>_text_document
DATA_PATH
=
"/data/datasets/nemo_pretrain/oscar-1GB/oscar-1GB-llama_text_document"
#<Specify path and file prefix>_text_document
# GPT_MODEL_ARGS=(
# --num-layers 32
# --hidden-size 5120
# --ffn-hidden-size 13824
# --num-attention-heads 40
# --seq-length 4096 #4096
# --max-position-embeddings 32768 #4096
# --num-query-groups 40
# --group-query-attention
# )
GPT_MODEL_ARGS
=(
GPT_MODEL_ARGS
=(
--num-layers
3
6
--num-layers
6
--hidden-size
4096
--hidden-size
4096
--ffn-hidden-size
11008
--ffn-hidden-size
11008
--num-attention-heads
32
--num-attention-heads
32
...
@@ -36,17 +54,18 @@ GPT_MODEL_ARGS=(
...
@@ -36,17 +54,18 @@ GPT_MODEL_ARGS=(
--max-position-embeddings
4096
--max-position-embeddings
4096
)
)
# export NVTE_FLASH_ATTN=1 # 走
a
utlass
# export NVTE_FLASH_ATTN=1 # 走
c
utlass
#
export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa
export
NVTE_FLASH_ATTN_TRITON
=
1
# 走triton_fa
# --transformer-impl transformer_engine
# --transformer-impl transformer_engine
# --use-mcore-models
# --use-mcore-models
# --transformer-impl local
# --use-legacy-models
TRAINING_ARGS
=(
TRAINING_ARGS
=(
--transformer-impl
local
--transformer-impl
transformer_engine
--use-
legacy
-models
--use-
mcore
-models
--micro-batch-size
1
--micro-batch-size
1
--global-batch-size
6
0
#240 #512 #64
--global-batch-size
6
#240
#60
#512 #64
--train-iters
10
0
--train-iters
10
--weight-decay
0.1
--weight-decay
0.1
--adam-beta1
0.9
--adam-beta1
0.9
--adam-beta2
0.95
--adam-beta2
0.95
...
@@ -54,24 +73,32 @@ TRAINING_ARGS=(
...
@@ -54,24 +73,32 @@ TRAINING_ARGS=(
--clip-grad
1.0
--clip-grad
1.0
--bf16
--bf16
--use-distributed-optimizer
--use-distributed-optimizer
--ckpt-format
torch
--disable-bias-linear
--disable-bias-linear
--overlap-grad-reduce
--attention-dropout
0
--attention-dropout
0
--hidden-dropout
0
--hidden-dropout
0
--ddp-average-in-collective
--recompute-granularity
full
--recompute-num-layers
5
--recompute-method
block
--no-gradient-accumulation-fusion
--no-gradient-accumulation-fusion
--add-qkv-bias
--swiglu
--swiglu
--lr
3.0e-5
--lr
3.0e-5
--lr-decay-style
cosine
--lr-decay-style
cosine
--min-lr
3.0e-6
--min-lr
3.0e-6
--lr-warmup-iters
1
--lr-warmup-iters
1
--ckpt-format
torch
--ddp-average-in-collective
--recompute-granularity
full
--recompute-num-layers
5
#0 #
--recompute-method
block
--overlap-grad-reduce
--use-flash-attn-triton
--use-flash-attn-triton
)
)
# --use-flash-attn-ck
# --add-qkv-bias # qwen
# --ckpt-format torch
# --ddp-average-in-collective
# --recompute-granularity full
# --recompute-num-layers 5
# --recompute-method block
# --overlap-grad-reduce
# --use-flash-attn-cutlass
# --use-flash-attn-triton
# --use-flash-attn-triton
MODEL_PARALLEL_ARGS
=(
MODEL_PARALLEL_ARGS
=(
...
@@ -88,7 +115,7 @@ DATA_ARGS=(
...
@@ -88,7 +115,7 @@ DATA_ARGS=(
--normalization
RMSNorm
--normalization
RMSNorm
--no-position-embedding
--no-position-embedding
--tokenizer-type
Llama2Tokenizer
--tokenizer-type
Llama2Tokenizer
--tokenizer-model
/
p
at
h/to
/llama2_7b_hf/tokenizer.model
--tokenizer-model
/
d
at
a/model_weights
/llama2_7b_hf/tokenizer.model
)
)
EVAL_AND_LOGGING_ARGS
=(
EVAL_AND_LOGGING_ARGS
=(
...
@@ -102,6 +129,15 @@ EVAL_AND_LOGGING_ARGS=(
...
@@ -102,6 +129,15 @@ EVAL_AND_LOGGING_ARGS=(
--tensorboard-dir
$TENSORBOARD_LOGS_PATH
--tensorboard-dir
$TENSORBOARD_LOGS_PATH
)
)
PROFILE_ARGS
=(
--profile
--profile-step-start
4
--profile-step-end
5
--use-pytorch-profiler
--profile-ranks
0 3
--profile-dir
prof_data
)
RANK
=
$OMPI_COMM_WORLD_RANK
RANK
=
$OMPI_COMM_WORLD_RANK
LOCAL_RANK
=
$OMPI_COMM_WORLD_LOCAL_RANK
LOCAL_RANK
=
$OMPI_COMM_WORLD_LOCAL_RANK
WORLD_SIZE
=
$OMPI_COMM_WORLD_SIZE
WORLD_SIZE
=
$OMPI_COMM_WORLD_SIZE
...
@@ -122,47 +158,51 @@ APP="python -u pretrain_gpt.py \
...
@@ -122,47 +158,51 @@ APP="python -u pretrain_gpt.py \
${
DATA_ARGS
[@]
}
\
${
DATA_ARGS
[@]
}
\
${
EVAL_AND_LOGGING_ARGS
[@]
}
\
${
EVAL_AND_LOGGING_ARGS
[@]
}
\
${
DISTRIBUTED_ARGS
[@]
}
\
${
DISTRIBUTED_ARGS
[@]
}
\
${
PROFILE_ARGS
[@]
}
\
"
"
export
HIP_VISIBLE_DEVICES
=
4,5,6,7
# 0,1,2,3 # 4,5,6,7 #,
# export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3,
# ${APP}
case
${
LOCAL_RANK
}
in
case
${
LOCAL_RANK
}
in
[
0]
)
[
0]
)
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#
${APP}
${
APP
}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
#
numactl --cpunodebind=0 --membind=0 ${APP}
;;
;;
[
1]
)
[
1]
)
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#
${APP}
${
APP
}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
#
numactl --cpunodebind=0 --membind=0 ${APP}
;;
;;
[
2]
)
[
2]
)
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#
${APP}
${
APP
}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
#
numactl --cpunodebind=0 --membind=0 ${APP}
;;
;;
[
3]
)
[
3]
)
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#
${APP}
${
APP
}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
#
numactl --cpunodebind=0 --membind=0 ${APP}
;;
;;
[
4]
)
#
[4])
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# ${APP}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
# #
numactl --cpunodebind=0 --membind=0 ${APP}
;;
#
;;
[
5]
)
#
[5])
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# ${APP}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
# #
numactl --cpunodebind=0 --membind=0 ${APP}
;;
#
;;
[
6]
)
#
[6])
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# ${APP}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
# #
numactl --cpunodebind=0 --membind=0 ${APP}
;;
#
;;
[
7]
)
#
[7])
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
#
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP}
# ${APP}
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
# #
numactl --cpunodebind=0 --membind=0 ${APP}
;;
#
;;
esac
esac
megatron/training/arguments.py
View file @
e7da80dd
...
@@ -643,7 +643,7 @@ def validate_args(args, defaults={}):
...
@@ -643,7 +643,7 @@ def validate_args(args, defaults={}):
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
# FlashAttention
args
.
use_flash_attn
=
args
.
use_flash_attn_c
k
or
args
.
use_flash_attn_triton
args
.
use_flash_attn
=
args
.
use_flash_attn_c
utlass
or
args
.
use_flash_attn_triton
# Legacy RoPE arguments
# Legacy RoPE arguments
if
args
.
use_rotary_position_embeddings
:
if
args
.
use_rotary_position_embeddings
:
...
@@ -1265,6 +1265,8 @@ def _add_training_args(parser):
...
@@ -1265,6 +1265,8 @@ def _add_training_args(parser):
dest
=
'use_pytorch_profiler'
)
dest
=
'use_pytorch_profiler'
)
group
.
add_argument
(
'--profile-ranks'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
0
],
group
.
add_argument
(
'--profile-ranks'
,
nargs
=
'+'
,
type
=
int
,
default
=
[
0
],
help
=
'Global ranks to profile.'
)
help
=
'Global ranks to profile.'
)
group
.
add_argument
(
'--profile-dir'
,
type
=
str
,
default
=
"./"
,
help
=
'profile dir to save.'
)
group
.
add_argument
(
'--record-memory-history'
,
action
=
"store_true"
,
default
=
False
,
group
.
add_argument
(
'--record-memory-history'
,
action
=
"store_true"
,
default
=
False
,
help
=
'Record memory history in last rank.'
)
help
=
'Record memory history in last rank.'
)
group
.
add_argument
(
'--memory-snapshot-path'
,
type
=
str
,
default
=
"snapshot.pickle"
,
group
.
add_argument
(
'--memory-snapshot-path'
,
type
=
str
,
default
=
"snapshot.pickle"
,
...
@@ -1358,7 +1360,7 @@ def _add_training_args(parser):
...
@@ -1358,7 +1360,7 @@ def _add_training_args(parser):
group
.
add_argument
(
'--cross-entropy-loss-fusion'
,
action
=
'store_true'
,
group
.
add_argument
(
'--cross-entropy-loss-fusion'
,
action
=
'store_true'
,
help
=
'Enabled fusion of cross entropy loss calculation.'
,
help
=
'Enabled fusion of cross entropy loss calculation.'
,
dest
=
'cross_entropy_loss_fusion'
)
dest
=
'cross_entropy_loss_fusion'
)
group
.
add_argument
(
'--use-flash-attn-c
k
'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-flash-attn-c
utlass
'
,
action
=
'store_true'
,
help
=
'use FlashAttention implementation of attention. '
help
=
'use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135'
)
'https://arxiv.org/abs/2205.14135'
)
group
.
add_argument
(
'--use-flash-attn-triton'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-flash-attn-triton'
,
action
=
'store_true'
,
...
...
megatron/training/training.py
View file @
e7da80dd
...
@@ -135,6 +135,13 @@ def num_floating_point_operations(args, batch_size):
...
@@ -135,6 +135,13 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor
=
3
*
2
*
2
expansion_factor
=
3
*
2
*
2
# print(f"batch_size: {batch_size}, \
# query_projection_to_hidden_size_ratio: {query_projection_to_hidden_size_ratio}, \
# num_experts_routed_to: {num_experts_routed_to}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# shared_expert_ffn_hidden_size: {shared_expert_ffn_hidden_size}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# ")
return
(
return
(
expansion_factor
expansion_factor
*
batch_size
*
batch_size
...
@@ -1214,8 +1221,8 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
...
@@ -1214,8 +1221,8 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
if
args
.
use_pytorch_profiler
:
if
args
.
use_pytorch_profiler
:
assert
prof
is
not
None
assert
prof
is
not
None
prof
.
stop
()
prof
.
stop
()
else
:
print_rank_0
(
f
"prof stop!"
)
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
# Manual garbage collection.
# Manual garbage collection.
if
args
.
manual_gc
:
if
args
.
manual_gc
:
...
@@ -1401,25 +1408,34 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1401,25 +1408,34 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
prof
=
None
prof
=
None
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
:
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
:
def
trace_handler
(
p
):
from
pathlib
import
Path
Path
(
f
"
{
args
.
profile_dir
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
print
(
p
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=-
1
))
p
.
export_chrome_trace
(
"{path}/trace_rank{rank}_step{step}.json"
.
format
(
path
=
args
.
profile_dir
,
rank
=
torch
.
distributed
.
get_rank
(),
step
=
p
.
step_num
))
prof
=
torch
.
profiler
.
profile
(
prof
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
schedule
=
torch
.
profiler
.
schedule
(
schedule
=
torch
.
profiler
.
schedule
(
wait
=
max
(
args
.
profile_step_start
-
1
,
0
),
wait
=
max
(
args
.
profile_step_start
-
1
,
0
),
warmup
=
1
if
args
.
profile_step_start
>
0
else
0
,
warmup
=
1
if
args
.
profile_step_start
>
0
else
0
,
active
=
args
.
profile_step_end
-
args
.
profile_step_start
,
active
=
args
.
profile_step_end
-
args
.
profile_step_start
,
repeat
=
1
),
repeat
=
1
),
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
args
.
tensorboard_dir
)
,
# record_shapes=True
,
record_shapes
=
True
,
# with_stack
=True,
with_stack
=
True
)
on_trace_ready
=
trace_handler
,
)
prof
.
start
()
prof
.
start
()
# Run training iterations till done.
# Run training iterations till done.
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
:
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
:
if
args
.
use_pytorch_profiler
:
if
args
.
use_pytorch_profiler
:
prof
.
step
()
prof
.
step
()
elif
iteration
==
args
.
profile_step_start
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
autograd
.
profiler
.
emit_nvtx
(
record_shapes
=
True
).
__enter__
()
maybe_finalize_async_save
(
blocking
=
False
)
maybe_finalize_async_save
(
blocking
=
False
)
...
@@ -1431,12 +1447,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1431,12 +1447,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
assert
get_num_microbatches
()
>
num_microbatches
,
\
assert
get_num_microbatches
()
>
num_microbatches
,
\
(
f
"Number of microbatches should be increasing due to batch size rampup; "
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
if
args
.
save
is
not
None
:
if
args
.
save
is
not
None
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
num_microbatches
=
get_num_microbatches
()
num_microbatches
=
get_num_microbatches
()
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
...
@@ -1444,23 +1460,23 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1444,23 +1460,23 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
args
.
curr_iteration
=
iteration
args
.
curr_iteration
=
iteration
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
=
\
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
optimizer
,
optimizer
,
opt_param_scheduler
,
opt_param_scheduler
,
config
)
config
)
if
should_checkpoint
:
if
should_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
if
should_exit
:
if
should_exit
:
break
break
# why is skipped_iter ignored?
# why is skipped_iter ignored?
iteration
+=
1
iteration
+=
1
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
get_num_microbatches
()
get_num_microbatches
()
args
.
consumed_train_samples
+=
batch_size
args
.
consumed_train_samples
+=
batch_size
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
())
get_current_running_global_batch_size
())
...
@@ -1486,11 +1502,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1486,11 +1502,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
else
:
else
:
learning_rate
=
param_group
[
'lr'
]
learning_rate
=
param_group
[
'lr'
]
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
learning_rate
,
decoupled_learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Evaluation.
# Evaluation.
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
@@ -1504,10 +1520,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1504,10 +1520,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
prefix
=
f
'iteration
{
iteration
}
'
prefix
=
f
'iteration
{
iteration
}
'
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
)
non_loss_data_func
=
non_loss_data_func
)
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_iterations
+=
args
.
eval_iters
eval_iterations
+=
args
.
eval_iters
timers
(
'eval-time'
).
stop
()
timers
(
'eval-time'
).
stop
()
...
@@ -1527,12 +1543,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -1527,12 +1543,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
# Some of these only happen at specific iterations.
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
)
num_floating_point_operations_since_last_log_event
)
# Checkpoint and decide whether to exit.
# Checkpoint and decide whether to exit.
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
)
checkpointing_context
,
train_data_iterator
)
if
should_exit
:
if
should_exit
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment