Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d2073912
Commit
d2073912
authored
Aug 10, 2022
by
Jared Casper
Browse files
Merge branch 'timing' into 'main'
Timing levels See merge request ADLR/megatron-lm!436
parents
83d78670
77efcccb
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
494 additions
and
218 deletions
+494
-218
megatron/arguments.py
megatron/arguments.py
+26
-0
megatron/global_vars.py
megatron/global_vars.py
+5
-85
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+12
-8
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+29
-16
megatron/p2p_communication.py
megatron/p2p_communication.py
+10
-9
megatron/schedules.py
megatron/schedules.py
+27
-19
megatron/timers.py
megatron/timers.py
+304
-0
megatron/training.py
megatron/training.py
+59
-59
pretrain_bert.py
pretrain_bert.py
+1
-1
pretrain_gpt.py
pretrain_gpt.py
+1
-1
pretrain_ict.py
pretrain_ict.py
+1
-1
pretrain_t5.py
pretrain_t5.py
+1
-1
pretrain_vision_classify.py
pretrain_vision_classify.py
+1
-1
pretrain_vision_dino.py
pretrain_vision_dino.py
+1
-1
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+1
-1
tasks/finetune_utils.py
tasks/finetune_utils.py
+7
-7
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+1
-1
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+1
-1
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+5
-5
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+1
-1
No files found.
megatron/arguments.py
View file @
d2073912
...
...
@@ -411,6 +411,32 @@ def _add_logging_args(parser):
help
=
'If set, calculate and log parameters norm.'
)
group
.
add_argument
(
'--log-num-zeros-in-grad'
,
action
=
'store_true'
,
help
=
'If set, calculate and log the number of zeros in gradient.'
)
group
.
add_argument
(
'--timing-log-level'
,
type
=
int
,
default
=
0
,
choices
=
range
(
0
,
3
),
help
=
'Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.'
)
group
.
add_argument
(
'--no-barrier-with-level-1-timing'
,
action
=
'store_false'
,
help
=
'If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.'
,
dest
=
'barrier_with_L1_time'
)
group
.
add_argument
(
'--timing-log-option'
,
type
=
str
,
default
=
'minmax'
,
choices
=
[
'max'
,
'minmax'
,
'all'
],
help
=
'Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.'
)
group
.
add_argument
(
'--tensorboard-log-interval'
,
type
=
int
,
default
=
1
,
help
=
'Report to tensorboard interval.'
)
group
.
add_argument
(
'--tensorboard-queue-size'
,
type
=
int
,
default
=
1000
,
...
...
megatron/global_vars.py
View file @
d2073912
...
...
@@ -17,7 +17,6 @@
import
os
import
sys
import
time
from
functools
import
reduce
import
operator
import
torch
...
...
@@ -25,6 +24,7 @@ import torch
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
.microbatches
import
build_num_microbatches_calculator
from
.timers
import
Timers
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
...
...
@@ -108,7 +108,7 @@ def set_global_variables(args):
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
_set_timers
(
args
)
_set_global_memory_buffer
()
if
args
.
exit_signal_handler
:
...
...
@@ -182,11 +182,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME
=
AutoResume
def
_set_timers
():
def
_set_timers
(
args
):
"""Initialize timers."""
global
_GLOBAL_TIMERS
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_GLOBAL_TIMERS
=
Timers
()
_GLOBAL_TIMERS
=
Timers
(
args
.
timing_log_level
,
args
.
timing_log_option
)
def
_set_global_memory_buffer
():
"""Initialize global buffer"""
...
...
@@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name):
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
Timers
:
"""Group of timers."""
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
_Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'-time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
class
GlobalMemoryBuffer
:
"""Global buffer to avoid dynamic memory allocations.
...
...
megatron/optimizer/distrib_optimizer.py
View file @
d2073912
...
...
@@ -532,17 +532,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"""
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'backward-layernorm-all-reduce'
).
start
()
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'
backward-
layernorm-all-reduce'
).
stop
()
timers
(
'layernorm-
grads-
all-reduce'
).
stop
()
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_embedding_grads
(
args
)
timers
(
'
backward-
embedding-all-reduce'
).
stop
()
timers
(
'embedding-
grads-
all-reduce'
).
stop
()
# Reduce-scatter setup.
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'grads-reduce-scatter'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
...
...
@@ -563,7 +566,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group
=
data_parallel_group
,
)
timers
(
'
backward-params-all-reduce
'
).
stop
()
timers
(
'
grads-reduce-scatter
'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
...
...
@@ -575,7 +578,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
can be copied from param.main_grad to param.
"""
timers
(
'backward-params-all-gather'
).
start
()
timers
(
'params-all-gather'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
...
...
@@ -602,7 +606,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
param
in
param_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'
backward-
params-all-gather'
).
stop
()
timers
(
'params-all-gather'
).
stop
()
def
_collect_main_grad_data_for_unscaling
(
self
):
...
...
megatron/optimizer/optimizer.py
View file @
d2073912
...
...
@@ -294,21 +294,24 @@ class MegatronOptimizer(ABC):
"""All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'backward-layernorm-all-reduce'
).
start
()
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'
backward-
layernorm-all-reduce'
).
stop
()
timers
(
'layernorm-
grads-
all-reduce'
).
stop
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
for
model
in
self
.
models
:
model
.
allreduce_gradients
()
timers
(
'
backward-pa
ra
m
s-all-reduce'
).
stop
()
timers
(
'
g
ra
d
s-all-reduce'
).
stop
()
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_embedding_grads
(
args
)
timers
(
'
backward-
embedding-all-reduce'
).
stop
()
timers
(
'embedding-
grads-
all-reduce'
).
stop
()
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
...
...
@@ -416,7 +419,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def
step
(
self
,
args
,
timers
):
# Copy gradients from model params to main params.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
...
...
@@ -425,7 +429,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
timers
(
'optimizer-unscale-and-check-inf'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
...
...
@@ -438,25 +443,29 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return
False
,
None
,
None
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# Count the zeros in the grads.
timers
(
'optimizer-count-zeros'
).
start
()
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
# Step the optimizer.
timers
(
'optimizer-inner-step'
).
start
()
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
timers
(
'optimizer-copy-main-to-model-params'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
...
...
@@ -725,7 +734,8 @@ class FP32Optimizer(MegatronOptimizer):
Always return successful since there is no overflow."""
# Copy main_grads to grads.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
...
...
@@ -739,20 +749,23 @@ class FP32Optimizer(MegatronOptimizer):
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Clip gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
timers
(
'optimizer-count-zeros'
).
start
()
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
# Update parameters.
timers
(
'optimizer-inner-step'
).
start
()
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
...
...
megatron/p2p_communication.py
View file @
d2073912
...
...
@@ -163,7 +163,7 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
start
()
timers
(
'forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
...
...
@@ -182,7 +182,7 @@ def recv_backward(tensor_shape=None, timers=None):
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
start
()
timers
(
'backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
...
...
@@ -199,7 +199,7 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
timers
(
'forward-send'
,
log_level
=
2
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
...
...
@@ -215,7 +215,7 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
timers
(
'backward-send'
,
log_level
=
2
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
...
...
@@ -232,7 +232,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
start
()
timers
(
'forward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
...
...
@@ -250,7 +250,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
start
()
timers
(
'backward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
...
...
@@ -265,7 +265,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
...
...
@@ -280,7 +280,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
timers
(
'backward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
...
...
@@ -297,7 +297,8 @@ def send_forward_backward_recv_forward_backward(
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
timers
(
'forward-backward-send-forward-backward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
...
...
megatron/schedules.py
View file @
d2073912
...
...
@@ -107,6 +107,7 @@ def forward_step(forward_step_func,
model
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
=
False
):
"""Forward step for passed-in model.
...
...
@@ -115,9 +116,9 @@ def forward_step(forward_step_func,
Returns output tensor."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
if
timers
is
not
None
:
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
...
@@ -138,7 +139,8 @@ def forward_step(forward_step_func,
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
timers
(
'forward-compute'
).
stop
()
if
timers
is
not
None
:
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
...
...
@@ -151,7 +153,8 @@ def forward_step(forward_step_func,
return
[
output_tensor
]
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
...
...
@@ -165,8 +168,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# connections.
args
=
get_args
()
timers
=
get_timers
()
timers
(
'backward-compute'
).
start
()
if
timers
is
not
None
:
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
...
...
@@ -207,7 +210,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
timers
(
'backward-compute'
).
stop
()
if
timers
is
not
None
:
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
...
...
@@ -243,18 +247,19 @@ def forward_backward_no_pipelining(forward_step_func,
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
)
return
forward_data_store
...
...
@@ -269,6 +274,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
args
=
get_args
()
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
forward_data_store
=
[]
...
...
@@ -278,7 +286,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
...
...
@@ -337,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
model
[
model_chunk_id
],
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
...
...
@@ -364,7 +372,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
,
timers
)
return
input_tensor_grad
...
...
@@ -620,8 +629,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
Returns dictionary with losses if the last stage, empty dict otherwise."""
args
=
get_args
()
timers
=
get_timers
()
assert
len
(
model
)
==
1
model
=
model
[
0
]
...
...
@@ -656,7 +664,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
...
...
@@ -676,7 +684,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
...
...
@@ -701,7 +709,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
,
timers
)
if
last_iteration
:
input_tensor
=
None
...
...
@@ -721,7 +729,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
,
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
...
...
megatron/timers.py
0 → 100644
View file @
d2073912
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron timers."""
from
abc
import
ABC
from
abc
import
abstractmethod
import
time
import
torch
class
TimerBase
(
ABC
):
def
__init__
(
self
,
name
):
self
.
name
=
name
@
abstractmethod
def
start
(
self
,
barrier
=
False
):
pass
@
abstractmethod
def
stop
(
self
,
barrier
=
False
):
pass
@
abstractmethod
def
reset
(
self
):
pass
@
abstractmethod
def
elapsed
(
self
,
reset
=
True
,
barrier
=
False
):
pass
class
DummyTimer
(
TimerBase
):
def
__init__
(
self
):
super
().
__init__
(
'dummy timer'
)
def
start
(
self
,
barrier
=
False
):
return
def
stop
(
self
,
barrier
=
False
):
return
def
reset
(
self
):
return
def
elapsed
(
self
,
reset
=
True
,
barrier
=
False
):
raise
Exception
(
'dummy timer should not be used to '
'calculate elapsed time'
)
class
Timer
(
TimerBase
):
"""
Comment on using `barrier`: If this flag is passed, then all
the caller processes will wait till all reach the timing routine.
It is up to the user to make sure all the ranks in `barrier_group`
call it otherwise, it will result in a hang.
Comment on `barrier_group`: By default it is set to None which
in torch distributed land, it will result in the global communicator.
"""
def
__init__
(
self
,
name
):
super
().
__init__
(
name
)
self
.
_elapsed
=
0.0
self
.
_started
=
False
# Note that None will default to the global process group
self
.
_barrier_group
=
None
self
.
_start_time
=
time
.
time
()
def
set_barrier_group
(
self
,
barrier_group
):
self
.
_barrier_group
=
barrier_group
def
start
(
self
,
barrier
=
False
):
"""Start the timer."""
assert
not
self
.
_started
,
'timer has already been started'
if
barrier
:
torch
.
distributed
.
barrier
(
group
=
self
.
_barrier_group
)
torch
.
cuda
.
synchronize
()
self
.
_start_time
=
time
.
time
()
self
.
_started
=
True
def
stop
(
self
,
barrier
=
False
):
"""Stop the timer."""
assert
self
.
_started
,
'timer is not started'
if
barrier
:
torch
.
distributed
.
barrier
(
group
=
self
.
_barrier_group
)
torch
.
cuda
.
synchronize
()
self
.
_elapsed
+=
(
time
.
time
()
-
self
.
_start_time
)
self
.
_started
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
_elapsed
=
0.0
self
.
_started
=
False
def
elapsed
(
self
,
reset
=
True
,
barrier
=
False
):
"""Calculate the elapsed time."""
_started
=
self
.
_started
# If the timing in progress, end it first.
if
self
.
_started
:
self
.
stop
(
barrier
=
barrier
)
# Get the elapsed time.
_elapsed
=
self
.
_elapsed
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
_started
:
self
.
start
(
barrier
=
barrier
)
return
_elapsed
class
Timers
:
"""Group of timers."""
def
__init__
(
self
,
log_level
,
log_option
):
self
.
_log_level
=
log_level
self
.
_log_option
=
log_option
self
.
_timers
=
{}
self
.
_log_levels
=
{}
self
.
_dummy_timer
=
DummyTimer
()
self
.
_max_log_level
=
2
def
__call__
(
self
,
name
,
log_level
=
None
):
# If the timer has already been set, then check if the log-level
# is provided, it matches the one that the timer was created with.
if
name
in
self
.
_timers
:
if
log_level
is
not
None
:
assert
log_level
==
self
.
_log_levels
[
name
],
\
'input log level {} does not match already existing '
\
'log level {} for {} timer'
.
format
(
log_level
,
self
.
_log_levels
[
name
],
name
)
return
self
.
_timers
[
name
]
# If timer does not exist and no log level is provided,
# set it to the max log level which is 2.
if
log_level
is
None
:
log_level
=
self
.
_max_log_level
assert
log_level
<=
self
.
_max_log_level
,
\
'log level {} is larger than max supported log level {}'
.
format
(
log_level
,
self
.
_max_log_level
)
# Now if the input log level is larger than the one set for
# the timers class, just ignore it and return a dummy timer.
if
log_level
>
self
.
_log_level
:
return
self
.
_dummy_timer
# Otherwise, initalize the timer and set the level.
self
.
_timers
[
name
]
=
Timer
(
name
)
self
.
_log_levels
[
name
]
=
log_level
return
self
.
_timers
[
name
]
def
_get_elapsed_time_all_ranks
(
self
,
names
,
reset
,
barrier
):
"""
Assumptions:
- All the ranks call this function.
- `names` are identical on all ranks.
If the above assumptions are not met, calling this function will
result in hang.
Arguments:
- names: list of timer names
- reset: reset the timer after recording the elapsed time
- barrier: if set, do a global barrier before time measurments
"""
# First make sure all the callers are in sync.
if
barrier
:
torch
.
distributed
.
barrier
()
world_size
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
# Here we can use gather on the rank we want to print the
# timing, however, there is no gather_base support in
# pytorch yet. It is simpler to deal with a single tensor
# and since we are only gathering a small amount of data,
# it should be ok to use all-gather instead of gather.
rank_name_to_time
=
torch
.
zeros
((
world_size
,
len
(
names
)),
dtype
=
torch
.
float
,
device
=
torch
.
cuda
.
current_device
())
for
i
,
name
in
enumerate
(
names
):
if
name
in
self
.
_timers
:
# Here we don't need to pass the barrier flag as all
# the processes are already in sync. This avoids the
# issue of different timers having different barrier
# groups inside their class.
rank_name_to_time
[
rank
,
i
]
=
self
.
_timers
[
name
].
elapsed
(
reset
=
reset
)
# See the note above for why we are not using gather.
torch
.
distributed
.
_all_gather_base
(
rank_name_to_time
.
view
(
-
1
),
rank_name_to_time
[
rank
,
:].
view
(
-
1
))
return
rank_name_to_time
def
_get_global_min_max_time
(
self
,
names
,
reset
,
barrier
,
normalizer
):
"""Report only min and max times across all ranks."""
rank_name_to_time
=
self
.
_get_elapsed_time_all_ranks
(
names
,
reset
,
barrier
)
name_to_min_max_time
=
{}
for
i
,
name
in
enumerate
(
names
):
rank_to_time
=
rank_name_to_time
[:,
i
]
# filter out the ones we did not have any timings for
rank_to_time
=
rank_to_time
[
rank_to_time
>
0.0
]
# If the timer exists:
if
rank_to_time
.
numel
()
>
0
:
name_to_min_max_time
[
name
]
=
(
rank_to_time
.
min
().
item
()
/
normalizer
,
rank_to_time
.
max
().
item
()
/
normalizer
)
return
name_to_min_max_time
def
_get_global_min_max_time_string
(
self
,
names
,
reset
,
barrier
,
normalizer
,
max_only
):
name_to_min_max_time
=
self
.
_get_global_min_max_time
(
names
,
reset
,
barrier
,
normalizer
)
if
not
name_to_min_max_time
:
return
None
output_string
=
'(min, max) time across ranks (ms):'
for
name
in
name_to_min_max_time
:
min_time
,
max_time
=
name_to_min_max_time
[
name
]
if
max_only
:
output_string
+=
'
\n
{}: {:.2f}'
.
format
(
(
name
+
' '
).
ljust
(
48
,
'.'
),
max_time
)
else
:
output_string
+=
'
\n
{}: ({:.2f}, {:.2f})'
.
format
(
(
name
+
' '
).
ljust
(
48
,
'.'
),
min_time
,
max_time
)
return
output_string
def
_get_all_ranks_time_string
(
self
,
names
,
reset
,
barrier
,
normalizer
):
"""Report times across all ranks."""
rank_name_to_time
=
self
.
_get_elapsed_time_all_ranks
(
names
,
reset
,
barrier
)
output_string
=
'times across ranks (ms):'
no_reported_timing
=
True
for
i
,
name
in
enumerate
(
names
):
not_yet_found
=
True
for
rank
in
range
(
torch
.
distributed
.
get_world_size
()):
if
rank_name_to_time
[
rank
,
i
]
>
0
:
no_reported_timing
=
False
if
not_yet_found
:
not_yet_found
=
False
output_string
+=
'
\n
{}:'
.
format
(
name
)
output_string
+=
'
\n
rank {:2d}: {:.2f}'
.
format
(
rank
,
rank_name_to_time
[
rank
,
i
]
/
normalizer
)
if
no_reported_timing
:
return
None
return
output_string
def
log
(
self
,
names
,
rank
=
None
,
normalizer
=
1.0
,
reset
=
True
,
barrier
=
False
):
"""Log a group of timers."""
# Print.
assert
normalizer
>
0.0
if
self
.
_log_option
in
[
'max'
,
'minmax'
]:
max_only
=
False
if
self
.
_log_option
==
'max'
:
max_only
=
True
output_string
=
self
.
_get_global_min_max_time_string
(
names
,
reset
,
barrier
,
normalizer
/
1000.0
,
max_only
)
elif
self
.
_log_option
==
'all'
:
output_string
=
self
.
_get_all_ranks_time_string
(
names
,
reset
,
barrier
,
normalizer
/
1000.0
)
else
:
raise
Exception
(
'unknown timing log option {}'
.
format
(
self
.
_log_option
))
# If no input rank is provided, log on last rank.
if
rank
is
None
:
rank
=
torch
.
distributed
.
get_world_size
()
-
1
if
rank
==
torch
.
distributed
.
get_rank
()
and
output_string
is
not
None
:
print
(
output_string
,
flush
=
True
)
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
,
barrier
=
False
):
"""Write timers to a tensorboard writer
Note that we only report maximum time across ranks to tensorboard.
"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
name_to_min_max_time
=
self
.
_get_global_min_max_time
(
names
,
reset
,
barrier
,
normalizer
)
if
writer
is
not
None
:
for
name
in
name_to_min_max_time
:
_
,
max_time
=
name_to_min_max_time
[
name
]
writer
.
add_scalar
(
name
+
'-time'
,
max_time
,
iteration
)
megatron/training.py
View file @
d2073912
...
...
@@ -119,23 +119,28 @@ def pretrain(train_valid_test_dataset_provider,
timers
=
get_timers
()
# Model, optimizer, and learning rate.
timers
(
'model-and-optimizer-setup'
).
start
(
)
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model-and-optimizer-setup'
,
log_level
=
0
).
start
(
barrier
=
True
)
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
# Data stuff.
timers
(
'train/valid/test-data-iterators-setup'
).
start
()
timers
(
'train/valid/test-data-iterators-setup'
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
all_data_iterators
=
[
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
for
_
in
range
(
len
(
model
))
]
train_data_iterator
=
[
data_iterators
[
0
]
for
data_iterators
in
all_data_iterators
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_data_iterators
]
train_data_iterator
=
[
data_iterators
[
0
]
for
data_iterators
in
all_data_iterators
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_data_iterators
]
else
:
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
...
...
@@ -145,7 +150,8 @@ def pretrain(train_valid_test_dataset_provider,
# Print setup timing.
print_rank_0
(
'done with setup ...'
)
timers
.
log
([
'model-and-optimizer-setup'
,
'train/valid/test-data-iterators-setup'
])
timers
.
log
([
'model-and-optimizer-setup'
,
'train/valid/test-data-iterators-setup'
],
barrier
=
True
)
print_rank_0
(
'training ...'
)
iteration
=
0
...
...
@@ -373,13 +379,9 @@ def setup_model_and_optimizer(model_provider_func,
if
args
.
load
is
not
None
:
timers
=
get_timers
()
# Extra barrier is added to make sure all ranks report the
# max time.
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
timers
(
'load-checkpoint'
,
log_level
=
0
).
start
(
barrier
=
True
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
opt_param_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
()
timers
(
'load-checkpoint'
).
stop
(
barrier
=
True
)
timers
.
log
([
'load-checkpoint'
])
else
:
args
.
iteration
=
0
...
...
@@ -412,19 +414,21 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
zero_grad
()
# Forward pass.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
forward_backward_func
=
get_forward_backward_func
()
fwd_bwd_timers
=
timers
if
args
.
timing_log_level
>
1
else
None
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
optimizer
,
fwd_bwd_timers
,
forward_only
=
False
)
timers
(
'forward-backward'
).
stop
()
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
# Reduce gradients.
timers
(
'backward-reduce-model-grads'
).
start
()
optimizer
.
reduce_model_grads
(
args
,
timers
)
timers
(
'backward-reduce-model-grads'
).
stop
()
# Vision gradients.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
...
...
@@ -433,15 +437,13 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
timers
(
'optimizer'
).
start
(
)
timers
(
'optimizer'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
)
timers
(
'optimizer'
).
stop
()
# Gather params.
if
update_successful
:
timers
(
'backward-gather-model-params'
).
start
()
optimizer
.
gather_model_params
(
args
,
timers
)
timers
(
'backward-gather-model-params'
).
stop
()
# Vision momentum.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
...
...
@@ -511,33 +513,32 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
nan_iters_key
,
0
)
+
int
(
got_nan
)
# Logging.
timers_to_log
=
[]
def
add_to_logging
(
name
):
if
name
in
timers
.
timers
:
timers_to_log
.
append
(
name
)
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-backward-send-forward-backward-recv'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-backward-recv'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-layernorm-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-reduce-model-grads'
)
add_to_logging
(
'backward-gather-model-params'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
add_to_logging
(
'optimizer-unscale-and-check-inf'
)
add_to_logging
(
'optimizer-clip-main-grad'
)
add_to_logging
(
'optimizer-count-zeros'
)
add_to_logging
(
'optimizer-inner-step'
)
add_to_logging
(
'optimizer-copy-main-to-model-params'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch-generator'
)
timers_to_log
=
[
'forward-backward'
,
'forward-compute'
,
'backward-compute'
,
'batch-generator'
,
'forward-recv'
,
'forward-send'
,
'backward-recv'
,
'backward-send'
,
'forward-send-forward-recv'
,
'forward-send-backward-recv'
,
'backward-send-forward-recv'
,
'backward-send-backward-recv'
,
'forward-backward-send-forward-backward-recv'
,
'layernorm-grads-all-reduce'
,
'embedding-grads-all-reduce'
,
'grads-all-reduce'
,
'grads-reduce-scatter'
,
'params-all-gather'
,
'optimizer-copy-to-main-grad'
,
'optimizer-unscale-and-check-inf'
,
'optimizer-clip-main-grad'
,
'optimizer-count-zeros'
,
'optimizer-inner-step'
,
'optimizer-copy-main-to-model-params'
,
'optimizer'
]
# Calculate batch size.
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
*
\
...
...
@@ -547,8 +548,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict
[
skipped_iters_key
]
# Tensorboard values.
if
writer
and
(
iteration
%
args
.
tensorboard_log_interval
==
0
)
and
\
is_last_rank
():
# Timer requires all the ranks to call.
if
args
.
log_timers_to_tensorboard
and
\
(
iteration
%
args
.
tensorboard_log_interval
==
0
):
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
normalizer
=
total_iterations
)
if
writer
and
(
iteration
%
args
.
tensorboard_log_interval
==
0
):
if
args
.
log_learning_rate_to_tensorboard
:
writer
.
add_scalar
(
'learning-rate'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning-rate vs samples'
,
learning_rate
,
...
...
@@ -581,9 +586,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer
.
add_scalar
(
'params-norm'
,
params_norm
,
iteration
)
writer
.
add_scalar
(
'params-norm vs samples'
,
params_norm
,
args
.
consumed_train_samples
)
if
args
.
log_timers_to_tensorboard
:
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
normalizer
=
total_iterations
)
if
args
.
log_memory_to_tensorboard
:
mem_stats
=
torch
.
cuda
.
memory_stats
()
writer
.
add_scalar
(
...
...
@@ -603,7 +605,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
)
if
iteration
%
args
.
log_interval
==
0
:
elapsed_time
=
timers
(
'interval-time'
).
elapsed
()
elapsed_time
=
timers
(
'interval-time'
).
elapsed
(
barrier
=
True
)
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
if
writer
:
if
args
.
log_timers_to_tensorboard
:
...
...
@@ -653,11 +655,9 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers
=
get_timers
()
# Extra barrier is added to make sure
# all ranks report the max time.
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
start
()
timers
(
'save-checkpoint'
,
log_level
=
0
).
start
(
barrier
=
True
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
stop
()
timers
(
'save-checkpoint'
).
stop
(
barrier
=
True
)
timers
.
log
([
'save-checkpoint'
])
...
...
@@ -681,7 +681,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
iteration
=
args
.
iteration
timers
(
'interval-time'
).
start
(
)
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
...
...
pretrain_bert.py
View file @
d2073912
...
...
@@ -104,7 +104,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
...
...
pretrain_gpt.py
View file @
d2073912
...
...
@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
...
...
pretrain_ict.py
View file @
d2073912
...
...
@@ -134,7 +134,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
context_indices
=
get_ict_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
...
...
pretrain_t5.py
View file @
d2073912
...
...
@@ -126,7 +126,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
,
log_level
=
2
).
start
()
tokens_enc
,
tokens_dec
,
loss_mask
,
lm_labels
,
enc_mask
,
dec_mask
,
enc_dec_mask
\
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
...
...
pretrain_vision_classify.py
View file @
d2073912
...
...
@@ -77,7 +77,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
# Get the batch.
timers
(
"batch-generator"
).
start
()
timers
(
"batch-generator"
,
log_level
=
2
).
start
()
(
images
,
labels
,
...
...
pretrain_vision_dino.py
View file @
d2073912
...
...
@@ -84,7 +84,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
# Get the batch.
timers
(
"batch-generator"
).
start
()
timers
(
"batch-generator"
,
log_level
=
2
).
start
()
(
images
,
labels
,
...
...
pretrain_vision_inpaint.py
View file @
d2073912
...
...
@@ -91,7 +91,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
# Get the batch.
timers
(
"batch-generator"
).
start
()
timers
(
"batch-generator"
,
log_level
=
2
).
start
()
(
images
,
masks
,
...
...
tasks/finetune_utils.py
View file @
d2073912
...
...
@@ -67,7 +67,7 @@ def _cross_entropy_forward_step(batch, model):
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
...
...
@@ -178,7 +178,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
report_memory_flag
=
True
# For each remaining epoch
timers
(
'interval-time'
).
start
(
)
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
...
...
@@ -261,7 +261,7 @@ def finetune(train_valid_datasets_provider, model_provider,
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
timers
(
'train/valid/test dataset/dataloder'
).
start
()
timers
(
'train/valid/test dataset/dataloder'
,
log_level
=
0
).
start
()
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
...
...
@@ -271,21 +271,21 @@ def finetune(train_valid_datasets_provider, model_provider,
timers
(
'train/valid/test dataset/dataloder'
).
stop
()
# Build calback function.
timers
(
'callback function'
).
start
()
timers
(
'callback function'
,
log_level
=
0
).
start
()
end_of_epoch_callback
=
None
if
end_of_epoch_callback_provider
is
not
None
:
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
'callback function'
).
stop
()
# Build model, optimizer and learning rate scheduler.
timers
(
'model and optimizer'
).
start
()
timers
(
'model and optimizer'
,
log_level
=
0
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model and optimizer'
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers
(
'pretrained checkpoint'
).
start
(
)
timers
(
'pretrained checkpoint'
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
...
...
@@ -302,7 +302,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'train/valid/test dataset/dataloder'
,
'callback function'
,
'model and optimizer'
,
'pretrained checkpoint'
])
'model and optimizer'
,
'pretrained checkpoint'
]
,
barrier
=
True
)
print_rank_0
(
'training ...'
)
# Finetune the model.
...
...
tasks/orqa/supervised/finetune.py
View file @
d2073912
...
...
@@ -63,7 +63,7 @@ def orqa(Dataset):
tokenizer
=
get_tokenizer
()
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
,
log_level
=
2
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
...
...
tasks/vision/classification/classification.py
View file @
d2073912
...
...
@@ -68,7 +68,7 @@ def classification():
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
timers
(
"batch generator"
,
log_level
=
2
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
...
...
tasks/vision/finetune_utils.py
View file @
d2073912
...
...
@@ -136,7 +136,7 @@ def _train(
report_memory_flag
=
True
# For each remaining epoch
timers
(
"interval-time"
).
start
(
)
timers
(
"interval-time"
,
log_level
=
0
).
start
(
barrier
=
True
)
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
...
...
@@ -218,7 +218,7 @@ def finetune(
timers
=
get_timers
()
# Train and validation data loaders.
timers
(
"train/valid/test dataset/dataloder"
).
start
()
timers
(
"train/valid/test dataset/dataloder"
,
log_level
=
0
).
start
()
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
...
...
@@ -227,14 +227,14 @@ def finetune(
timers
(
"train/valid/test dataset/dataloder"
).
stop
()
# Build calback function.
timers
(
"callback function"
).
start
()
timers
(
"callback function"
,
log_level
=
0
).
start
()
end_of_epoch_callback
=
None
if
end_of_epoch_callback_provider
is
not
None
:
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
"callback function"
).
stop
()
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
timers
(
"model and optimizer"
,
log_level
=
0
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
\
setup_model_and_optimizer
(
model_provider
,
...
...
@@ -246,7 +246,7 @@ def finetune(
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers
(
"pretrained checkpoint"
).
start
(
)
timers
(
"pretrained checkpoint"
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
pretrained_checkpoint_type
==
'default'
:
original_load
=
args
.
load
...
...
tasks/vision/segmentation/finetune_segformer.py
View file @
d2073912
...
...
@@ -123,7 +123,7 @@ def segmentation():
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
timers
(
"batch generator"
,
log_level
=
2
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment