Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
626645c0
Commit
626645c0
authored
Feb 10, 2021
by
Deepak Narayanan
Browse files
Comments in megatron/schedules.py and address a few more comments
parent
cc691cbf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
32 deletions
+89
-32
megatron/schedules.py
megatron/schedules.py
+76
-21
megatron/training.py
megatron/training.py
+13
-11
No files found.
megatron/schedules.py
View file @
626645c0
...
...
@@ -23,7 +23,12 @@ from megatron import p2p_communication
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
"""Forward step."""
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
...
...
@@ -38,7 +43,13 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step."""
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -65,7 +76,10 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run forward and backward passes without inter-stage communication."""
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert
len
(
model
)
==
1
model
=
model
[
0
]
...
...
@@ -83,7 +97,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run interleaved 1F1B schedule."""
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
...
...
@@ -100,18 +117,27 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
get_num_microbatches
()
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
k
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
k_in_group
=
k
%
(
pipeline_parallel_size
*
num_model_chunks
)
i
=
k_in_group
//
pipeline_parallel_size
if
not
forward
:
...
...
@@ -119,14 +145,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return
i
def
forward_step_helper
(
k
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
...
...
@@ -134,6 +165,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return
output_tensor
def
backward_step_helper
(
k
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
...
...
@@ -144,15 +178,21 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
,
use_ring_exchange
=
True
))
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
,
use_ring_exchange
=
True
))
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
...
...
@@ -160,8 +200,13 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
\
not
all_warmup_microbatches
:
input_tensor_grad
=
None
...
...
@@ -176,7 +221,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
...
...
@@ -215,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
...
@@ -226,10 +273,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
before the
# start of the for loop.
# If last iteration, don't receive; we already received one extra
#
before the
start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
...
...
@@ -240,13 +288,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
# Put input_tensor and output_tensor_grad in data structures in the right location.
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
# Run cooldown backward passes.
# Run cooldown backward passes
(flush out pipeline)
.
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
...
...
@@ -269,7 +319,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers
=
get_timers
()
assert
len
(
model
)
==
1
...
...
@@ -327,7 +380,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
else
:
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
)
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
...
...
@@ -349,7 +403,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
else
:
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
# Run cooldown backward passes.
if
not
forward_only
:
...
...
megatron/training.py
View file @
626645c0
...
...
@@ -120,13 +120,13 @@ def pretrain(train_valid_test_dataset_provider,
# Data stuff.
timers
(
'train/valid/test-data-iterators-setup'
).
start
()
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
data_iterators
=
[
all_
data_iterators
=
[
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
for
_
in
range
(
len
(
model
))
]
train_data_iterator
=
[
x
[
0
]
for
x
in
data_iterators
]
valid_data_iterator
=
[
x
[
1
]
for
x
in
data_iterators
]
test_data_iterator
=
[
x
[
2
]
for
x
in
data_iterators
]
train_data_iterator
=
[
data_iterators
[
0
]
for
data_iterators
in
all_
data_iterators
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_
data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_
data_iterators
]
else
:
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
...
...
@@ -311,17 +311,18 @@ def setup_model_and_optimizer(model_provider_func):
# We only support local DDP with multiple micro-batches.
if
get_num_microbatches
()
>
1
:
assert
args
.
DDP_impl
==
'local'
if
len
(
model
)
==
1
:
assert
args
.
DDP_impl
==
'local'
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
model
=
unwrap_model
(
model
)
for
module
in
model
:
unwrapped_module
=
module
while
hasattr
(
unwrapped_module
,
'module'
):
unwrapped_module
=
unwrapped_module
.
module
if
args
.
iteration
==
0
and
hasattr
(
unwrapped_module
,
if
args
.
iteration
==
0
and
hasattr
(
module
,
'init_state_dict_from_bert'
):
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
unwrapped_
module
.
init_state_dict_from_bert
()
module
.
init_state_dict_from_bert
()
return
model
,
optimizer
,
lr_scheduler
...
...
@@ -364,7 +365,8 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
(
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
))
and
\
if
(
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
))
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment