Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
77efcccb
Commit
77efcccb
authored
Aug 10, 2022
by
Mohammad Shoeybi
Committed by
Jared Casper
Aug 10, 2022
Browse files
Timing levels
parent
83d78670
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
494 additions
and
218 deletions
+494
-218
megatron/arguments.py
megatron/arguments.py
+26
-0
megatron/global_vars.py
megatron/global_vars.py
+5
-85
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+12
-8
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+29
-16
megatron/p2p_communication.py
megatron/p2p_communication.py
+10
-9
megatron/schedules.py
megatron/schedules.py
+27
-19
megatron/timers.py
megatron/timers.py
+304
-0
megatron/training.py
megatron/training.py
+59
-59
pretrain_bert.py
pretrain_bert.py
+1
-1
pretrain_gpt.py
pretrain_gpt.py
+1
-1
pretrain_ict.py
pretrain_ict.py
+1
-1
pretrain_t5.py
pretrain_t5.py
+1
-1
pretrain_vision_classify.py
pretrain_vision_classify.py
+1
-1
pretrain_vision_dino.py
pretrain_vision_dino.py
+1
-1
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+1
-1
tasks/finetune_utils.py
tasks/finetune_utils.py
+7
-7
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+1
-1
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+1
-1
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+5
-5
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+1
-1
No files found.
megatron/arguments.py
View file @
77efcccb
...
@@ -411,6 +411,32 @@ def _add_logging_args(parser):
...
@@ -411,6 +411,32 @@ def _add_logging_args(parser):
help
=
'If set, calculate and log parameters norm.'
)
help
=
'If set, calculate and log parameters norm.'
)
group
.
add_argument
(
'--log-num-zeros-in-grad'
,
action
=
'store_true'
,
group
.
add_argument
(
'--log-num-zeros-in-grad'
,
action
=
'store_true'
,
help
=
'If set, calculate and log the number of zeros in gradient.'
)
help
=
'If set, calculate and log the number of zeros in gradient.'
)
group
.
add_argument
(
'--timing-log-level'
,
type
=
int
,
default
=
0
,
choices
=
range
(
0
,
3
),
help
=
'Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.'
)
group
.
add_argument
(
'--no-barrier-with-level-1-timing'
,
action
=
'store_false'
,
help
=
'If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.'
,
dest
=
'barrier_with_L1_time'
)
group
.
add_argument
(
'--timing-log-option'
,
type
=
str
,
default
=
'minmax'
,
choices
=
[
'max'
,
'minmax'
,
'all'
],
help
=
'Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.'
)
group
.
add_argument
(
'--tensorboard-log-interval'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--tensorboard-log-interval'
,
type
=
int
,
default
=
1
,
help
=
'Report to tensorboard interval.'
)
help
=
'Report to tensorboard interval.'
)
group
.
add_argument
(
'--tensorboard-queue-size'
,
type
=
int
,
default
=
1000
,
group
.
add_argument
(
'--tensorboard-queue-size'
,
type
=
int
,
default
=
1000
,
...
...
megatron/global_vars.py
View file @
77efcccb
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
import
os
import
os
import
sys
import
sys
import
time
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
operator
import
torch
import
torch
...
@@ -25,6 +24,7 @@ import torch
...
@@ -25,6 +24,7 @@ import torch
from
megatron
import
dist_signal_handler
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
megatron.tokenizer
import
build_tokenizer
from
.microbatches
import
build_num_microbatches_calculator
from
.microbatches
import
build_num_microbatches_calculator
from
.timers
import
Timers
_GLOBAL_ARGS
=
None
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
...
@@ -108,7 +108,7 @@ def set_global_variables(args):
...
@@ -108,7 +108,7 @@ def set_global_variables(args):
_
=
_build_tokenizer
(
args
)
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
_set_timers
(
args
)
_set_global_memory_buffer
()
_set_global_memory_buffer
()
if
args
.
exit_signal_handler
:
if
args
.
exit_signal_handler
:
...
@@ -182,11 +182,12 @@ def _set_adlr_autoresume(args):
...
@@ -182,11 +182,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME
=
AutoResume
_GLOBAL_ADLR_AUTORESUME
=
AutoResume
def
_set_timers
():
def
_set_timers
(
args
):
"""Initialize timers."""
"""Initialize timers."""
global
_GLOBAL_TIMERS
global
_GLOBAL_TIMERS
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_GLOBAL_TIMERS
=
Timers
()
_GLOBAL_TIMERS
=
Timers
(
args
.
timing_log_level
,
args
.
timing_log_option
)
def
_set_global_memory_buffer
():
def
_set_global_memory_buffer
():
"""Initialize global buffer"""
"""Initialize global buffer"""
...
@@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name):
...
@@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name):
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
Timers
:
"""Group of timers."""
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
_Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'-time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
class
GlobalMemoryBuffer
:
class
GlobalMemoryBuffer
:
"""Global buffer to avoid dynamic memory allocations.
"""Global buffer to avoid dynamic memory allocations.
...
...
megatron/optimizer/distrib_optimizer.py
View file @
77efcccb
...
@@ -532,17 +532,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -532,17 +532,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"""
"""
# All-reduce layer-norm grads (for sequence parallelism).
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'backward-layernorm-all-reduce'
).
start
()
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_layernorm_grads
(
args
)
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'
backward-
layernorm-all-reduce'
).
stop
()
timers
(
'layernorm-
grads-
all-reduce'
).
stop
()
# All-reduce embedding grads.
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_embedding_grads
(
args
)
self
.
allreduce_embedding_grads
(
args
)
timers
(
'
backward-
embedding-all-reduce'
).
stop
()
timers
(
'embedding-
grads-
all-reduce'
).
stop
()
# Reduce-scatter setup.
# Reduce-scatter setup.
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'grads-reduce-scatter'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
...
@@ -563,7 +566,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -563,7 +566,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
timers
(
'
backward-params-all-reduce
'
).
stop
()
timers
(
'
grads-reduce-scatter
'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
def
gather_model_params
(
self
,
args
,
timers
):
...
@@ -575,7 +578,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -575,7 +578,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
can be copied from param.main_grad to param.
can be copied from param.main_grad to param.
"""
"""
timers
(
'backward-params-all-gather'
).
start
()
timers
(
'params-all-gather'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
...
@@ -602,7 +606,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -602,7 +606,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
param
in
param_map
:
for
param
in
param_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'
backward-
params-all-gather'
).
stop
()
timers
(
'params-all-gather'
).
stop
()
def
_collect_main_grad_data_for_unscaling
(
self
):
def
_collect_main_grad_data_for_unscaling
(
self
):
...
...
megatron/optimizer/optimizer.py
View file @
77efcccb
...
@@ -294,21 +294,24 @@ class MegatronOptimizer(ABC):
...
@@ -294,21 +294,24 @@ class MegatronOptimizer(ABC):
"""All-reduce all grads, and all-reduce embeddings."""
"""All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'backward-layernorm-all-reduce'
).
start
()
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_layernorm_grads
(
args
)
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'
backward-
layernorm-all-reduce'
).
stop
()
timers
(
'layernorm-
grads-
all-reduce'
).
stop
()
# All-reduce if needed.
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
for
model
in
self
.
models
:
for
model
in
self
.
models
:
model
.
allreduce_gradients
()
model
.
allreduce_gradients
()
timers
(
'
backward-pa
ra
m
s-all-reduce'
).
stop
()
timers
(
'
g
ra
d
s-all-reduce'
).
stop
()
# All-reduce embedding grads.
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_embedding_grads
(
args
)
self
.
allreduce_embedding_grads
(
args
)
timers
(
'
backward-
embedding-all-reduce'
).
stop
()
timers
(
'embedding-
grads-
all-reduce'
).
stop
()
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
...
@@ -416,7 +419,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -416,7 +419,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def
step
(
self
,
args
,
timers
):
def
step
(
self
,
args
,
timers
):
# Copy gradients from model params to main params.
# Copy gradients from model params to main params.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
_copy_model_grads_to_main_grads
()
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
...
@@ -425,7 +429,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -425,7 +429,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
if
self
.
grad_scaler
:
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
timers
(
'optimizer-unscale-and-check-inf'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
...
@@ -438,25 +443,29 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -438,25 +443,29 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return
False
,
None
,
None
return
False
,
None
,
None
# Clip the main gradients.
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
timers
(
'optimizer-clip-main-grad'
).
stop
()
# Count the zeros in the grads.
# Count the zeros in the grads.
timers
(
'optimizer-count-zeros'
).
start
()
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
timers
(
'optimizer-count-zeros'
).
stop
()
# Step the optimizer.
# Step the optimizer.
timers
(
'optimizer-inner-step'
).
start
()
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
timers
(
'optimizer-inner-step'
).
stop
()
# Update params from main params.
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
timers
(
'optimizer-copy-main-to-model-params'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
_copy_main_params_to_model_params
()
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
...
@@ -725,7 +734,8 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -725,7 +734,8 @@ class FP32Optimizer(MegatronOptimizer):
Always return successful since there is no overflow."""
Always return successful since there is no overflow."""
# Copy main_grads to grads.
# Copy main_grads to grads.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
self
.
params_have_main_grad
:
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
for
param
in
param_group
[
'params'
]:
...
@@ -739,20 +749,23 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -739,20 +749,23 @@ class FP32Optimizer(MegatronOptimizer):
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Clip gradients.
# Clip gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
# count the zeros in the grads
timers
(
'optimizer-count-zeros'
).
start
()
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
timers
(
'optimizer-count-zeros'
).
stop
()
# Update parameters.
# Update parameters.
timers
(
'optimizer-inner-step'
).
start
()
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
timers
(
'optimizer-inner-step'
).
stop
()
...
...
megatron/p2p_communication.py
View file @
77efcccb
...
@@ -163,7 +163,7 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
...
@@ -163,7 +163,7 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
input_tensor
=
None
input_tensor
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
start
()
timers
(
'forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -182,7 +182,7 @@ def recv_backward(tensor_shape=None, timers=None):
...
@@ -182,7 +182,7 @@ def recv_backward(tensor_shape=None, timers=None):
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
start
()
timers
(
'backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -199,7 +199,7 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
...
@@ -199,7 +199,7 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
if
not
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
timers
(
'forward-send'
,
log_level
=
2
).
start
()
_communicate
(
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -215,7 +215,7 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
...
@@ -215,7 +215,7 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
timers
(
'backward-send'
,
log_level
=
2
).
start
()
_communicate
(
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
...
@@ -232,7 +232,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
...
@@ -232,7 +232,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
start
()
timers
(
'forward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -250,7 +250,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
...
@@ -250,7 +250,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
input_tensor
=
None
input_tensor
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
start
()
timers
(
'backward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
...
@@ -265,7 +265,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
...
@@ -265,7 +265,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -280,7 +280,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
...
@@ -280,7 +280,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
timers
(
'backward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
...
@@ -297,7 +297,8 @@ def send_forward_backward_recv_forward_backward(
...
@@ -297,7 +297,8 @@ def send_forward_backward_recv_forward_backward(
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
timers
(
'forward-backward-send-forward-backward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
...
...
megatron/schedules.py
View file @
77efcccb
...
@@ -107,6 +107,7 @@ def forward_step(forward_step_func,
...
@@ -107,6 +107,7 @@ def forward_step(forward_step_func,
model
,
model
,
input_tensor
,
input_tensor
,
forward_data_store
,
forward_data_store
,
timers
,
collect_non_loss_data
=
False
):
collect_non_loss_data
=
False
):
"""Forward step for passed-in model.
"""Forward step for passed-in model.
...
@@ -115,9 +116,9 @@ def forward_step(forward_step_func,
...
@@ -115,9 +116,9 @@ def forward_step(forward_step_func,
Returns output tensor."""
Returns output tensor."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
if
timers
is
not
None
:
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
@@ -138,7 +139,8 @@ def forward_step(forward_step_func,
...
@@ -138,7 +139,8 @@ def forward_step(forward_step_func,
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
forward_data_store
.
append
(
data
)
timers
(
'forward-compute'
).
stop
()
if
timers
is
not
None
:
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# and in decoder stack, then send encoder_hidden_state
...
@@ -151,7 +153,8 @@ def forward_step(forward_step_func,
...
@@ -151,7 +153,8 @@ def forward_step(forward_step_func,
return
[
output_tensor
]
return
[
output_tensor
]
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
):
"""Backward step through passed-in output tensor.
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
If last stage, output_tensor_grad is None, otherwise gradient of loss
...
@@ -165,8 +168,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
...
@@ -165,8 +168,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# connections.
# connections.
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
if
timers
is
not
None
:
timers
(
'backward-compute'
).
start
()
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
unwrap_input_tensor_grad
=
False
...
@@ -207,7 +210,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
...
@@ -207,7 +210,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
if
unwrap_input_tensor_grad
:
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
input_tensor_grad
=
input_tensor_grad
[
0
]
timers
(
'backward-compute'
).
stop
()
if
timers
is
not
None
:
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
return
input_tensor_grad
...
@@ -243,18 +247,19 @@ def forward_backward_no_pipelining(forward_step_func,
...
@@ -243,18 +247,19 @@ def forward_backward_no_pipelining(forward_step_func,
for
i
in
range
(
get_num_microbatches
()
-
1
):
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
)
return
forward_data_store
return
forward_data_store
...
@@ -269,6 +274,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -269,6 +274,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
communication between pipeline stages as needed.
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
args
=
get_args
()
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
forward_data_store
=
[]
forward_data_store
=
[]
...
@@ -278,7 +286,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -278,7 +286,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
if
args
.
sequence_parallel
:
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
...
@@ -337,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -337,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
model
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
input_tensor
,
forward_data_store
,
forward_data_store
,
timers
,
collect_non_loss_data
)
collect_non_loss_data
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
...
@@ -364,7 +372,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -364,7 +372,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
backward_step
(
optimizer
,
backward_step
(
optimizer
,
input_tensor
,
input_tensor
,
output_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
,
timers
)
return
input_tensor_grad
return
input_tensor_grad
...
@@ -620,8 +629,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -620,8 +629,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
assert
len
(
model
)
==
1
assert
len
(
model
)
==
1
model
=
model
[
0
]
model
=
model
[
0
]
...
@@ -656,7 +664,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -656,7 +664,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
if
not
forward_only
:
...
@@ -676,7 +684,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -676,7 +684,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
...
@@ -701,7 +709,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -701,7 +709,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
,
timers
)
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
...
@@ -721,7 +729,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -721,7 +729,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
,
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
...
...
megatron/timers.py
0 → 100644
View file @
77efcccb
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron timers."""
from
abc
import
ABC
from
abc
import
abstractmethod
import
time
import
torch
class
TimerBase
(
ABC
):
def
__init__
(
self
,
name
):
self
.
name
=
name
@
abstractmethod
def
start
(
self
,
barrier
=
False
):
pass
@
abstractmethod
def
stop
(
self
,
barrier
=
False
):
pass
@
abstractmethod
def
reset
(
self
):
pass
@
abstractmethod
def
elapsed
(
self
,
reset
=
True
,
barrier
=
False
):
pass
class
DummyTimer
(
TimerBase
):
def
__init__
(
self
):
super
().
__init__
(
'dummy timer'
)
def
start
(
self
,
barrier
=
False
):
return
def
stop
(
self
,
barrier
=
False
):
return
def
reset
(
self
):
return
def
elapsed
(
self
,
reset
=
True
,
barrier
=
False
):
raise
Exception
(
'dummy timer should not be used to '
'calculate elapsed time'
)
class
Timer
(
TimerBase
):
"""
Comment on using `barrier`: If this flag is passed, then all
the caller processes will wait till all reach the timing routine.
It is up to the user to make sure all the ranks in `barrier_group`
call it otherwise, it will result in a hang.
Comment on `barrier_group`: By default it is set to None which
in torch distributed land, it will result in the global communicator.
"""
def
__init__
(
self
,
name
):
super
().
__init__
(
name
)
self
.
_elapsed
=
0.0
self
.
_started
=
False
# Note that None will default to the global process group
self
.
_barrier_group
=
None
self
.
_start_time
=
time
.
time
()
def
set_barrier_group
(
self
,
barrier_group
):
self
.
_barrier_group
=
barrier_group
def
start
(
self
,
barrier
=
False
):
"""Start the timer."""
assert
not
self
.
_started
,
'timer has already been started'
if
barrier
:
torch
.
distributed
.
barrier
(
group
=
self
.
_barrier_group
)
torch
.
cuda
.
synchronize
()
self
.
_start_time
=
time
.
time
()
self
.
_started
=
True
def
stop
(
self
,
barrier
=
False
):
"""Stop the timer."""
assert
self
.
_started
,
'timer is not started'
if
barrier
:
torch
.
distributed
.
barrier
(
group
=
self
.
_barrier_group
)
torch
.
cuda
.
synchronize
()
self
.
_elapsed
+=
(
time
.
time
()
-
self
.
_start_time
)
self
.
_started
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
_elapsed
=
0.0
self
.
_started
=
False
def
elapsed
(
self
,
reset
=
True
,
barrier
=
False
):
"""Calculate the elapsed time."""
_started
=
self
.
_started
# If the timing in progress, end it first.
if
self
.
_started
:
self
.
stop
(
barrier
=
barrier
)
# Get the elapsed time.
_elapsed
=
self
.
_elapsed
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
_started
:
self
.
start
(
barrier
=
barrier
)
return
_elapsed
class
Timers
:
"""Group of timers."""
def
__init__
(
self
,
log_level
,
log_option
):
self
.
_log_level
=
log_level
self
.
_log_option
=
log_option
self
.
_timers
=
{}
self
.
_log_levels
=
{}
self
.
_dummy_timer
=
DummyTimer
()
self
.
_max_log_level
=
2
def
__call__
(
self
,
name
,
log_level
=
None
):
# If the timer has already been set, then check if the log-level
# is provided, it matches the one that the timer was created with.
if
name
in
self
.
_timers
:
if
log_level
is
not
None
:
assert
log_level
==
self
.
_log_levels
[
name
],
\
'input log level {} does not match already existing '
\
'log level {} for {} timer'
.
format
(
log_level
,
self
.
_log_levels
[
name
],
name
)
return
self
.
_timers
[
name
]
# If timer does not exist and no log level is provided,
# set it to the max log level which is 2.
if
log_level
is
None
:
log_level
=
self
.
_max_log_level
assert
log_level
<=
self
.
_max_log_level
,
\
'log level {} is larger than max supported log level {}'
.
format
(
log_level
,
self
.
_max_log_level
)
# Now if the input log level is larger than the one set for
# the timers class, just ignore it and return a dummy timer.
if
log_level
>
self
.
_log_level
:
return
self
.
_dummy_timer
# Otherwise, initalize the timer and set the level.
self
.
_timers
[
name
]
=
Timer
(
name
)
self
.
_log_levels
[
name
]
=
log_level
return
self
.
_timers
[
name
]
def
_get_elapsed_time_all_ranks
(
self
,
names
,
reset
,
barrier
):
"""
Assumptions:
- All the ranks call this function.
- `names` are identical on all ranks.
If the above assumptions are not met, calling this function will
result in hang.
Arguments:
- names: list of timer names
- reset: reset the timer after recording the elapsed time
- barrier: if set, do a global barrier before time measurments
"""
# First make sure all the callers are in sync.
if
barrier
:
torch
.
distributed
.
barrier
()
world_size
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
# Here we can use gather on the rank we want to print the
# timing, however, there is no gather_base support in
# pytorch yet. It is simpler to deal with a single tensor
# and since we are only gathering a small amount of data,
# it should be ok to use all-gather instead of gather.
rank_name_to_time
=
torch
.
zeros
((
world_size
,
len
(
names
)),
dtype
=
torch
.
float
,
device
=
torch
.
cuda
.
current_device
())
for
i
,
name
in
enumerate
(
names
):
if
name
in
self
.
_timers
:
# Here we don't need to pass the barrier flag as all
# the processes are already in sync. This avoids the
# issue of different timers having different barrier
# groups inside their class.
rank_name_to_time
[
rank
,
i
]
=
self
.
_timers
[
name
].
elapsed
(
reset
=
reset
)
# See the note above for why we are not using gather.
torch
.
distributed
.
_all_gather_base
(
rank_name_to_time
.
view
(
-
1
),
rank_name_to_time
[
rank
,
:].
view
(
-
1
))
return
rank_name_to_time
def
_get_global_min_max_time
(
self
,
names
,
reset
,
barrier
,
normalizer
):
"""Report only min and max times across all ranks."""
rank_name_to_time
=
self
.
_get_elapsed_time_all_ranks
(
names
,
reset
,
barrier
)
name_to_min_max_time
=
{}
for
i
,
name
in
enumerate
(
names
):
rank_to_time
=
rank_name_to_time
[:,
i
]
# filter out the ones we did not have any timings for
rank_to_time
=
rank_to_time
[
rank_to_time
>
0.0
]
# If the timer exists:
if
rank_to_time
.
numel
()
>
0
:
name_to_min_max_time
[
name
]
=
(
rank_to_time
.
min
().
item
()
/
normalizer
,
rank_to_time
.
max
().
item
()
/
normalizer
)
return
name_to_min_max_time
def
_get_global_min_max_time_string
(
self
,
names
,
reset
,
barrier
,
normalizer
,
max_only
):
name_to_min_max_time
=
self
.
_get_global_min_max_time
(
names
,
reset
,
barrier
,
normalizer
)
if
not
name_to_min_max_time
:
return
None
output_string
=
'(min, max) time across ranks (ms):'
for
name
in
name_to_min_max_time
:
min_time
,
max_time
=
name_to_min_max_time
[
name
]
if
max_only
:
output_string
+=
'
\n
{}: {:.2f}'
.
format
(
(
name
+
' '
).
ljust
(
48
,
'.'
),
max_time
)
else
:
output_string
+=
'
\n
{}: ({:.2f}, {:.2f})'
.
format
(
(
name
+
' '
).
ljust
(
48
,
'.'
),
min_time
,
max_time
)
return
output_string
def
_get_all_ranks_time_string
(
self
,
names
,
reset
,
barrier
,
normalizer
):
"""Report times across all ranks."""
rank_name_to_time
=
self
.
_get_elapsed_time_all_ranks
(
names
,
reset
,
barrier
)
output_string
=
'times across ranks (ms):'
no_reported_timing
=
True
for
i
,
name
in
enumerate
(
names
):
not_yet_found
=
True
for
rank
in
range
(
torch
.
distributed
.
get_world_size
()):
if
rank_name_to_time
[
rank
,
i
]
>
0
:
no_reported_timing
=
False
if
not_yet_found
:
not_yet_found
=
False
output_string
+=
'
\n
{}:'
.
format
(
name
)
output_string
+=
'
\n
rank {:2d}: {:.2f}'
.
format
(
rank
,
rank_name_to_time
[
rank
,
i
]
/
normalizer
)
if
no_reported_timing
:
return
None
return
output_string
def
log
(
self
,
names
,
rank
=
None
,
normalizer
=
1.0
,
reset
=
True
,
barrier
=
False
):
"""Log a group of timers."""
# Print.
assert
normalizer
>
0.0
if
self
.
_log_option
in
[
'max'
,
'minmax'
]:
max_only
=
False
if
self
.
_log_option
==
'max'
:
max_only
=
True
output_string
=
self
.
_get_global_min_max_time_string
(
names
,
reset
,
barrier
,
normalizer
/
1000.0
,
max_only
)
elif
self
.
_log_option
==
'all'
:
output_string
=
self
.
_get_all_ranks_time_string
(
names
,
reset
,
barrier
,
normalizer
/
1000.0
)
else
:
raise
Exception
(
'unknown timing log option {}'
.
format
(
self
.
_log_option
))
# If no input rank is provided, log on last rank.
if
rank
is
None
:
rank
=
torch
.
distributed
.
get_world_size
()
-
1
if
rank
==
torch
.
distributed
.
get_rank
()
and
output_string
is
not
None
:
print
(
output_string
,
flush
=
True
)
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
,
barrier
=
False
):
"""Write timers to a tensorboard writer
Note that we only report maximum time across ranks to tensorboard.
"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
name_to_min_max_time
=
self
.
_get_global_min_max_time
(
names
,
reset
,
barrier
,
normalizer
)
if
writer
is
not
None
:
for
name
in
name_to_min_max_time
:
_
,
max_time
=
name_to_min_max_time
[
name
]
writer
.
add_scalar
(
name
+
'-time'
,
max_time
,
iteration
)
megatron/training.py
View file @
77efcccb
...
@@ -119,23 +119,28 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -119,23 +119,28 @@ def pretrain(train_valid_test_dataset_provider,
timers
=
get_timers
()
timers
=
get_timers
()
# Model, optimizer, and learning rate.
# Model, optimizer, and learning rate.
timers
(
'model-and-optimizer-setup'
).
start
(
)
timers
(
'model-and-optimizer-setup'
,
log_level
=
0
).
start
(
barrier
=
True
)
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_type
)
model_provider
,
model_type
)
timers
(
'model-and-optimizer-setup'
).
stop
()
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
'scheduler are built'
)
# Data stuff.
# Data stuff.
timers
(
'train/valid/test-data-iterators-setup'
).
start
()
timers
(
'train/valid/test-data-iterators-setup'
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
all_data_iterators
=
[
all_data_iterators
=
[
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
for
_
in
range
(
len
(
model
))
for
_
in
range
(
len
(
model
))
]
]
train_data_iterator
=
[
data_iterators
[
0
]
for
data_iterators
in
all_data_iterators
]
train_data_iterator
=
[
data_iterators
[
0
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_data_iterators
]
for
data_iterators
in
all_data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_data_iterators
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_data_iterators
]
else
:
else
:
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
=
build_train_valid_test_data_iterators
(
...
@@ -145,7 +150,8 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -145,7 +150,8 @@ def pretrain(train_valid_test_dataset_provider,
# Print setup timing.
# Print setup timing.
print_rank_0
(
'done with setup ...'
)
print_rank_0
(
'done with setup ...'
)
timers
.
log
([
'model-and-optimizer-setup'
,
'train/valid/test-data-iterators-setup'
])
timers
.
log
([
'model-and-optimizer-setup'
,
'train/valid/test-data-iterators-setup'
],
barrier
=
True
)
print_rank_0
(
'training ...'
)
print_rank_0
(
'training ...'
)
iteration
=
0
iteration
=
0
...
@@ -373,13 +379,9 @@ def setup_model_and_optimizer(model_provider_func,
...
@@ -373,13 +379,9 @@ def setup_model_and_optimizer(model_provider_func,
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
timers
=
get_timers
()
timers
=
get_timers
()
# Extra barrier is added to make sure all ranks report the
timers
(
'load-checkpoint'
,
log_level
=
0
).
start
(
barrier
=
True
)
# max time.
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
opt_param_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
opt_param_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
(
barrier
=
True
)
timers
(
'load-checkpoint'
).
stop
()
timers
.
log
([
'load-checkpoint'
])
timers
.
log
([
'load-checkpoint'
])
else
:
else
:
args
.
iteration
=
0
args
.
iteration
=
0
...
@@ -412,19 +414,21 @@ def train_step(forward_step_func, data_iterator,
...
@@ -412,19 +414,21 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
# Forward pass.
# Forward pass.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
fwd_bwd_timers
=
timers
if
args
.
timing_log_level
>
1
else
None
losses_reduced
=
forward_backward_func
(
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
optimizer
,
fwd_bwd_timers
,
forward_only
=
False
)
timers
(
'forward-backward'
).
stop
()
# Empty unused memory.
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
1
:
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# Reduce gradients.
# Reduce gradients.
timers
(
'backward-reduce-model-grads'
).
start
()
optimizer
.
reduce_model_grads
(
args
,
timers
)
optimizer
.
reduce_model_grads
(
args
,
timers
)
timers
(
'backward-reduce-model-grads'
).
stop
()
# Vision gradients.
# Vision gradients.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
...
@@ -433,15 +437,13 @@ def train_step(forward_step_func, data_iterator,
...
@@ -433,15 +437,13 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
(
)
timers
(
'optimizer'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
)
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# Gather params.
# Gather params.
if
update_successful
:
if
update_successful
:
timers
(
'backward-gather-model-params'
).
start
()
optimizer
.
gather_model_params
(
args
,
timers
)
optimizer
.
gather_model_params
(
args
,
timers
)
timers
(
'backward-gather-model-params'
).
stop
()
# Vision momentum.
# Vision momentum.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
...
@@ -511,33 +513,32 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -511,33 +513,32 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
nan_iters_key
,
0
)
+
int
(
got_nan
)
nan_iters_key
,
0
)
+
int
(
got_nan
)
# Logging.
# Logging.
timers_to_log
=
[]
timers_to_log
=
[
'forward-backward'
,
def
add_to_logging
(
name
):
'forward-compute'
,
if
name
in
timers
.
timers
:
'backward-compute'
,
timers_to_log
.
append
(
name
)
'batch-generator'
,
add_to_logging
(
'forward-compute'
)
'forward-recv'
,
add_to_logging
(
'forward-recv'
)
'forward-send'
,
add_to_logging
(
'forward-send'
)
'backward-recv'
,
add_to_logging
(
'forward-backward-send-forward-backward-recv'
)
'backward-send'
,
add_to_logging
(
'backward-compute'
)
'forward-send-forward-recv'
,
add_to_logging
(
'backward-recv'
)
'forward-send-backward-recv'
,
add_to_logging
(
'backward-send'
)
'backward-send-forward-recv'
,
add_to_logging
(
'backward-send-forward-recv'
)
'backward-send-backward-recv'
,
add_to_logging
(
'backward-send-backward-recv'
)
'forward-backward-send-forward-backward-recv'
,
add_to_logging
(
'backward-params-all-reduce'
)
'layernorm-grads-all-reduce'
,
add_to_logging
(
'backward-layernorm-all-reduce'
)
'embedding-grads-all-reduce'
,
add_to_logging
(
'backward-embedding-all-reduce'
)
'grads-all-reduce'
,
add_to_logging
(
'backward-reduce-model-grads'
)
'grads-reduce-scatter'
,
add_to_logging
(
'backward-gather-model-params'
)
'params-all-gather'
,
add_to_logging
(
'optimizer-copy-to-main-grad'
)
'optimizer-copy-to-main-grad'
,
add_to_logging
(
'optimizer-unscale-and-check-inf'
)
'optimizer-unscale-and-check-inf'
,
add_to_logging
(
'optimizer-clip-main-grad'
)
'optimizer-clip-main-grad'
,
add_to_logging
(
'optimizer-count-zeros'
)
'optimizer-count-zeros'
,
add_to_logging
(
'optimizer-inner-step'
)
'optimizer-inner-step'
,
add_to_logging
(
'optimizer-copy-main-to-model-params'
)
'optimizer-copy-main-to-model-params'
,
add_to_logging
(
'optimizer'
)
'optimizer'
]
add_to_logging
(
'batch-generator'
)
# Calculate batch size.
# Calculate batch size.
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
*
\
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
*
\
...
@@ -547,8 +548,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -547,8 +548,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict
[
skipped_iters_key
]
total_loss_dict
[
skipped_iters_key
]
# Tensorboard values.
# Tensorboard values.
if
writer
and
(
iteration
%
args
.
tensorboard_log_interval
==
0
)
and
\
# Timer requires all the ranks to call.
is_last_rank
():
if
args
.
log_timers_to_tensorboard
and
\
(
iteration
%
args
.
tensorboard_log_interval
==
0
):
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
normalizer
=
total_iterations
)
if
writer
and
(
iteration
%
args
.
tensorboard_log_interval
==
0
):
if
args
.
log_learning_rate_to_tensorboard
:
if
args
.
log_learning_rate_to_tensorboard
:
writer
.
add_scalar
(
'learning-rate'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning-rate'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning-rate vs samples'
,
learning_rate
,
writer
.
add_scalar
(
'learning-rate vs samples'
,
learning_rate
,
...
@@ -581,9 +586,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -581,9 +586,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer
.
add_scalar
(
'params-norm'
,
params_norm
,
iteration
)
writer
.
add_scalar
(
'params-norm'
,
params_norm
,
iteration
)
writer
.
add_scalar
(
'params-norm vs samples'
,
params_norm
,
writer
.
add_scalar
(
'params-norm vs samples'
,
params_norm
,
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
if
args
.
log_timers_to_tensorboard
:
timers
.
write
(
timers_to_log
,
writer
,
iteration
,
normalizer
=
total_iterations
)
if
args
.
log_memory_to_tensorboard
:
if
args
.
log_memory_to_tensorboard
:
mem_stats
=
torch
.
cuda
.
memory_stats
()
mem_stats
=
torch
.
cuda
.
memory_stats
()
writer
.
add_scalar
(
writer
.
add_scalar
(
...
@@ -603,7 +605,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -603,7 +605,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
)
)
if
iteration
%
args
.
log_interval
==
0
:
if
iteration
%
args
.
log_interval
==
0
:
elapsed_time
=
timers
(
'interval-time'
).
elapsed
()
elapsed_time
=
timers
(
'interval-time'
).
elapsed
(
barrier
=
True
)
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
if
writer
:
if
writer
:
if
args
.
log_timers_to_tensorboard
:
if
args
.
log_timers_to_tensorboard
:
...
@@ -653,11 +655,9 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
...
@@ -653,11 +655,9 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers
=
get_timers
()
timers
=
get_timers
()
# Extra barrier is added to make sure
# Extra barrier is added to make sure
# all ranks report the max time.
# all ranks report the max time.
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
,
log_level
=
0
).
start
(
barrier
=
True
)
timers
(
'save-checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
stop
(
barrier
=
True
)
timers
(
'save-checkpoint'
).
stop
()
timers
.
log
([
'save-checkpoint'
])
timers
.
log
([
'save-checkpoint'
])
...
@@ -681,7 +681,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -681,7 +681,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
# Iterations.
iteration
=
args
.
iteration
iteration
=
args
.
iteration
timers
(
'interval-time'
).
start
(
)
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
print_datetime
(
'before the start of training step'
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
...
...
pretrain_bert.py
View file @
77efcccb
...
@@ -104,7 +104,7 @@ def forward_step(data_iterator, model):
...
@@ -104,7 +104,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
data_iterator
)
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
...
...
pretrain_gpt.py
View file @
77efcccb
...
@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
...
@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
...
...
pretrain_ict.py
View file @
77efcccb
...
@@ -134,7 +134,7 @@ def forward_step(data_iterator, model):
...
@@ -134,7 +134,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
query_tokens
,
query_mask
,
\
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
context_indices
=
get_ict_batch
(
data_iterator
)
context_tokens
,
context_mask
,
context_indices
=
get_ict_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
...
...
pretrain_t5.py
View file @
77efcccb
...
@@ -126,7 +126,7 @@ def forward_step(data_iterator, model):
...
@@ -126,7 +126,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
,
log_level
=
2
).
start
()
tokens_enc
,
tokens_dec
,
loss_mask
,
lm_labels
,
enc_mask
,
dec_mask
,
enc_dec_mask
\
tokens_enc
,
tokens_dec
,
loss_mask
,
lm_labels
,
enc_mask
,
dec_mask
,
enc_dec_mask
\
=
get_batch
(
data_iterator
)
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
...
...
pretrain_vision_classify.py
View file @
77efcccb
...
@@ -77,7 +77,7 @@ def forward_step(data_iterator, model):
...
@@ -77,7 +77,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch-generator"
).
start
()
timers
(
"batch-generator"
,
log_level
=
2
).
start
()
(
(
images
,
images
,
labels
,
labels
,
...
...
pretrain_vision_dino.py
View file @
77efcccb
...
@@ -84,7 +84,7 @@ def forward_step(data_iterator, model):
...
@@ -84,7 +84,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch-generator"
).
start
()
timers
(
"batch-generator"
,
log_level
=
2
).
start
()
(
(
images
,
images
,
labels
,
labels
,
...
...
pretrain_vision_inpaint.py
View file @
77efcccb
...
@@ -91,7 +91,7 @@ def forward_step(data_iterator, model):
...
@@ -91,7 +91,7 @@ def forward_step(data_iterator, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch-generator"
).
start
()
timers
(
"batch-generator"
,
log_level
=
2
).
start
()
(
(
images
,
images
,
masks
,
masks
,
...
...
tasks/finetune_utils.py
View file @
77efcccb
...
@@ -67,7 +67,7 @@ def _cross_entropy_forward_step(batch, model):
...
@@ -67,7 +67,7 @@ def _cross_entropy_forward_step(batch, model):
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
try
:
try
:
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
except
BaseException
:
except
BaseException
:
...
@@ -178,7 +178,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
...
@@ -178,7 +178,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
report_memory_flag
=
True
report_memory_flag
=
True
# For each remaining epoch
# For each remaining epoch
timers
(
'interval-time'
).
start
(
)
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
print_rank_0
(
'working on epoch {} ...'
.
format
(
epoch
+
1
))
...
@@ -261,7 +261,7 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -261,7 +261,7 @@ def finetune(train_valid_datasets_provider, model_provider,
'batch size scaling is not supported for finetuning'
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
# Train and validation data loaders.
timers
(
'train/valid/test dataset/dataloder'
).
start
()
timers
(
'train/valid/test dataset/dataloder'
,
log_level
=
0
).
start
()
if
args
.
epochs
>
0
:
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
...
@@ -271,21 +271,21 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -271,21 +271,21 @@ def finetune(train_valid_datasets_provider, model_provider,
timers
(
'train/valid/test dataset/dataloder'
).
stop
()
timers
(
'train/valid/test dataset/dataloder'
).
stop
()
# Build calback function.
# Build calback function.
timers
(
'callback function'
).
start
()
timers
(
'callback function'
,
log_level
=
0
).
start
()
end_of_epoch_callback
=
None
end_of_epoch_callback
=
None
if
end_of_epoch_callback_provider
is
not
None
:
if
end_of_epoch_callback_provider
is
not
None
:
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
'callback function'
).
stop
()
timers
(
'callback function'
).
stop
()
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
'model and optimizer'
).
start
()
timers
(
'model and optimizer'
,
log_level
=
0
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model and optimizer'
).
stop
()
timers
(
'model and optimizer'
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
# checkpoint.
timers
(
'pretrained checkpoint'
).
start
(
)
timers
(
'pretrained checkpoint'
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
args
.
load
=
args
.
pretrained_checkpoint
...
@@ -302,7 +302,7 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -302,7 +302,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Print setup timing.
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'train/valid/test dataset/dataloder'
,
'callback function'
,
timers
.
log
([
'train/valid/test dataset/dataloder'
,
'callback function'
,
'model and optimizer'
,
'pretrained checkpoint'
])
'model and optimizer'
,
'pretrained checkpoint'
]
,
barrier
=
True
)
print_rank_0
(
'training ...'
)
print_rank_0
(
'training ...'
)
# Finetune the model.
# Finetune the model.
...
...
tasks/orqa/supervised/finetune.py
View file @
77efcccb
...
@@ -63,7 +63,7 @@ def orqa(Dataset):
...
@@ -63,7 +63,7 @@ def orqa(Dataset):
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
,
log_level
=
2
).
start
()
try
:
try
:
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
except
BaseException
:
except
BaseException
:
...
...
tasks/vision/classification/classification.py
View file @
77efcccb
...
@@ -68,7 +68,7 @@ def classification():
...
@@ -68,7 +68,7 @@ def classification():
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch generator"
).
start
()
timers
(
"batch generator"
,
log_level
=
2
).
start
()
try
:
try
:
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
except
BaseException
:
except
BaseException
:
...
...
tasks/vision/finetune_utils.py
View file @
77efcccb
...
@@ -136,7 +136,7 @@ def _train(
...
@@ -136,7 +136,7 @@ def _train(
report_memory_flag
=
True
report_memory_flag
=
True
# For each remaining epoch
# For each remaining epoch
timers
(
"interval-time"
).
start
(
)
timers
(
"interval-time"
,
log_level
=
0
).
start
(
barrier
=
True
)
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
...
@@ -218,7 +218,7 @@ def finetune(
...
@@ -218,7 +218,7 @@ def finetune(
timers
=
get_timers
()
timers
=
get_timers
()
# Train and validation data loaders.
# Train and validation data loaders.
timers
(
"train/valid/test dataset/dataloder"
).
start
()
timers
(
"train/valid/test dataset/dataloder"
,
log_level
=
0
).
start
()
if
args
.
epochs
>
0
:
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
...
@@ -227,14 +227,14 @@ def finetune(
...
@@ -227,14 +227,14 @@ def finetune(
timers
(
"train/valid/test dataset/dataloder"
).
stop
()
timers
(
"train/valid/test dataset/dataloder"
).
stop
()
# Build calback function.
# Build calback function.
timers
(
"callback function"
).
start
()
timers
(
"callback function"
,
log_level
=
0
).
start
()
end_of_epoch_callback
=
None
end_of_epoch_callback
=
None
if
end_of_epoch_callback_provider
is
not
None
:
if
end_of_epoch_callback_provider
is
not
None
:
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
"callback function"
).
stop
()
timers
(
"callback function"
).
stop
()
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
timers
(
"model and optimizer"
,
log_level
=
0
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
\
model
,
optimizer
,
opt_param_scheduler
=
\
setup_model_and_optimizer
(
setup_model_and_optimizer
(
model_provider
,
model_provider
,
...
@@ -246,7 +246,7 @@ def finetune(
...
@@ -246,7 +246,7 @@ def finetune(
# If pretrained checkpoint is provided and we have not trained for
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
# checkpoint.
timers
(
"pretrained checkpoint"
).
start
(
)
timers
(
"pretrained checkpoint"
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
pretrained_checkpoint_type
==
'default'
:
if
args
.
pretrained_checkpoint_type
==
'default'
:
original_load
=
args
.
load
original_load
=
args
.
load
...
...
tasks/vision/segmentation/finetune_segformer.py
View file @
77efcccb
...
@@ -123,7 +123,7 @@ def segmentation():
...
@@ -123,7 +123,7 @@ def segmentation():
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch generator"
).
start
()
timers
(
"batch generator"
,
log_level
=
2
).
start
()
import
types
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment