Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5489bda9
Commit
5489bda9
authored
Feb 13, 2021
by
Deepak Narayanan
Browse files
More comments and some cleanup (e.g., better variable names)
parent
626645c0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
36 deletions
+85
-36
megatron/model/transformer.py
megatron/model/transformer.py
+13
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+13
-9
megatron/p2p_communication.py
megatron/p2p_communication.py
+42
-10
megatron/schedules.py
megatron/schedules.py
+12
-11
megatron/training.py
megatron/training.py
+5
-5
No files found.
megatron/model/transformer.py
View file @
5489bda9
...
...
@@ -554,12 +554,24 @@ class ParallelTransformer(MegatronModule):
self_attn_mask_type
=
self_attn_mask_type
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size'
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
...
megatron/mpu/initialize.py
View file @
5489bda9
...
...
@@ -271,10 +271,8 @@ def get_pipeline_model_parallel_rank():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
and
\
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
!=
0
:
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
return
False
return
get_pipeline_model_parallel_rank
()
==
0
...
...
@@ -282,11 +280,11 @@ def is_pipeline_first_stage(ignore_virtual=False):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
and
\
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
!=
(
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
-
1
):
virtual_pipeline_model_parallel_world_size
=
\
get_virtual_pipeline_model_parallel_world_size
()
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
...
...
@@ -304,6 +302,12 @@ def set_virtual_pipeline_model_parallel_rank(rank):
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
...
...
megatron/p2p_communication.py
View file @
5489bda9
...
...
@@ -23,7 +23,24 @@ from megatron import mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
"""Communicate tensors between stages."""
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
...
...
@@ -50,6 +67,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
...
@@ -67,27 +85,32 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
else
:
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
tensor_recv_prev_before
=
tensor_recv_prev
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
...
...
@@ -101,6 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
def
recv_forward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
...
...
@@ -118,6 +142,7 @@ def recv_forward(timers=None, use_ring_exchange=False):
def
recv_backward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
...
...
@@ -135,6 +160,7 @@ def recv_backward(timers=None, use_ring_exchange=False):
def
send_forward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
...
...
@@ -149,6 +175,7 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
def
send_backward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
...
...
@@ -163,6 +190,7 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
...
...
@@ -180,6 +208,7 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
...
...
@@ -197,6 +226,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
...
...
@@ -211,6 +241,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
...
...
@@ -227,6 +258,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
...
...
megatron/schedules.py
View file @
5489bda9
...
...
@@ -136,19 +136,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
k
,
forward
):
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
k_in_group
=
k
%
(
pipeline_parallel_size
*
num_model_chunks
)
i
=
k
_in_group
//
pipeline_parallel_size
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id
_in_group
//
pipeline_parallel_size
if
not
forward
:
i
=
(
num_model_chunks
-
i
-
1
)
return
i
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
model_chunk_id
def
forward_step_helper
(
k
):
def
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
...
...
@@ -164,11 +164,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return
output_tensor
def
backward_step_helper
(
k
):
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
...
...
@@ -317,8 +317,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return
losses_reduced
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
...
...
megatron/training.py
View file @
5489bda9
...
...
@@ -50,7 +50,7 @@ from megatron.utils import unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
forward_backward_pipelining
from
megatron.schedules
import
forward_backward_pipelining
_without_interleaving
from
megatron.schedules
import
forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
...
...
@@ -340,7 +340,7 @@ def train_step(forward_step_func, data_iterator,
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining
forward_backward_func
=
forward_backward_pipelining
_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
losses_reduced
=
forward_backward_func
(
...
...
@@ -681,7 +681,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining
forward_backward_func
=
forward_backward_pipelining
_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
loss_dicts
=
forward_backward_func
(
...
...
@@ -692,8 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes.
for
loss_dict
in
loss_dicts
:
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
\
loss_dict
[
key
]
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
loss_dict
[
key
]
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment