Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
318d68c2
Commit
318d68c2
authored
Oct 23, 2020
by
Deepak Narayanan
Browse files
Refactor communication code in main training loop to helper method
parent
e805f0bd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
62 deletions
+96
-62
megatron/arguments.py
megatron/arguments.py
+4
-2
megatron/training.py
megatron/training.py
+92
-60
No files found.
megatron/arguments.py
View file @
318d68c2
...
@@ -59,6 +59,8 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -59,6 +59,8 @@ def parse_args(extra_args_provider=None, defaults={},
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
if
args
.
num_microbatches_in_minibatch
is
None
:
args
.
num_microbatches_in_minibatch
=
1
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '
.
format
(
print
(
'using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
))
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
))
...
@@ -223,6 +225,8 @@ def _add_training_args(parser):
...
@@ -223,6 +225,8 @@ def _add_training_args(parser):
help
=
'Batch size per model instance (local batch size). '
help
=
'Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'Global batch size is local batch size times data '
'parallel size.'
)
'parallel size.'
)
group
.
add_argument
(
'--num-microbatches-in-minibatch'
,
type
=
int
,
default
=
None
,
help
=
'Number of microbatches in minibatch'
)
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
'with larger models, sequences, and batch sizes.'
)
...
@@ -368,8 +372,6 @@ def _add_distributed_args(parser):
...
@@ -368,8 +372,6 @@ def _add_distributed_args(parser):
help
=
'Degree of tensor model parallelism.'
)
help
=
'Degree of tensor model parallelism.'
)
group
.
add_argument
(
'--pipeline-model-parallel-size'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--pipeline-model-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'Degree of pipeline model parallelism.'
)
help
=
'Degree of pipeline model parallelism.'
)
group
.
add_argument
(
'--use-pipelining'
,
action
=
'store_true'
,
help
=
'Use pipelining to increase throughput of pipeline model parallelism'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
choices
=
[
'nccl'
,
'gloo'
],
help
=
'Which backend to use for distributed training.'
)
help
=
'Which backend to use for distributed training.'
)
...
...
megatron/training.py
View file @
318d68c2
...
@@ -138,7 +138,7 @@ def get_model(model_provider_func):
...
@@ -138,7 +138,7 @@ def get_model(model_provider_func):
model
=
FP16_Module
(
model
)
model
=
FP16_Module
(
model
)
# Wrap model for distributed training."""
# Wrap model for distributed training."""
if
args
.
use_pipelining
:
if
args
.
num_microbatches_in_minibatch
>
1
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
...
@@ -291,6 +291,67 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
...
@@ -291,6 +291,67 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
return
input_tensor_grad
return
input_tensor_grad
def
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward model for one step.
timers
(
'forward'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
losses_reduced
.
append
(
loss_reduced
)
else
:
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
def
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
):
"""Backward step."""
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers
(
'backward'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
False
,
recv_backward
=
False
)
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
model
,
optimizer
,
lr_scheduler
):
"""Single training step."""
"""Single training step."""
...
@@ -304,70 +365,41 @@ def train_step(forward_step_func, data_iterator,
...
@@ -304,70 +365,41 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
# Compute number of microbatches in a minibatch.
# Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline
=
args
.
pipeline_model_parallel_size
\
num_microbatches_in_minibatch
=
args
.
num_microbatches_in_minibatch
if
args
.
use_pipelining
else
1
# TODO: Switch to the following schedule when async communication is supported
# so that we can facilitate mroe memory-efficient training.
# num_warmup_microbatches = \
# (torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) -
# torch.distributed.get_rank(group=mpu.get_pipeline_model_parallel_group()) - 1)
# num_warmup_microbatches = min(
# num_warmup_microbatches,
# num_microbatches_in_minibatch)
num_warmup_microbatches
=
num_microbatches_in_minibatch
input_tensors
=
[]
input_tensors
=
[]
output_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
losses_reduced
=
[]
# Run forward pass for all microbatches in minibatch.
# Run warmup forward passes.
for
i
in
range
(
num_microbatches_to_pipeline
):
for
i
in
range
(
num_warmup_microbatches
):
if
not
mpu
.
is_pipeline_first_stage
():
forward_step_with_communication
(
input_tensor
,
_
=
communicate
(
forward_step_func
,
data_iterator
,
model
,
tensor_send_next
=
None
,
input_tensors
,
output_tensors
,
tensor_send_prev
=
None
,
losses_reduced
,
timers
)
recv_forward
=
True
,
recv_backward
=
False
)
# Run 1F1B.
else
:
for
i
in
range
(
num_microbatches_in_minibatch
-
num_warmup_microbatches
):
input_tensor
=
None
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
# Forward model for one step.
input_tensors
,
output_tensors
,
timers
(
'forward'
).
start
()
losses_reduced
,
timers
)
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
backward_step_with_communication
(
timers
(
'forward'
).
stop
()
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
)
if
mpu
.
is_pipeline_last_stage
():
# Run cooldown backward passes.
loss
,
loss_reduced
=
output_tensor
for
i
in
range
(
num_warmup_microbatches
):
output_tensor
=
loss
backward_step_with_communication
(
losses_reduced
.
append
(
loss_reduced
)
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
)
else
:
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Run backward pass for all microbatches in minibatch.
for
i
in
range
(
num_microbatches_to_pipeline
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
if
mpu
.
is_pipeline_last_stage
():
output_grad_tensor
=
None
else
:
_
,
output_grad_tensor
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers
(
'backward'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_grad_tensor
)
timers
(
'backward'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
False
,
recv_backward
=
False
)
# All-reduce if needed.
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment