Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3c92fa93
Commit
3c92fa93
authored
Mar 23, 2023
by
Jared Casper
Browse files
Move pipeline parallel functionality into core with associated changes.
parent
0b44909c
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
578 additions
and
325 deletions
+578
-325
examples/detxoify_lm/finetune_gpt.py
examples/detxoify_lm/finetune_gpt.py
+2
-1
megatron/core/enums.py
megatron/core/enums.py
+7
-0
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+41
-8
megatron/core/pipeline_parallel/__init__.py
megatron/core/pipeline_parallel/__init__.py
+1
-0
megatron/core/pipeline_parallel/p2p_communication.py
megatron/core/pipeline_parallel/p2p_communication.py
+187
-136
megatron/core/pipeline_parallel/schedules.py
megatron/core/pipeline_parallel/schedules.py
+282
-158
megatron/core/tensor_parallel/layers.py
megatron/core/tensor_parallel/layers.py
+6
-2
megatron/core/utils.py
megatron/core/utils.py
+15
-0
megatron/model/__init__.py
megatron/model/__init__.py
+0
-1
megatron/model/enums.py
megatron/model/enums.py
+0
-4
megatron/model/retro_transformer.py
megatron/model/retro_transformer.py
+2
-1
megatron/model/transformer.py
megatron/model/transformer.py
+2
-1
megatron/training.py
megatron/training.py
+22
-6
pretrain_bert.py
pretrain_bert.py
+2
-1
pretrain_gpt.py
pretrain_gpt.py
+2
-1
pretrain_ict.py
pretrain_ict.py
+1
-1
pretrain_retro.py
pretrain_retro.py
+2
-1
pretrain_t5.py
pretrain_t5.py
+2
-1
pretrain_vision_classify.py
pretrain_vision_classify.py
+1
-1
pretrain_vision_dino.py
pretrain_vision_dino.py
+1
-1
No files found.
examples/detxoify_lm/finetune_gpt.py
View file @
3c92fa93
...
@@ -17,7 +17,8 @@ from megatron import print_rank_0
...
@@ -17,7 +17,8 @@ from megatron import print_rank_0
from
megatron.core
import
mpu
from
megatron.core
import
mpu
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.blendable_dataset
import
BlendableDataset
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.model
import
GPTModel
from
megatron.core.enums
import
ModelType
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
megatron/core/enums.py
0 → 100644
View file @
3c92fa93
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
enum
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
megatron/core/parallel_state.py
View file @
3c92fa93
...
@@ -58,12 +58,40 @@ def initialize_model_parallel(
...
@@ -58,12 +58,40 @@ def initialize_model_parallel(
Initialize model data parallel groups.
Initialize model data parallel groups.
Arguments:
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
tensor_model_parallel_size (int, default = 1):
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
The number of GPUs to split individual tensors across.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_size (int, default = 1):
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
The number of tensor parallel GPU groups to split the
rank in pipeline with split point.
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...
@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
...
@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
def
set_pipeline_model_parallel_split_rank
(
rank
):
def
set_pipeline_model_parallel_split_rank
(
rank
):
"""Set pipeline model parallel split rank."""
"""Set pipeline model parallel split rank."""
global
_MPU
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
rank
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
rank
def
get_tensor_model_parallel_rank
():
def
get_tensor_model_parallel_rank
():
...
@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
...
@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
get_pipeline_model_parallel_split_rank
():
"""Return pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
...
...
megatron/core/pipeline_parallel/__init__.py
0 → 100644
View file @
3c92fa93
from
.schedules
import
get_forward_backward_func
megatron/p2p_communication.py
→
megatron/
core/pipeline_parallel/
p2p_communication.py
View file @
3c92fa93
...
@@ -2,15 +2,24 @@
...
@@ -2,15 +2,24 @@
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
operator
from
typing
import
Optional
,
List
,
Union
,
Callable
,
Tuple
import
torch
import
torch
from
megatron
import
get_args
,
core
from
megatron
import
core
from
megatron.core
import
mpu
from
megatron.core.parallel_state
import
(
get_pipeline_model_parallel_group
,
get_pipeline_model_parallel_prev_rank
,
get_pipeline_model_parallel_next_rank
,
)
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
def
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
):
recv_prev
,
recv_next
,
"""Communicate tensor shapes between stages. Used to communicate
use_ring_exchange_p2p
):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
This is required when the sequence lengths across micro batches
are not uniform.
are not uniform.
...
@@ -28,7 +37,6 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
...
@@ -28,7 +37,6 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
(recv_prev_shape, recv_next_shape)
(recv_prev_shape, recv_next_shape)
"""
"""
args
=
get_args
()
recv_prev_shape_tensor
=
None
recv_prev_shape_tensor
=
None
recv_next_shape_tensor
=
None
recv_next_shape_tensor
=
None
send_prev_shape_tensor
=
None
send_prev_shape_tensor
=
None
...
@@ -50,7 +58,7 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
...
@@ -50,7 +58,7 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
if
args
.
use_ring_exchange_p2p
:
if
use_ring_exchange_p2p
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
send_prev_shape_tensor
,
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
send_prev_shape_tensor
,
tensor_recv_prev
=
recv_prev_shape_tensor
,
tensor_recv_prev
=
recv_prev_shape_tensor
,
tensor_send_next
=
send_next_shape_tensor
,
tensor_send_next
=
send_next_shape_tensor
,
...
@@ -98,46 +106,70 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
...
@@ -98,46 +106,70 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
return
recv_prev_shape
,
recv_next_shape
return
recv_prev_shape
,
recv_next_shape
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate
(
*
,
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_shape
,
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
dtype_
=
None
):
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
dtype
:
Optional
[
torch
.
dtype
],
variable_seq_lengths
:
bool
=
False
,
use_ring_exchange_p2p
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Communicate tensors between stages. Used as helper method in other
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
Arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
tensor_send_next (torch.Tensor, optional):
set to None).
Tensor to send to next rank (no tensor sent if None)
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
tensor_send_prev (torch.Tensor, optional):
recv_prev: boolean for whether tensor should be received from
Tensor to send to prev rank (no tensor sent if None)
previous rank.
recv_next: boolean for whether tensor should be received from
recv_prev (boolean, required):
next rank.
whether tensor should be received from previous rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
recv_next (boolean, required):
the same shape).
whether tensor should be received from next rank.
dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype.
tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
dtype (torch.dtype, required if either recv_{prev,next} is True):
this must be the type of the tensors that will be
received, will typically be params_dtype, but in the case
of fp32 residual connections might be torch.float.
variable_seq_lengths (bool, optional, default=False):
Support for variable sequence lengths across
microbatches. Setting this communicates the size of
tensors during pipeline parallelism communication, because
of this extra overhead it should only be set if the
sequence length is not constant during training.
use_ring_exchange_p2p (bool, optional, default = False):
Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom
built torch with torch.distributed.ring_exchange.
Returns:
Returns:
(tensor_recv_prev, tensor_recv_next)
tuple containing
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
"""
"""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# Create placeholder tensors for receive in forward and backward directions
# if needed.
# if needed.
tensor_recv_prev
=
None
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_next
=
None
# Some legacy inference code doesn't set the tensor shape, do so now
if
not
variable_seq_lengths
:
# for the normal values for gpt/bert. This could be removed if inference
recv_prev_shape
=
tensor_shape
# code is changed to provide tensor_shape.
recv_next_shape
=
tensor_shape
if
not
args
.
variable_seq_lengths
:
if
tensor_shape
is
None
:
recv_prev_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
recv_next_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
else
:
recv_prev_shape
=
tensor_shape
recv_next_shape
=
tensor_shape
else
:
else
:
recv_prev_shape
,
recv_next_shape
=
\
recv_prev_shape
,
recv_next_shape
=
\
_communicate_shapes
(
tensor_send_next
,
_communicate_shapes
(
tensor_send_next
,
...
@@ -145,116 +177,81 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -145,116 +177,81 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
recv_prev
,
recv_prev
,
recv_next
)
recv_next
)
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
recv_prev_chunk_shape
=
reduce
(
operator
.
mul
,
recv_prev_shape
,
1
)
recv_next_chunk_shape
=
reduce
(
operator
.
mul
,
recv_next_shape
,
1
)
if
recv_prev_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
and
\
recv_next_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
recv_prev_chunk_shape
=
recv_prev_chunk_shape
//
\
mpu
.
get_tensor_model_parallel_world_size
()
recv_next_chunk_shape
=
recv_next_chunk_shape
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
recv_prev_chunk_shape
=
recv_prev_shape
recv_next_chunk_shape
=
recv_next_shape
override_scatter_gather_tensors_in_pipeline
=
True
else
:
recv_prev_chunk_shape
=
recv_prev_shape
recv_next_chunk_shape
=
recv_next_shape
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
requires_grad
=
True
if
dtype_
is
not
None
:
dtype
=
dtype_
requires_grad
=
False
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
recv_prev_chunk_shape
,
if
dtype
is
None
:
requires_grad
=
requires_grad
,
raise
RuntimeError
(
"dtype must be provided if recv_prev is True"
)
if
tensor_shape
is
None
:
raise
RuntimeError
(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev
=
torch
.
empty
(
recv_prev_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
if
recv_next
:
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
recv_next_chunk_shape
,
if
dtype
is
None
:
requires_grad
=
requires_grad
,
raise
RuntimeError
(
"dtype must be provided if recv_next is True"
)
if
tensor_shape
is
None
:
raise
RuntimeError
(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next
=
torch
.
empty
(
recv_next_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
core
.
tensor_parallel
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
core
.
tensor_parallel
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
# Send tensors in both the forward and backward directions as appropriate.
if
args
.
use_ring_exchange_p2p
:
if
use_ring_exchange_p2p
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
group
=
get_pipeline_model_parallel_group
())
else
:
else
:
ops
=
[]
ops
=
[]
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
for
req
in
reqs
:
req
.
wait
()
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
recv_prev
:
tensor_recv_prev
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
recv_prev_shape
).
requires_grad_
()
tensor_recv_prev
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_prev
,
requires_grad
=
True
,
keep_graph
=
False
)
if
recv_next
:
tensor_recv_next
=
core
.
tensor_parallel
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
recv_next_shape
).
requires_grad_
()
tensor_recv_next
=
core
.
utils
.
make_viewless_tensor
(
tensor_recv_next
,
requires_grad
=
True
,
keep_graph
=
False
)
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
def
recv_forward
(
tensor_shape
:
Shape
,
"""Receive tensor from previous rank in pipeline (forward receive)."""
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
""" Receive tensor from previous rank in pipeline (forward receive).
if
mpu
.
is_pipeline_first_stage
():
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
...
@@ -265,15 +262,20 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
...
@@ -265,15 +262,20 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
_
=
dtype
_
)
dtype
=
dtype
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
timers
(
'forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
recv_backward
(
tensor_shape
=
None
,
timers
=
None
):
def
recv_backward
(
tensor_shape
:
Shape
,
"""Receive tensor from next rank in pipeline (backward receive)."""
dtype
:
torch
.
dtype
,
if
mpu
.
is_pipeline_last_stage
():
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
...
@@ -283,16 +285,21 @@ def recv_backward(tensor_shape=None, timers=None):
...
@@ -283,16 +285,21 @@ def recv_backward(tensor_shape=None, timers=None):
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
)
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
output_tensor
,
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
"""Send tensor to next rank in pipeline (forward send)."""
timers
:
Callable
=
None
)
->
None
:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
if
not
mpu
.
is_pipeline_last_stage
():
if
not
core
.
parallel_state
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
,
log_level
=
2
).
start
()
timers
(
'forward-send'
,
log_level
=
2
).
start
()
_communicate
(
_communicate
(
...
@@ -300,15 +307,19 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
...
@@ -300,15 +307,19 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shap
e
,
tensor_shape
=
Non
e
,
dtype
_
=
dtype_
)
dtype
=
None
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
"""Send tensor to previous rank in pipeline (backward send)."""
timers
:
Callable
=
None
)
->
None
:
if
not
mpu
.
is_pipeline_first_stage
():
"""Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
"""
if
not
core
.
parallel_state
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send'
,
log_level
=
2
).
start
()
timers
(
'backward-send'
,
log_level
=
2
).
start
()
_communicate
(
_communicate
(
...
@@ -316,14 +327,21 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
...
@@ -316,14 +327,21 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
tensor_shape
=
None
,
dtype
=
None
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
None
,
timers
=
None
):
def
send_forward_recv_backward
(
output_tensor
:
torch
.
Tensor
,
"""Batched send and recv with next rank in pipeline."""
tensor_shape
:
Shape
,
if
mpu
.
is_pipeline_last_stage
():
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
...
@@ -333,15 +351,22 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
...
@@ -333,15 +351,22 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
)
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
def
send_backward_recv_forward
(
input_tensor_grad
:
torch
.
Tensor
,
"""Batched send and recv with previous rank in pipeline."""
tensor_shape
:
Shape
,
if
mpu
.
is_pipeline_first_stage
():
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
if
timers
is
not
None
:
if
timers
is
not
None
:
...
@@ -351,14 +376,22 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
...
@@ -351,14 +376,22 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
"""Batched recv from previous rank and send to next rank in pipeline."""
recv_prev
:
bool
,
tensor_shape
:
Shape
,
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
...
@@ -366,14 +399,22 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
...
@@ -366,14 +399,22 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
def
send_backward_recv_backward
(
input_tensor_grad
:
torch
.
Tensor
,
"""Batched recv from next rank and send to previous rank in pipeline."""
recv_next
:
bool
,
tensor_shape
:
Shape
,
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
"""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
,
log_level
=
2
).
start
()
timers
(
'backward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
...
@@ -381,16 +422,25 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None,
...
@@ -381,16 +422,25 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
output_tensor
:
torch
.
Tensor
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
input_tensor_grad
:
torch
.
Tensor
,
"""Batched send and recv with previous and next ranks in pipeline."""
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
"""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
,
timers
(
'forward-backward-send-forward-backward-recv'
,
log_level
=
2
).
start
()
log_level
=
2
).
start
()
...
@@ -399,7 +449,8 @@ def send_forward_backward_recv_forward_backward(
...
@@ -399,7 +449,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
return
input_tensor
,
output_tensor_grad
megatron/schedules.py
→
megatron/
core/pipeline_parallel/
schedules.py
View file @
3c92fa93
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
,
nullcontext
from
typing
import
Optional
,
List
,
Union
,
Callable
,
Any
import
torch
import
torch
from
torch.autograd.variable
import
Variable
from
torch.autograd.variable
import
Variable
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron.core
import
parallel_state
from
megatron
import
get_num_microbatches
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron
import
get_timers
from
megatron.core.enums
import
ModelType
from
megatron
import
p2p_communication
from
megatron.core.utils
import
get_attr_wrapped_model
,
get_model_type
from
megatron.core
import
mpu
from
megatron.utils
import
unwrap_model
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
get_forward_backward_func
():
def
get_forward_backward_func
():
args
=
get_args
()
"""Retrieves the appropriate forward_backward function given the
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
configuration of parallel_state.
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func
model (required): the actual model. A torch.nn.Module or, in the
case or iterleaving, a list of torch.nn.Module
num_microbatches (int, required):
The number of microbatches to go through
dtype (required when using pipeline parallelism): dtype used in
p2p communication, usually params_dtype
tensor_shape (required when using pipeline parallelism): Shape of
tensor. The tensor is expected to be 3D and its order of
dimension is supposed to be ``(sequence, batch, hidden)``.
decoder_seq_length (int, required for ModelType.encoder_and_decoder models):
Sequence length of the decoder portion, used to determine tensor shapes.
grad_scaler (optional, default=None): If using loss scaling,
this function should take the loss and return the scaled
loss. If None, no function is called on the loss.
sequence_parallel (optional, default=False):
Set to :obj:`True` for this function to handle sequence
length. When :obj:`True`, the sequence length on each tensor
model parallel rank is updated to
:math:`original\_sequence\_length /
tensor\_model\_parallel\_world\_size`.
TODO: Do we need this? Just roll into tensor_shape arg?
forward_only (optional, default=False): Perform only the forward step
timers (optional, default=None): TODO
collect_non_loss_data: TODO
"""
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
if
pipeline_model_parallel_size
>
1
:
if
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
forward_backward_func
=
forward_backward_pipelining_with_interleaving
assert
get_num_microbatches
()
%
\
args
.
pipeline_model_parallel_size
==
0
,
\
'number of microbatches (%d) is not divisible by pipeline-'
\
'model-parallel-size (%d) when using interleaved schedule'
%
(
get_num_microbatches
(),
args
.
pipeline_model_parallel_size
,
)
else
:
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
else
:
...
@@ -52,7 +119,7 @@ def deallocate_output_tensor(out):
...
@@ -52,7 +119,7 @@ def deallocate_output_tensor(out):
device
=
out
.
device
,
device
=
out
.
device
,
dtype
=
out
.
dtype
,
dtype
=
out
.
dtype
,
)
)
def
custom_backward
(
output
,
grad_output
):
def
custom_backward
(
output
,
grad_output
):
'''Directly call C++ autograd engine.
'''Directly call C++ autograd engine.
...
@@ -87,11 +154,15 @@ def custom_backward(output, grad_output):
...
@@ -87,11 +154,15 @@ def custom_backward(output, grad_output):
allow_unreachable
=
True
,
allow_unreachable
=
True
,
accumulate_grad
=
True
,
accumulate_grad
=
True
,
)
)
def
forward_step
(
forward_step_func
,
def
forward_step
(
forward_step_func
,
data_iterator
,
data_iterator
,
model
,
model
,
num_microbatches
,
input_tensor
,
input_tensor
,
forward_data_store
,
forward_data_store
,
timers
,
timers
,
...
@@ -102,25 +173,26 @@ def forward_step(forward_step_func,
...
@@ -102,25 +173,26 @@ def forward_step(forward_step_func,
passed-in input_tensor is used.
passed-in input_tensor is used.
Returns output tensor."""
Returns output tensor."""
args
=
get_args
()
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrap_output_tensor
=
False
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
unwrap_output_tensor
=
True
unwrapped_model
.
set_input_tensor
(
input_tensor
)
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
set_input_tensor
(
input_tensor
)
if
mpu
.
is_pipeline_last_stage
():
context_manager
=
torch
.
autocast
(
"cuda"
)
if
torch
.
is_autocast_enabled
()
else
nullcontext
()
with
context_manager
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
parallel_state
.
is_pipeline_last_stage
():
if
not
collect_non_loss_data
:
if
not
collect_non_loss_data
:
output_tensor
=
loss_func
(
output_tensor
)
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_
num_microbatches
()
output_tensor
=
loss
/
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
forward_data_store
.
append
(
loss_reduced
)
else
:
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
...
@@ -132,16 +204,18 @@ def forward_step(forward_step_func,
...
@@ -132,16 +204,18 @@ def forward_step(forward_step_func,
# If T5 model (or other model with encoder and decoder)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
# downstream as well.
if
mpu
.
is_pipeline_stage_after_split
()
and
\
model_type
=
get_model_type
(
model
)
args
.
model_type
==
ModelType
.
encoder_and_decoder
:
if
parallel_state
.
is_pipeline_stage_after_split
()
and
\
model_type
==
ModelType
.
encoder_and_decoder
:
return
[
output_tensor
,
input_tensor
[
-
1
]]
return
[
output_tensor
,
input_tensor
[
-
1
]]
if
unwrap_output_tensor
:
if
unwrap_output_tensor
:
return
output_tensor
return
output_tensor
return
[
output_tensor
]
return
[
output_tensor
]
def
backward_step
(
optimiz
er
,
input_tensor
,
output_tensor
,
def
backward_step
(
grad_scal
er
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
):
output_tensor_grad
,
model_type
,
timers
):
"""Backward step through passed-in output tensor.
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
If last stage, output_tensor_grad is None, otherwise gradient of loss
...
@@ -153,7 +227,6 @@ def backward_step(optimizer, input_tensor, output_tensor,
...
@@ -153,7 +227,6 @@ def backward_step(optimizer, input_tensor, output_tensor,
# NOTE: This code currently can handle at most one skip connection. It
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
# connections.
args
=
get_args
()
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
...
@@ -173,8 +246,8 @@ def backward_step(optimizer, input_tensor, output_tensor,
...
@@ -173,8 +246,8 @@ def backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad
=
[
output_tensor_grad
]
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
:
if
output_tensor_grad
[
0
]
is
None
and
grad_scaler
is
not
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
[
0
])
output_tensor
=
grad_scaler
(
output_tensor
[
0
])
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
# Collect the grad of the input_tensor.
...
@@ -189,9 +262,9 @@ def backward_step(optimizer, input_tensor, output_tensor,
...
@@ -189,9 +262,9 @@ def backward_step(optimizer, input_tensor, output_tensor,
# Handle single skip connection if it exists (encoder_hidden_state in
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
# model with encoder and decoder).
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
mpu
.
is_pipeline_stage_after_split
()
and
\
parallel_state
.
is_pipeline_stage_after_split
()
and
\
args
.
model_type
==
ModelType
.
encoder_and_decoder
:
model_type
==
ModelType
.
encoder_and_decoder
:
if
output_tensor_grad
[
1
]
is
not
None
:
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
if
unwrap_input_tensor_grad
:
...
@@ -211,16 +284,27 @@ def dummy_handler():
...
@@ -211,16 +284,27 @@ def dummy_handler():
pass
pass
def
forward_backward_no_pipelining
(
forward_step_func
,
def
forward_backward_no_pipelining
(
*
,
data_iterator
,
model
,
forward_step_func
,
optimizer
,
data_iterator
,
timers
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
forward_only
,
num_microbatches
:
int
,
collect_non_loss_data
=
False
):
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
# unused
tensor_shape
:
Optional
[
Shape
]
=
None
,
# unused
decoder_seq_length
:
Optional
[
int
]
=
None
,
# unused
grad_scaler
:
Callable
=
None
,
sequence_parallel
:
bool
=
False
,
# unused
forward_only
:
bool
=
False
,
timers
:
Callable
=
None
,
collect_non_loss_data
:
bool
=
False
):
"""Run forward and backward passes with no pipeline parallelism
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
(no inter-stage communication).
Returns dictionary with losses."""
Returns dictionary with losses.
See get_forward_backward_func() for argument details
"""
assert
len
(
model
)
==
1
assert
len
(
model
)
==
1
model
=
model
[
0
]
model
=
model
[
0
]
...
@@ -228,64 +312,86 @@ def forward_backward_no_pipelining(forward_step_func,
...
@@ -228,64 +312,86 @@ def forward_backward_no_pipelining(forward_step_func,
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
context_handler
=
model
.
no_sync
context_handler
=
model
.
no_sync
model_type
=
get_model_type
(
model
)
forward_data_store
=
[]
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
with
context_handler
():
for
i
in
range
(
get_
num_microbatches
()
-
1
):
for
i
in
range
(
num_microbatches
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimiz
er
,
input_tensor
,
output_tensor
,
backward_step
(
grad_scal
er
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
)
output_tensor_grad
,
model_type
,
timers
)
# Run computation for last microbatch out of context handler (want to
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimiz
er
,
input_tensor
,
output_tensor
,
backward_step
(
grad_scal
er
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
)
output_tensor_grad
,
model_type
,
timers
)
return
forward_data_store
return
forward_data_store
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
def
forward_backward_pipelining_with_interleaving
(
*
,
data_iterator
,
model
,
forward_step_func
,
optimizer
,
data_iterator
,
timers
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
forward_only
,
num_microbatches
:
int
,
collect_non_loss_data
=
False
):
dtype
:
torch
.
dtype
,
tensor_shape
:
Shape
,
decoder_seq_length
:
Optional
[
int
]
=
None
,
grad_scaler
:
Callable
=
None
,
sequence_parallel
:
bool
=
False
,
forward_only
:
bool
=
False
,
timers
:
Callable
=
None
,
collect_non_loss_data
:
bool
=
False
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
args
=
get_args
()
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
forward_data_store
=
[]
forward_data_store
=
[]
if
not
forward_only
:
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
num_microbatches
%
pipeline_parallel_size
!=
0
:
msg
=
f
'number of microbatches (
{
num_microbatches
}
) is not divisible by '
msg
+=
f
'pipeline-model-parallel-size (
{
pipeline_parallel_size
}
) '
msg
+=
'when using interleaved schedule'
raise
RuntimeError
(
msg
)
model_type
=
get_model_type
(
model
[
0
])
if
model_type
==
ModelType
.
encoder_and_decoder
:
raise
RuntimeError
(
"Interleaving is not supported with an encoder and decoder model."
)
if
decoder_seq_length
is
not
None
and
decoder_seq_length
!=
tensor_shape
[
0
]:
raise
RuntimeError
(
"Interleaving is not supported with a different decoder sequence length."
)
if
sequence_parallel
:
seq_length
,
batch_size
,
hidden
=
tensor_shape
tensor_shape
=
(
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
(),
batch_size
,
hidden
,
)
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
tensor_shape
=
(
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
# Compute number of warmup and remaining microbatches.
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_
num_microbatches
()
*
num_model_chunks
total_
num_microbatches
=
num_microbatches
*
num_model_chunks
all_warmup_microbatches
=
False
all_warmup_microbatches
=
False
if
forward_only
:
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
num_warmup_microbatches
=
total_
num_microbatches
else
:
else
:
# Run all forward passes and then all backward passes if number of
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# microbatches is just the number of pipeline stages.
...
@@ -293,8 +399,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -293,8 +399,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
# all workers, followed by more microbatches after depending on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
# immediately start with 1F1B).
if
get_
num_microbatches
()
==
pipeline_parallel_size
:
if
num_microbatches
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
num_warmup_microbatches
=
total_
num_microbatches
all_warmup_microbatches
=
True
all_warmup_microbatches
=
True
else
:
else
:
num_warmup_microbatches
=
\
num_warmup_microbatches
=
\
...
@@ -302,9 +408,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -302,9 +408,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
num_warmup_microbatches
+=
(
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
total_
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
total_
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
microbatch_id
,
forward
):
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
"""Helper method to get the model chunk ID given the iteration number."""
...
@@ -319,10 +425,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -319,10 +425,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# forward step
# forward step
if
mpu
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
...
@@ -330,7 +436,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -330,7 +436,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
output_tensor
=
forward_step
(
forward_step_func
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
num_microbatches
,
input_tensor
,
forward_data_store
,
forward_data_store
,
timers
,
timers
,
collect_non_loss_data
)
collect_non_loss_data
)
...
@@ -348,41 +455,42 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -348,41 +455,42 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimiz
er
,
backward_step
(
grad_scal
er
,
input_tensor
,
input_tensor
,
output_tensor
,
output_tensor
,
output_tensor_grad
,
output_tensor_grad
,
model_type
,
timers
)
timers
)
return
input_tensor_grad
return
input_tensor_grad
# Run warmup forward passes.
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
timers
=
timers
))
p2p_communication
.
recv_forward
(
tensor_shape
,
dtype
,
timers
=
timers
))
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
output_tensor
=
forward_step_helper
(
k
)
# Determine if tensor should be received from previous stage.
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
if
k
==
(
total_
num_microbatches
-
1
):
recv_prev
=
False
recv_prev
=
False
# Don't send tensor downstream if on last stage.
# Don't send tensor downstream if on last stage.
if
mpu
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# Send and receive tensors as appropriate (send tensors computed
...
@@ -391,20 +499,20 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -391,20 +499,20 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
not
all_warmup_microbatches
:
not
all_warmup_microbatches
:
input_tensor_grad
=
None
input_tensor_grad
=
None
recv_next
=
True
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
recv_next
=
False
input_tensor
,
output_tensor_grad
=
\
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
timers
=
timers
)
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
else
:
input_tensor
=
\
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
timers
=
timers
)
timers
=
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
deallocate_output_tensor
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
)
...
@@ -425,19 +533,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -425,19 +533,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
# Determine if current stage has anything to send in either direction,
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# Determine if peers are sending, and where in data structure to put
# received tensors.
# received tensors.
recv_prev
=
True
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
...
@@ -449,7 +557,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -449,7 +557,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
forward
=
True
)
forward
=
True
)
recv_next
=
True
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
...
@@ -470,7 +578,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -470,7 +578,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
p2p_communication
.
send_forward_backward_recv_forward_backward
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
timers
=
timers
)
deallocate_output_tensor
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
)
# Put input_tensor and output_tensor_grad in data structures in the
# Put input_tensor and output_tensor_grad in data structures in the
...
@@ -486,25 +594,29 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
...
@@ -486,25 +594,29 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
if
all_warmup_microbatches
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
total_
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
if
k
==
(
total_
num_microbatches
-
1
):
recv_next
=
False
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
timers
=
timers
))
timers
=
timers
))
return
forward_data_store
return
forward_data_store
def
get_tensor_shapes
(
*
,
def
get_tensor_shapes
(
rank
,
model_type
):
rank
:
int
,
model_type
:
ModelType
,
tensor_shape
:
Shape
,
decoder_seq_length
:
int
,
sequence_parallel
:
bool
):
# Determine right tensor sizes (based on position of rank with respect to split
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# Send two tensors if model is T5 and rank is in decoder stage:
...
@@ -513,48 +625,50 @@ def get_tensor_shapes(rank, model_type):
...
@@ -513,48 +625,50 @@ def get_tensor_shapes(rank, model_type):
# If model is T5 and rank is at the boundary:
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
# Otherwise, send one tensor (pre-transpose).
args
=
get_args
()
tensor_shapes
=
[]
tensor_shapes
=
[]
if
args
.
sequence_parallel
:
assert
(
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
len
(
tensor_shape
)
==
3
else
:
),
f
"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but
{
tensor_shape
}
"
seq_length
=
args
.
seq_length
seq_length
,
micro_batch_size
,
hidden_size
=
tensor_shape
if
sequence_parallel
:
seq_length
=
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
args
.
sequence_parallel
:
if
sequence_parallel
:
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
decoder_seq_length
=
decoder_seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
else
:
decoder_seq_length
=
args
.
decoder_seq_length
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
if
parallel_state
.
is_pipeline_stage_before_split
(
rank
):
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
else
:
else
:
tensor_shapes
.
append
((
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
else
:
else
:
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
return
tensor_shapes
return
tensor_shapes
def
recv_forward
(
tensor_shapes
,
timers
):
def
recv_forward
(
tensor_shapes
,
dtype
,
timers
):
input_tensors
=
[]
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
input_tensors
.
append
(
None
)
else
:
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
dtype
,
timers
=
timers
))
timers
=
timers
))
return
input_tensors
return
input_tensors
def
recv_backward
(
tensor_shapes
,
timers
):
def
recv_backward
(
tensor_shapes
,
dtype
,
timers
):
output_tensor_grads
=
[]
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
output_tensor_grads
.
append
(
None
)
else
:
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
dtype
,
timers
=
timers
))
timers
=
timers
))
return
output_tensor_grads
return
output_tensor_grads
...
@@ -565,7 +679,7 @@ def send_forward(output_tensors, tensor_shapes, timers):
...
@@ -565,7 +679,7 @@ def send_forward(output_tensors, tensor_shapes, timers):
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
continue
continue
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
,
timers
=
timers
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
...
@@ -574,10 +688,10 @@ def send_backward(input_tensor_grads, tensor_shapes, timers):
...
@@ -574,10 +688,10 @@ def send_backward(input_tensor_grads, tensor_shapes, timers):
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
continue
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
,
timers
=
timers
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
=
timers
)
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
timers
):
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
dtype
,
timers
):
if
not
isinstance
(
output_tensors
,
list
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
output_tensor_grads
=
[]
...
@@ -586,12 +700,12 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
...
@@ -586,12 +700,12 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
output_tensor_grads
.
append
(
None
)
output_tensor_grads
.
append
(
None
)
continue
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
timers
=
timers
)
output_tensor
,
tensor_shape
,
dtype
,
timers
=
timers
)
output_tensor_grads
.
append
(
output_tensor_grad
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
dtype
,
timers
):
if
not
isinstance
(
input_tensor_grads
,
list
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
input_tensors
=
[]
...
@@ -600,44 +714,55 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
...
@@ -600,44 +714,55 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
input_tensors
.
append
(
None
)
input_tensors
.
append
(
None
)
continue
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
timers
=
timers
)
input_tensor_grad
,
tensor_shape
,
dtype
,
timers
=
timers
)
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
def
forward_backward_pipelining_without_interleaving
(
*
,
forward_step_func
,
data_iterator
,
data_iterator
,
model
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
optimizer
,
num_microbatches
:
int
,
timers
,
dtype
:
torch
.
dtype
,
forward_only
,
tensor_shape
:
Shape
,
collect_non_loss_data
=
False
):
decoder_seq_length
:
Optional
[
int
]
=
None
,
grad_scaler
:
Callable
=
None
,
sequence_parallel
:
bool
=
False
,
forward_only
:
bool
=
False
,
timers
:
Callable
=
None
,
collect_non_loss_data
:
bool
=
False
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
args
=
get_args
()
assert
len
(
model
)
==
1
assert
len
(
model
)
==
1
model
=
model
[
0
]
model
=
model
[
0
]
# Compute number of warmup microbatches.
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_warmup_microbatches
,
num_microbatches
)
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
num_microbatches
-
num_warmup_microbatches
unwrapped_model
=
unwrap_model
(
model_type
=
get_model_type
(
model
)
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
model_type
=
unwrapped_model
.
model_type
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
rank
=
mpu
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
-
1
,
recv_tensor_shapes
=
get_tensor_shapes
(
rank
-
1
,
model_type
)
model_type
=
model_type
,
send_tensor_shapes
=
get_tensor_shapes
(
rank
,
model_type
)
tensor_shape
=
tensor_shape
,
decoder_seq_length
=
decoder_seq_length
,
sequence_parallel
=
sequence_parallel
)
send_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
,
model_type
=
model_type
,
tensor_shape
=
tensor_shape
,
decoder_seq_length
=
decoder_seq_length
,
sequence_parallel
=
sequence_parallel
)
# Input, output tensors only need to be saved when doing backward passes
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
input_tensors
=
None
...
@@ -649,10 +774,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -649,10 +774,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
if
not
forward_only
:
...
@@ -664,25 +787,26 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -664,25 +787,26 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# If all microbatches are run in warmup / cooldown phase, then no need to
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
num_microbatches_remaining
>
0
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
input_tensor
,
forward_data_store
,
timers
,
collect_non_loss_data
)
timers
,
collect_non_loss_data
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
last_iteration
:
if
not
last_iteration
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
else
:
else
:
output_tensor_grad
=
\
output_tensor_grad
=
\
send_forward_recv_backward
(
output_tensor
,
send_forward_recv_backward
(
output_tensor
,
send_tensor_shapes
,
send_tensor_shapes
,
dtype
,
timers
=
timers
)
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
# Add input_tensor and output_tensor to end of list.
...
@@ -696,8 +820,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -696,8 +820,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimiz
er
,
input_tensor
,
output_tensor
,
backward_step
(
grad_scal
er
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
)
output_tensor_grad
,
model_type
,
timers
)
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
...
@@ -705,7 +829,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -705,7 +829,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
else
:
else
:
input_tensor
=
\
input_tensor
=
\
send_backward_recv_forward
(
send_backward_recv_forward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
input_tensor_grad
,
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
# Run cooldown backward passes.
# Run cooldown backward passes.
if
not
forward_only
:
if
not
forward_only
:
...
@@ -713,11 +837,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
...
@@ -713,11 +837,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
timers
=
timers
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
dtype
,
timers
=
timers
)
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimiz
er
,
input_tensor
,
output_tensor
,
backward_step
(
grad_scal
er
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
timers
)
output_tensor_grad
,
model_type
,
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
...
...
megatron/core/tensor_parallel/layers.py
View file @
3c92fa93
...
@@ -13,6 +13,8 @@ import torch.nn.functional as F
...
@@ -13,6 +13,8 @@ import torch.nn.functional as F
import
torch.nn.init
as
init
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
"""See linear_with_grad_accumulation_and_async_allreduce"""
@
staticmethod
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel
):
async_grad_allreduce
,
sequence_parallel
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
...
@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return
output
return
output
@
staticmethod
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
...
@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
...
@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
"maximum speedup"
)
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
...
megatron/core/utils.py
View file @
3c92fa93
...
@@ -20,6 +20,21 @@ def divide(numerator, denominator):
...
@@ -20,6 +20,21 @@ def divide(numerator, denominator):
ensure_divisibility
(
numerator
,
denominator
)
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
return
numerator
//
denominator
def
get_attr_wrapped_model
(
model
,
attr
):
"""Get an attribute from a wrapped model"""
if
isinstance
(
model
,
list
):
raise
RuntimeError
(
"_get_attr_wrapped_model given a list of models"
)
while
not
hasattr
(
model
,
attr
):
if
not
hasattr
(
model
,
"module"
):
raise
RuntimeError
(
f
"_get_attr_wrapped_model couldn't find attribute
{
attr
}
"
)
model
=
model
.
module
return
getattr
(
model
,
attr
)
def
get_model_type
(
model
):
return
get_attr_wrapped_model
(
model
,
'model_type'
)
class
GlobalMemoryBuffer
:
class
GlobalMemoryBuffer
:
"""Global buffer to avoid dynamic memory allocations.
"""Global buffer to avoid dynamic memory allocations.
...
...
megatron/model/__init__.py
View file @
3c92fa93
...
@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
...
@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
from
.t5_model
import
T5Model
from
.t5_model
import
T5Model
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
Float16Module
from
.module
import
Float16Module
from
.enums
import
ModelType
megatron/model/enums.py
View file @
3c92fa93
...
@@ -2,10 +2,6 @@
...
@@ -2,10 +2,6 @@
import
enum
import
enum
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
class
LayerType
(
enum
.
Enum
):
class
LayerType
(
enum
.
Enum
):
encoder
=
1
encoder
=
1
decoder
=
2
decoder
=
2
...
...
megatron/model/retro_transformer.py
View file @
3c92fa93
...
@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
...
@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core
import
utils
as
core_utils
from
megatron.core
import
utils
as
core_utils
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.core.enums
import
ModelType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
...
...
megatron/model/transformer.py
View file @
3c92fa93
...
@@ -9,7 +9,8 @@ import torch.nn.functional as F
...
@@ -9,7 +9,8 @@ import torch.nn.functional as F
from
megatron
import
get_timers
,
get_args
,
core
,
get_num_microbatches
from
megatron
import
get_timers
,
get_args
,
core
,
get_num_microbatches
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.core.enums
import
ModelType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
...
...
megatron/training.py
View file @
3c92fa93
...
@@ -25,8 +25,8 @@ from megatron import print_rank_last
...
@@ -25,8 +25,8 @@ from megatron import print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.core.enums
import
ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
...
@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
...
@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.
schedu
le
s
import
get_forward_backward_func
from
megatron.
core.pipeline_paral
le
l
import
get_forward_backward_func
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
...
@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
...
@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
return
model
,
optimizer
,
opt_param_scheduler
return
model
,
optimizer
,
opt_param_scheduler
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
):
model
,
optimizer
,
opt_param_scheduler
):
"""Single training step."""
"""Single training step."""
...
@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
...
@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
fwd_bwd_timers
=
timers
if
args
.
timing_log_level
>
1
else
None
fwd_bwd_timers
=
timers
if
args
.
timing_log_level
>
1
else
None
losses_reduced
=
forward_backward_func
(
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
forward_step_func
=
forward_step_func
,
optimizer
,
fwd_bwd_timers
,
forward_only
=
False
)
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
dtype
=
args
.
params_dtype
,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
grad_scaler
=
optimizer
.
scale_loss
,
sequence_parallel
=
args
.
sequence_parallel
,
forward_only
=
False
,
timers
=
fwd_bwd_timers
)
timers
(
'forward-backward'
).
stop
()
timers
(
'forward-backward'
).
stop
()
# Empty unused memory.
# Empty unused memory.
...
@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
...
@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
loss_dicts
=
forward_backward_func
(
loss_dicts
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
forward_step_func
=
forward_step_func
,
timers
=
None
,
forward_only
=
True
)
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
dtype
=
args
.
params_dtype
,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
sequence_parallel
=
args
.
sequence_parallel
,
forward_only
=
True
,
timers
=
None
)
# Empty unused memory
# Empty unused memory
if
args
.
empty_unused_memory_level
>=
1
:
if
args
.
empty_unused_memory_level
>=
1
:
...
...
pretrain_bert.py
View file @
3c92fa93
...
@@ -11,8 +11,9 @@ from megatron import get_args
...
@@ -11,8 +11,9 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
,
ModelType
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_gpt.py
View file @
3c92fa93
...
@@ -9,8 +9,9 @@ from megatron import print_rank_0
...
@@ -9,8 +9,9 @@ from megatron import print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_ict.py
View file @
3c92fa93
...
@@ -13,9 +13,9 @@ from megatron import get_args
...
@@ -13,9 +13,9 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron.core
import
mpu
from
megatron.core
import
mpu
from
megatron.core.enums
import
ModelType
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
ModelType
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_retro.py
View file @
3c92fa93
...
@@ -10,7 +10,8 @@ from megatron import get_timers
...
@@ -10,7 +10,8 @@ from megatron import get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
tools.retro.pretraining.retro_dataset
import
get_retro_datasets
from
tools.retro.pretraining.retro_dataset
import
get_retro_datasets
...
...
pretrain_t5.py
View file @
3c92fa93
...
@@ -12,8 +12,9 @@ from megatron import (
...
@@ -12,8 +12,9 @@ from megatron import (
print_rank_0
print_rank_0
)
)
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
T5Model
,
ModelType
from
megatron.model
import
T5Model
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
pretrain_vision_classify.py
View file @
3c92fa93
...
@@ -6,8 +6,8 @@ import torch
...
@@ -6,8 +6,8 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron.core.enums
import
ModelType
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
MitClassificationModel
from
megatron.model.vision.classification
import
MitClassificationModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
...
pretrain_vision_dino.py
View file @
3c92fa93
...
@@ -7,6 +7,7 @@ import numpy as np
...
@@ -7,6 +7,7 @@ import numpy as np
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
functools
import
partial
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
print_rank_0
from
megatron.core.enums
import
ModelType
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.dino
import
DINOPretrainModel
from
megatron.model.vision.dino
import
DINOPretrainModel
from
megatron.model.vision.knn_monitor
import
knn_predict
,
get_feature_bank
from
megatron.model.vision.knn_monitor
import
knn_predict
,
get_feature_bank
...
@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
...
@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment