Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8e922d5b
"vscode:/vscode.git/clone" did not exist on "ef29b24fda25ce2637d90e6077d85b4cf3d6c701"
Commit
8e922d5b
authored
Jan 09, 2021
by
Deepak Narayanan
Browse files
Put in barriers in appropriate places to measure length of pipeline stall
parent
dd889062
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
2 deletions
+22
-2
megatron/schedules.py
megatron/schedules.py
+15
-2
megatron/training.py
megatron/training.py
+7
-0
No files found.
megatron/schedules.py
View file @
8e922d5b
...
@@ -95,6 +95,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -95,6 +95,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_model_chunks
=
len
(
model
)
...
@@ -108,8 +109,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -108,8 +109,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
all_warmup_microbatches
=
True
all_warmup_microbatches
=
True
else
:
else
:
num_warmup_microbatches
=
\
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches_remaining
=
\
...
@@ -272,6 +272,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -272,6 +272,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
optimizer
,
timers
,
forward_only
):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
timers
=
get_timers
()
assert
len
(
model
)
==
1
assert
len
(
model
)
==
1
model
=
model
[
0
]
model
=
model
[
0
]
...
@@ -295,11 +297,22 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -295,11 +297,22 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
input_tensor
=
recv_forward
(
timers
)
input_tensor
=
recv_forward
(
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
# Barrier before first receive to measure forward stall.
if
i
==
(
num_warmup_microbatches
-
1
):
timers
(
'forward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'forward-pipeline-stall'
).
stop
()
send_forward
(
output_tensor
,
timers
)
send_forward
(
output_tensor
,
timers
)
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
# Barrier before first receive to measure forward stall.
if
num_warmup_microbatches
==
0
:
timers
(
'forward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'forward-pipeline-stall'
).
stop
()
# Before running 1F1B, need to receive first forward tensor.
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# receive this tensor here.
...
...
megatron/training.py
View file @
8e922d5b
...
@@ -354,6 +354,11 @@ def train_step(forward_step_func, data_iterator,
...
@@ -354,6 +354,11 @@ def train_step(forward_step_func, data_iterator,
fp32_allreduce
=
args
.
fp32_allreduce
)
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
# Barrier to measure backward stall.
timers
(
'backward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'backward-pipeline-stall'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# This should only run for models that support pipelined model parallelism
...
@@ -443,10 +448,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -443,10 +448,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if
name
in
timers
.
timers
:
if
name
in
timers
.
timers
:
timers_to_log
.
append
(
name
)
timers_to_log
.
append
(
name
)
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-pipeline-stall'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-backward-send-forward-backward-recv'
)
add_to_logging
(
'forward-backward-send-forward-backward-recv'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-pipeline-stall'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-forward-recv'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment