Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3aca1415
Commit
3aca1415
authored
Apr 29, 2024
by
liangjing
Browse files
Merge branch 'megatron-lm_dtk24.04' into 'main'
Megatron lm dtk24.04 See merge request
!1
parents
0024a5c6
1005e9d3
Pipeline
#1806
passed with stage
Changes
204
Pipelines
3
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2699 additions
and
985 deletions
+2699
-985
megatron/core/package_info.py
megatron/core/package_info.py
+29
-0
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+152
-67
megatron/core/pipeline_parallel/p2p_communication.py
megatron/core/pipeline_parallel/p2p_communication.py
+311
-196
megatron/core/pipeline_parallel/schedules.py
megatron/core/pipeline_parallel/schedules.py
+759
-363
megatron/core/requirements.txt
megatron/core/requirements.txt
+1
-0
megatron/core/tensor_parallel/__init__.py
megatron/core/tensor_parallel/__init__.py
+9
-18
megatron/core/tensor_parallel/cross_entropy.py
megatron/core/tensor_parallel/cross_entropy.py
+17
-18
megatron/core/tensor_parallel/data.py
megatron/core/tensor_parallel/data.py
+12
-13
megatron/core/tensor_parallel/layers.py
megatron/core/tensor_parallel/layers.py
+427
-244
megatron/core/tensor_parallel/mappings.py
megatron/core/tensor_parallel/mappings.py
+30
-26
megatron/core/tensor_parallel/random.py
megatron/core/tensor_parallel/random.py
+18
-22
megatron/core/tensor_parallel/utils.py
megatron/core/tensor_parallel/utils.py
+23
-18
megatron/core/transformer/__init__.py
megatron/core/transformer/__init__.py
+3
-0
megatron/core/transformer/attention.py
megatron/core/transformer/attention.py
+368
-0
megatron/core/transformer/custom_layers/__init__.py
megatron/core/transformer/custom_layers/__init__.py
+0
-0
megatron/core/transformer/custom_layers/transformer_engine.py
...tron/core/transformer/custom_layers/transformer_engine.py
+249
-0
megatron/core/transformer/dot_product_attention.py
megatron/core/transformer/dot_product_attention.py
+165
-0
megatron/core/transformer/enums.py
megatron/core/transformer/enums.py
+25
-0
megatron/core/transformer/identity_op.py
megatron/core/transformer/identity_op.py
+14
-0
megatron/core/transformer/mlp.py
megatron/core/transformer/mlp.py
+87
-0
No files found.
megatron/core/package_info.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
MAJOR
=
0
MINOR
=
3
PATCH
=
0
PRE_RELEASE
=
''
# Use the following formatting: (major, minor, patch, pre-release)
VERSION
=
(
MAJOR
,
MINOR
,
PATCH
,
PRE_RELEASE
)
__shortversion__
=
'.'
.
join
(
map
(
str
,
VERSION
[:
3
]))
__version__
=
'.'
.
join
(
map
(
str
,
VERSION
[:
3
]))
+
''
.
join
(
VERSION
[
3
:])
__package_name__
=
'megatron_core'
__contact_names__
=
'NVIDIA'
__contact_emails__
=
'nemo-toolkit@nvidia.com'
# use NeMo Email
__homepage__
=
(
'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/'
# use NeMo homepage
)
__repository_url__
=
'https://github.com/NVIDIA/Megatron-LM/megatron/core'
__download_url__
=
'https://github.com/NVIDIA/Megatron-LM/releases'
__description__
=
(
'Megatron Core - a library for efficient and scalable training of transformer based models'
)
__license__
=
'BSD-3'
__keywords__
=
(
'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch'
)
megatron/core/parallel_state.py
View file @
3aca1415
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
"""Model and data parallel groups."""
"""Model and data parallel groups."""
import
torch
import
os
from
typing
import
Optional
from
typing
import
Optional
import
torch
from
.utils
import
GlobalMemoryBuffer
from
.utils
import
GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
# Intra-layer model parallel group that the current rank belongs to.
...
@@ -19,6 +21,9 @@ _EMBEDDING_GROUP = None
...
@@ -19,6 +21,9 @@ _EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP
=
None
_POSITION_EMBEDDING_GROUP
=
None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP_GLOO
=
None
# FP8 amax reduction group.
_AMAX_REDUCTION_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
...
@@ -53,9 +58,10 @@ def initialize_model_parallel(
...
@@ -53,9 +58,10 @@ def initialize_model_parallel(
pipeline_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
,
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
use_sharp
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""
"""Initialize model data parallel groups.
Initialize model data parallel groups.
Arguments:
Arguments:
tensor_model_parallel_size (int, default = 1):
tensor_model_parallel_size (int, default = 1):
...
@@ -93,6 +99,17 @@ def initialize_model_parallel(
...
@@ -93,6 +99,17 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
will be the encoder and ranks 3-7 will be the decoder.
use_fp8 (bool, default = False):
Construct GPU groups needed for FP8 training, namely for
amax reduction across the product of the data-parallel and
tensor-parallel groups.
use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
within each data-parallel process group, which specifies
the SHARP application target groups.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
the model pipeline. The present function will
...
@@ -108,6 +125,7 @@ def initialize_model_parallel(
...
@@ -108,6 +125,7 @@ def initialize_model_parallel(
are on the same DGX box. For example if we are using 2 DGX-1 boxes
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
ranks 8 to 15 belong to the second box.
"""
"""
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
...
@@ -119,17 +137,19 @@ def initialize_model_parallel(
...
@@ -119,17 +137,19 @@ def initialize_model_parallel(
f
"(
{
tensor_model_parallel_size
}
) x pipeline_model_parallel_size (
{
pipeline_model_parallel_size
}
)"
f
"(
{
tensor_model_parallel_size
}
) x pipeline_model_parallel_size (
{
pipeline_model_parallel_size
}
)"
)
)
data_parallel_size
:
int
=
world_size
//
(
tensor_model_parallel_size
*
data_parallel_size
:
int
=
world_size
//
(
pipeline_model_parallel_size
)
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
:
int
=
world_size
//
data_parallel_size
num_data_parallel_groups
:
int
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
not
pipeline_model_parallel_size
>
2
:
if
not
pipeline_model_parallel_size
>
2
:
raise
RuntimeError
(
"pipeline-model-parallel size should be greater than 2 with "
raise
RuntimeError
(
"interleaved schedule"
)
"pipeline-model-parallel size should be greater than 2 with interleaved schedule"
)
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
...
@@ -143,6 +163,7 @@ def initialize_model_parallel(
...
@@ -143,6 +163,7 @@ def initialize_model_parallel(
# Build the data-parallel groups.
# Build the data-parallel groups.
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP_GLOO
global
_DATA_PARALLEL_GLOBAL_RANKS
global
_DATA_PARALLEL_GLOBAL_RANKS
assert
_DATA_PARALLEL_GROUP
is
None
,
'data parallel group is already initialized'
assert
_DATA_PARALLEL_GROUP
is
None
,
'data parallel group is already initialized'
all_data_parallel_group_ranks
=
[]
all_data_parallel_group_ranks
=
[]
...
@@ -153,27 +174,50 @@ def initialize_model_parallel(
...
@@ -153,27 +174,50 @@ def initialize_model_parallel(
ranks
=
range
(
start_rank
+
j
,
end_rank
,
tensor_model_parallel_size
)
ranks
=
range
(
start_rank
+
j
,
end_rank
,
tensor_model_parallel_size
)
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group_gloo
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP_GLOO
=
group_gloo
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
# Apply SHARP to DP process groups
if
use_sharp
:
if
rank
==
0
:
print
(
"The number of process groups to use SHARP with depends on the type "
"of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
"process groups and QM2 supports up to 256 process groups. We apply "
"SHARP to the communications of the data-parallel domain. If the "
"number of data-parallel process groups is larger than the max "
"process groups that the network switch supports, the communication "
"will fall back to non-SHARP operators. To enable SHARP, "
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
torch
.
distributed
.
barrier
(
group
=
get_data_parallel_group
(),
device_ids
=
[
torch
.
cuda
.
current_device
()]
)
# Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups
os
.
environ
[
"NCCL_SHARP_DISABLE"
]
=
"1"
# Build the model-parallel groups.
# Build the model-parallel groups.
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
'model parallel group is already initialized'
assert
_MODEL_PARALLEL_GROUP
is
None
,
'model parallel group is already initialized'
for
i
in
range
(
data_parallel_size
):
for
i
in
range
(
data_parallel_size
):
ranks
=
[
data_parallel_group_ranks
[
i
]
ranks
=
[
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
_MODEL_PARALLEL_GROUP
=
group
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TENSOR_MODEL_PARALLEL_GROUP
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
None
,
\
assert
(
'tensor model parallel group is already initialized'
_TENSOR_MODEL_PARALLEL_GROUP
is
None
),
'tensor model parallel group is already initialized'
for
i
in
range
(
num_tensor_model_parallel_groups
):
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
(
i
+
1
)
*
tensor_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TENSOR_MODEL_PARALLEL_GROUP
=
group
...
@@ -182,15 +226,15 @@ def initialize_model_parallel(
...
@@ -182,15 +226,15 @@ def initialize_model_parallel(
# (first and last rank in each pipeline model-parallel group).
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
\
assert
(
'pipeline model parallel group is already initialized'
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
),
'pipeline model parallel group is already initialized'
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
'embedding group is already initialized'
assert
_EMBEDDING_GROUP
is
None
,
'embedding group is already initialized'
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
_POSITION_EMBEDDING_GROUP
is
None
,
\
assert
_POSITION_EMBEDDING_GROUP
is
None
,
'position embedding group is already initialized'
'position embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
...
@@ -204,12 +248,13 @@ def initialize_model_parallel(
...
@@ -204,12 +248,13 @@ def initialize_model_parallel(
position_embedding_ranks
=
[
ranks
[
0
]]
position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank
is
not
None
:
if
pipeline_model_parallel_split_rank
is
not
None
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
embedding_ranks
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
embedding_ranks
=
[
ranks
[
pipeline_model_parallel_split_rank
],
ranks
[
0
],
ranks
[
-
1
]]
ranks
[
pipeline_model_parallel_split_rank
],
ranks
[
-
1
],
]
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
position_embedding_ranks
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
position_embedding_ranks
:
position_embedding_ranks
=
[
ranks
[
0
],
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
]]
ranks
[
pipeline_model_parallel_split_rank
]]
else
:
else
:
embedding_ranks
=
ranks
embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
...
@@ -226,6 +271,20 @@ def initialize_model_parallel(
...
@@ -226,6 +271,20 @@ def initialize_model_parallel(
if
rank
in
ranks
:
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
# Build the FP8 groups.
global
_AMAX_REDUCTION_GROUP
assert
_AMAX_REDUCTION_GROUP
is
None
,
'FP8 amax reduction group is already initialized'
if
use_fp8
:
amax_group_size
:
int
=
tensor_model_parallel_size
*
data_parallel_size
num_amax_groups
:
int
=
world_size
//
amax_group_size
for
i
in
range
(
num_amax_groups
):
start_rank
=
i
*
amax_group_size
end_rank
=
(
i
+
1
)
*
amax_group_size
ranks
=
range
(
start_rank
,
end_rank
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
_AMAX_REDUCTION_GROUP
=
group
# Initialize global memory buffer
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# put this. If we end up with a more generic initialization of megatron-core
...
@@ -240,55 +299,68 @@ def is_unitialized():
...
@@ -240,55 +299,68 @@ def is_unitialized():
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
"""Check if model and data parallel groups are initialized."""
if
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
\
if
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
\
_TENSOR_MODEL_PARALLEL_GROUP
is
None
_DATA_PARALLEL_GROUP
is
None
:
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
):
return
False
return
False
return
True
return
True
def
get_model_parallel_group
():
def
get_model_parallel_group
():
"""Get the model parallel group the caller rank belongs to."""
"""Get the model parallel group the caller rank belongs to."""
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
\
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
'model parallel group is not initialized'
'model parallel group is not initialized'
return
_MODEL_PARALLEL_GROUP
return
_MODEL_PARALLEL_GROUP
def
get_tensor_model_parallel_group
():
def
get_tensor_model_parallel_group
(
check_initialized
=
True
):
"""Get the tensor model parallel group the caller rank belongs to."""
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
,
\
if
check_initialized
:
'intra_layer_model parallel group is not initialized'
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
),
'tensor model parallel group is not initialized'
return
_TENSOR_MODEL_PARALLEL_GROUP
return
_TENSOR_MODEL_PARALLEL_GROUP
def
get_pipeline_model_parallel_group
():
def
get_pipeline_model_parallel_group
():
"""Get the pipeline model parallel group the caller rank belongs to."""
"""Get the pipeline model parallel group the caller rank belongs to."""
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
,
\
assert
(
'pipeline_model parallel group is not initialized'
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
),
'pipeline_model parallel group is not initialized'
return
_PIPELINE_MODEL_PARALLEL_GROUP
return
_PIPELINE_MODEL_PARALLEL_GROUP
def
get_data_parallel_group
():
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
"""Get the data parallel group the caller rank belongs to."""
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
assert
_DATA_PARALLEL_GROUP
is
not
None
,
'data parallel group is not initialized'
'data parallel group is not initialized'
return
_DATA_PARALLEL_GROUP
return
_DATA_PARALLEL_GROUP
def
get_data_parallel_group_gloo
():
"""Get the data parallel group-gloo the caller rank belongs to."""
assert
_DATA_PARALLEL_GROUP_GLOO
is
not
None
,
'data parallel group-gloo is not initialized'
return
_DATA_PARALLEL_GROUP_GLOO
def
get_embedding_group
():
def
get_embedding_group
():
"""Get the embedding group the caller rank belongs to."""
"""Get the embedding group the caller rank belongs to."""
assert
_EMBEDDING_GROUP
is
not
None
,
\
assert
_EMBEDDING_GROUP
is
not
None
,
'embedding group is not initialized'
'embedding group is not initialized'
return
_EMBEDDING_GROUP
return
_EMBEDDING_GROUP
def
get_position_embedding_group
():
def
get_position_embedding_group
():
"""Get the position embedding group the caller rank belongs to."""
"""Get the position embedding group the caller rank belongs to."""
assert
_POSITION_EMBEDDING_GROUP
is
not
None
,
\
assert
_POSITION_EMBEDDING_GROUP
is
not
None
,
'position embedding group is not initialized'
'position embedding group is not initialized'
return
_POSITION_EMBEDDING_GROUP
return
_POSITION_EMBEDDING_GROUP
def
get_amax_reduction_group
():
"""Get the FP8 amax reduction group the caller rank belongs to."""
assert
_AMAX_REDUCTION_GROUP
is
not
None
,
'FP8 amax reduction group is not initialized'
return
_AMAX_REDUCTION_GROUP
def
set_tensor_model_parallel_world_size
(
world_size
):
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
@@ -301,6 +373,12 @@ def set_pipeline_model_parallel_world_size(world_size):
...
@@ -301,6 +373,12 @@ def set_pipeline_model_parallel_world_size(world_size):
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
set_virtual_pipeline_model_parallel_world_size
(
world_size
):
"""Set the pipeline model parallel size"""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
get_tensor_model_parallel_world_size
():
def
get_tensor_model_parallel_world_size
():
"""Return world size for the tensor model parallel group."""
"""Return world size for the tensor model parallel group."""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
@@ -360,8 +438,10 @@ def get_pipeline_model_parallel_split_rank():
...
@@ -360,8 +438,10 @@ def get_pipeline_model_parallel_split_rank():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
if
(
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
0
):
return
False
return
False
return
get_pipeline_model_parallel_rank
()
==
0
return
get_pipeline_model_parallel_rank
()
==
0
...
@@ -369,14 +449,14 @@ def is_pipeline_first_stage(ignore_virtual=False):
...
@@ -369,14 +449,14 @@ def is_pipeline_first_stage(ignore_virtual=False):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
\
virtual_pipeline_model_parallel_world_size
=
(
get_virtual_pipeline_model_parallel_world_size
()
get_virtual_pipeline_model_parallel_world_size
()
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
)
get_virtual_pipeline_model_parallel_rank
()
!=
(
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
False
return
get_pipeline_model_parallel_rank
()
==
(
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
get_pipeline_model_parallel_world_size
()
-
1
)
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
...
@@ -437,8 +517,7 @@ def is_pipeline_stage_at_split():
...
@@ -437,8 +517,7 @@ def is_pipeline_stage_at_split():
stage executes encoder block for a model with both encoder and
stage executes encoder block for a model with both encoder and
decoder."""
decoder."""
rank
=
get_pipeline_model_parallel_rank
()
rank
=
get_pipeline_model_parallel_rank
()
return
is_pipeline_stage_before_split
(
rank
)
and
\
return
is_pipeline_stage_before_split
(
rank
)
and
is_pipeline_stage_after_split
(
rank
+
1
)
is_pipeline_stage_after_split
(
rank
+
1
)
def
get_virtual_pipeline_model_parallel_rank
():
def
get_virtual_pipeline_model_parallel_rank
():
...
@@ -459,12 +538,6 @@ def get_virtual_pipeline_model_parallel_world_size():
...
@@ -459,12 +538,6 @@ def get_virtual_pipeline_model_parallel_world_size():
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
set_virtual_pipeline_model_parallel_world_size
(
world_size
):
"""Set the virtual pipeline-parallel world size"""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
get_tensor_model_parallel_src_rank
():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
in the tensor model parallel group."""
...
@@ -476,31 +549,28 @@ def get_tensor_model_parallel_src_rank():
...
@@ -476,31 +549,28 @@ def get_tensor_model_parallel_src_rank():
def
get_data_parallel_src_rank
():
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
in the data parallel group."""
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
\
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
"Data parallel group is not initialized"
"Data parallel group is not initialized"
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
"""Return the global rank of the first process in the pipeline for the
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
current tensor parallel group"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
def
get_pipeline_model_parallel_last_rank
():
"""Return the global rank of the last process in the pipeline for the
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
current tensor parallel group"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
def
get_pipeline_model_parallel_next_rank
():
"""Return the global rank that follows the caller in the pipeline"""
"""Return the global rank that follows the caller in the pipeline"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
...
@@ -508,8 +578,7 @@ def get_pipeline_model_parallel_next_rank():
...
@@ -508,8 +578,7 @@ def get_pipeline_model_parallel_next_rank():
def
get_pipeline_model_parallel_prev_rank
():
def
get_pipeline_model_parallel_prev_rank
():
"""Return the global rank that preceeds the caller in the pipeline"""
"""Return the global rank that preceeds the caller in the pipeline"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
...
@@ -517,12 +586,19 @@ def get_pipeline_model_parallel_prev_rank():
...
@@ -517,12 +586,19 @@ def get_pipeline_model_parallel_prev_rank():
def
get_data_parallel_world_size
():
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
"""Return world size for the data parallel group."""
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
else
:
return
0
def
get_data_parallel_rank
():
def
get_data_parallel_rank
():
"""Return my rank for the data parallel group."""
"""Return my rank for the data parallel group."""
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
else
:
return
0
def
_set_global_memory_buffer
():
def
_set_global_memory_buffer
():
"""Initialize global buffer"""
"""Initialize global buffer"""
...
@@ -530,12 +606,19 @@ def _set_global_memory_buffer():
...
@@ -530,12 +606,19 @@ def _set_global_memory_buffer():
assert
_GLOBAL_MEMORY_BUFFER
is
None
,
'global memory buffer is already initialized'
assert
_GLOBAL_MEMORY_BUFFER
is
None
,
'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER
=
GlobalMemoryBuffer
()
_GLOBAL_MEMORY_BUFFER
=
GlobalMemoryBuffer
()
def
get_global_memory_buffer
():
def
get_global_memory_buffer
():
"""Return the global GlobalMemoryBuffer object"""
"""Return the global GlobalMemoryBuffer object"""
assert
_GLOBAL_MEMORY_BUFFER
is
not
None
,
'global memory buffer is not initialized'
assert
_GLOBAL_MEMORY_BUFFER
is
not
None
,
'global memory buffer is not initialized'
return
_GLOBAL_MEMORY_BUFFER
return
_GLOBAL_MEMORY_BUFFER
def
destroy_global_memory_buffer
():
"""Sets the global memory buffer to None"""
global
_GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER
=
None
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
...
@@ -550,6 +633,8 @@ def destroy_model_parallel():
...
@@ -550,6 +633,8 @@ def destroy_model_parallel():
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP
=
None
_POSITION_EMBEDDING_GROUP
=
None
global
_AMAX_REDUCTION_GROUP
_AMAX_REDUCTION_GROUP
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
...
...
megatron/core/pipeline_parallel/p2p_communication.py
View file @
3aca1415
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
functools
import
reduce
import
operator
import
operator
from
typing
import
Optional
,
List
,
Union
,
Callable
,
Tuple
from
functools
import
reduce
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
megatron
import
core
from
megatron
import
core
from
megatron.core
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_pipeline_model_parallel_group
,
get_pipeline_model_parallel_group
,
get_pipeline_model_parallel_prev_rank
,
get_pipeline_model_parallel_next_rank
,
get_pipeline_model_parallel_next_rank
,
get_pipeline_model_parallel_prev_rank
,
get_pipeline_model_parallel_rank
,
)
)
# Types
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
config
):
use_ring_exchange_p2p
):
"""Communicate tensor shapes between stages. Used to communicate
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
This is required when the sequence lengths across micro batches
...
@@ -42,49 +43,59 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
...
@@ -42,49 +43,59 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
send_prev_shape_tensor
=
None
send_prev_shape_tensor
=
None
send_next_shape_tensor
=
None
send_next_shape_tensor
=
None
if
recv_prev
:
if
recv_prev
:
recv_prev_shape_tensor
=
torch
.
empty
(
(
3
),
recv_prev_shape_tensor
=
torch
.
empty
(
device
=
torch
.
cuda
.
current_device
(),
(
3
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
dtype
=
torch
.
int64
)
)
if
recv_next
:
if
recv_next
:
recv_next_shape_tensor
=
torch
.
empty
(
(
3
),
recv_next_shape_tensor
=
torch
.
empty
(
device
=
torch
.
cuda
.
current_device
(),
(
3
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
dtype
=
torch
.
int64
)
)
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
send_prev_shape_tensor
=
torch
.
tensor
(
tensor_send_prev
.
size
(),
send_prev_shape_tensor
=
torch
.
tensor
(
device
=
torch
.
cuda
.
current_device
(),
tensor_send_prev
.
size
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
dtype
=
torch
.
int64
)
)
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
send_next_shape_tensor
=
torch
.
tensor
(
tensor_send_next
.
size
(),
send_next_shape_tensor
=
torch
.
tensor
(
device
=
torch
.
cuda
.
current_device
(),
tensor_send_next
.
size
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
dtype
=
torch
.
int64
)
)
if
use_ring_exchange_p2p
:
if
config
.
use_ring_exchange_p2p
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
send_prev_shape_tensor
,
torch
.
distributed
.
ring_exchange
(
tensor_recv_prev
=
recv_prev_shape_tensor
,
tensor_send_prev
=
send_prev_shape_tensor
,
tensor_send_next
=
send_next_shape_tensor
,
tensor_recv_prev
=
recv_prev_shape_tensor
,
tensor_recv_next
=
recv_next_shape_tensor
,
tensor_send_next
=
send_next_shape_tensor
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
tensor_recv_next
=
recv_next_shape_tensor
,
group
=
get_pipeline_model_parallel_group
(),
)
else
:
else
:
ops
=
[]
ops
=
[]
if
send_prev_shape_tensor
is
not
None
:
if
send_prev_shape_tensor
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
send_prev_shape_tensor
,
torch
.
distributed
.
isend
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
send_prev_shape_tensor
,
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
send_prev_op
)
ops
.
append
(
send_prev_op
)
if
recv_prev_shape_tensor
is
not
None
:
if
recv_prev_shape_tensor
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_prev_shape_tensor
,
torch
.
distributed
.
irecv
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
recv_prev_shape_tensor
,
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
recv_prev_op
)
ops
.
append
(
recv_prev_op
)
if
send_next_shape_tensor
is
not
None
:
if
send_next_shape_tensor
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
send_next_shape_tensor
,
torch
.
distributed
.
isend
,
mpu
.
get_pipeline_model_parallel_next_rank
())
send_next_shape_tensor
,
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
send_next_op
)
ops
.
append
(
send_next_op
)
if
recv_next_shape_tensor
is
not
None
:
if
recv_next_shape_tensor
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_next_shape_tensor
,
torch
.
distributed
.
irecv
,
mpu
.
get_pipeline_model_parallel_next_rank
())
recv_next_shape_tensor
,
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
recv_next_op
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
...
@@ -106,15 +117,126 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
...
@@ -106,15 +117,126 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev,
return
recv_prev_shape
,
recv_next_shape
return
recv_prev_shape
,
recv_next_shape
def
_communicate
(
*
,
tensor_send_next
:
Optional
[
torch
.
Tensor
],
def
_batched_p2p_ops
(
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
*
,
recv_prev
:
bool
,
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
recv_next
:
bool
,
tensor_recv_prev
:
Optional
[
torch
.
Tensor
],
tensor_shape
:
Shape
,
tensor_send_next
:
Optional
[
torch
.
Tensor
],
dtype
:
Optional
[
torch
.
dtype
],
tensor_recv_next
:
Optional
[
torch
.
Tensor
],
variable_seq_lengths
:
bool
=
False
,
group
:
torch
.
distributed
.
ProcessGroup
use_ring_exchange_p2p
:
bool
=
False
,
):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
get_pipeline_model_parallel_prev_rank
(),
group
,
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
get_pipeline_model_parallel_prev_rank
(),
group
,
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
get_pipeline_model_parallel_next_rank
(),
group
,
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
get_pipeline_model_parallel_next_rank
(),
group
,
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
else
:
reqs
=
[]
return
reqs
def
_p2p_ops
(
*
,
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
tensor_recv_prev
:
Optional
[
torch
.
Tensor
],
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_recv_next
:
Optional
[
torch
.
Tensor
],
group
:
torch
.
distributed
.
ProcessGroup
):
reqs
=
[]
rank
=
get_pipeline_model_parallel_rank
()
if
get_pipeline_model_parallel_rank
()
%
2
==
0
:
if
tensor_send_next
is
not
None
:
send_next_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_next
,
dst
=
get_pipeline_model_parallel_next_rank
(),
group
=
group
,
)
reqs
.
append
(
send_next_req
)
if
tensor_recv_prev
is
not
None
:
recv_prev_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_prev
,
src
=
get_pipeline_model_parallel_prev_rank
(),
group
=
group
,
)
reqs
.
append
(
recv_prev_req
)
if
tensor_send_prev
is
not
None
:
send_prev_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_prev
,
dst
=
get_pipeline_model_parallel_prev_rank
(),
group
=
group
,
)
reqs
.
append
(
send_prev_req
)
if
tensor_recv_next
is
not
None
:
recv_next_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_next
,
src
=
get_pipeline_model_parallel_next_rank
(),
group
=
group
,
)
reqs
.
append
(
recv_next_req
)
else
:
if
tensor_recv_prev
is
not
None
:
recv_prev_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_prev
,
src
=
get_pipeline_model_parallel_prev_rank
(),
group
=
group
,
)
reqs
.
append
(
recv_prev_req
)
if
tensor_send_next
is
not
None
:
send_next_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_next
,
dst
=
get_pipeline_model_parallel_next_rank
(),
group
=
group
,
)
reqs
.
append
(
send_next_req
)
if
tensor_recv_next
is
not
None
:
recv_next_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_next
,
src
=
get_pipeline_model_parallel_next_rank
(),
group
=
group
,
)
reqs
.
append
(
recv_next_req
)
if
tensor_send_prev
is
not
None
:
send_prev_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_prev
,
dst
=
get_pipeline_model_parallel_prev_rank
(),
group
=
group
,
)
reqs
.
append
(
send_prev_req
)
return
reqs
def
_communicate
(
*
,
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
wait_on_reqs
:
bool
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Communicate tensors between stages. Used as helper method in other
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
communication methods that are used in megatron/schedules.py.
...
@@ -136,23 +258,9 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor],
...
@@ -136,23 +258,9 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor],
tensors sent and received in a single function call are
tensors sent and received in a single function call are
the same shape).
the same shape).
dtype (torch.dtype, required if either recv_{prev,next} is True):
wait_on_reqs (boolean, optional, default=False):
this must be the type of the tensors that will be
For non-batched p2p communication, wait on each request
received, will typically be params_dtype, but in the case
before returning.
of fp32 residual connections might be torch.float.
variable_seq_lengths (bool, optional, default=False):
Support for variable sequence lengths across
microbatches. Setting this communicates the size of
tensors during pipeline parallelism communication, because
of this extra overhead it should only be set if the
sequence length is not constant during training.
use_ring_exchange_p2p (bool, optional, default = False):
Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom
built torch with torch.distributed.ring_exchange.
Returns:
Returns:
tuple containing
tuple containing
...
@@ -167,84 +275,79 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor],
...
@@ -167,84 +275,79 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor],
tensor_recv_prev
=
None
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_next
=
None
if
not
variable_seq_lengths
:
if
not
config
.
variable_seq_lengths
:
recv_prev_shape
=
tensor_shape
recv_prev_shape
=
tensor_shape
recv_next_shape
=
tensor_shape
recv_next_shape
=
tensor_shape
else
:
else
:
recv_prev_shape
,
recv_next_shape
=
\
recv_prev_shape
,
recv_next_shape
=
_communicate_shapes
(
_communicate_shapes
(
tensor_send_next
,
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
config
tensor_send_prev
,
)
recv_prev
,
recv_next
)
if
recv_prev
:
if
recv_prev
:
if
dtype
is
None
:
if
config
.
pipeline_
dtype
is
None
:
raise
RuntimeError
(
"dtype must be provided if recv_prev is True"
)
raise
RuntimeError
(
"
pipeline_
dtype must be provided if recv_prev is True"
)
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"tensor_shape must be specified if recv_prev is True. "
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
)
tensor_recv_prev
=
torch
.
empty
(
recv_prev_shape
,
tensor_recv_prev
=
torch
.
empty
(
requires_grad
=
True
,
recv_prev_shape
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
,
dtype
=
dtype
)
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
if
recv_next
:
if
recv_next
:
if
dtype
is
None
:
if
config
.
pipeline_
dtype
is
None
:
raise
RuntimeError
(
"dtype must be provided if recv_next is True"
)
raise
RuntimeError
(
"dtype must be provided if recv_next is True"
)
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"tensor_shape must be specified if recv_next is True. "
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
)
tensor_recv_next
=
torch
.
empty
(
recv_next_shape
,
tensor_recv_next
=
torch
.
empty
(
requires_grad
=
True
,
recv_next_shape
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
,
dtype
=
dtype
)
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
# Send tensors in both the forward and backward directions as appropriate.
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange_p2p
:
if
config
.
use_ring_exchange_p2p
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
def
_ring_exchange_wrapper
(
**
kwargs
):
tensor_send_next
=
tensor_send_next
,
torch
.
distributed
.
ring_exchange
(
**
kwargs
)
tensor_recv_next
=
tensor_recv_next
,
return
[]
group
=
get_pipeline_model_parallel_group
())
p2p_func
=
_ring_exchange_wrapper
elif
config
.
batch_p2p_comm
:
assert
wait_on_reqs
p2p_func
=
_batched_p2p_ops
else
:
else
:
ops
=
[]
p2p_func
=
_p2p_ops
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
reqs
=
p2p_func
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
tensor_send_prev
=
tensor_send_prev
,
get_pipeline_model_parallel_prev_rank
())
tensor_recv_prev
=
tensor_recv_prev
,
ops
.
append
(
send_prev_op
)
tensor_send_next
=
tensor_send_next
,
if
tensor_recv_prev
is
not
None
:
tensor_recv_next
=
tensor_recv_next
,
recv_prev_op
=
torch
.
distributed
.
P2POp
(
group
=
get_pipeline_model_parallel_group
(),
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
)
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
wait_on_reqs
and
len
(
reqs
)
>
0
:
if
tensor_send_next
is
not
None
:
for
req
in
reqs
:
send_next_op
=
torch
.
distributed
.
P2POp
(
req
.
wait
()
torch
.
distributed
.
isend
,
tensor_send_next
,
reqs
=
None
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
config
.
batch_p2p_comm
and
config
.
batch_p2p_sync
:
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
# User should assert that we have a modern enough PyTorch to not need this
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
,
reqs
def
recv_forward
(
tensor_shape
:
Shape
,
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
""" Receive tensor from previous rank in pipeline (forward receive).
""" Receive tensor from previous rank in pipeline (forward receive).
...
@@ -254,23 +357,22 @@ def recv_forward(tensor_shape: Shape,
...
@@ -254,23 +357,22 @@ def recv_forward(tensor_shape: Shape,
if
core
.
parallel_state
.
is_pipeline_first_stage
():
if
core
.
parallel_state
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'forward-recv'
,
log_level
=
2
).
start
()
config
.
timers
(
'forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
config
=
config
,
if
timers
is
not
None
:
)
timers
(
'forward-recv'
).
stop
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
recv_backward
(
tensor_shape
:
Shape
,
def
recv_backward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Receive tensor from next rank in pipeline (backward receive).
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
See _communicate for argument details.
...
@@ -278,65 +380,65 @@ def recv_backward(tensor_shape: Shape,
...
@@ -278,65 +380,65 @@ def recv_backward(tensor_shape: Shape,
if
core
.
parallel_state
.
is_pipeline_last_stage
():
if
core
.
parallel_state
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'backward-recv'
,
log_level
=
2
).
start
()
config
.
timers
(
'backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
config
=
config
,
if
timers
is
not
None
:
)
timers
(
'backward-recv'
).
stop
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
config
:
ModelParallelConfig
)
->
None
:
timers
:
Callable
=
None
)
->
None
:
"""Send tensor to next rank in pipeline (forward send).
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
See _communicate for argument details.
"""
"""
if
not
core
.
parallel_state
.
is_pipeline_last_stage
():
if
not
core
.
parallel_state
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'forward-send'
,
log_level
=
2
).
start
()
config
.
timers
(
'forward-send'
,
log_level
=
2
).
start
()
_communicate
(
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
None
,
tensor_shape
=
None
,
dtype
=
None
)
config
=
config
,
if
timers
is
not
None
:
)
timers
(
'forward-send'
).
stop
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
config
:
ModelParallelConfig
)
->
None
:
timers
:
Callable
=
None
)
->
None
:
"""Send tensor to previous rank in pipeline (backward send).
"""Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
See _communicate for argument details.
"""
"""
if
not
core
.
parallel_state
.
is_pipeline_first_stage
():
if
not
core
.
parallel_state
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'backward-send'
,
log_level
=
2
).
start
()
config
.
timers
(
'backward-send'
,
log_level
=
2
).
start
()
_communicate
(
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
None
,
tensor_shape
=
None
,
dtype
=
None
)
config
=
config
,
if
timers
is
not
None
:
)
timers
(
'backward-send'
).
stop
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
:
torch
.
Tensor
,
def
send_forward_recv_backward
(
tensor_shape
:
Shape
,
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Batched send and recv with next rank in pipeline.
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
See _communicate for argument details.
...
@@ -344,24 +446,24 @@ def send_forward_recv_backward(output_tensor: torch.Tensor,
...
@@ -344,24 +446,24 @@ def send_forward_recv_backward(output_tensor: torch.Tensor,
if
core
.
parallel_state
.
is_pipeline_last_stage
():
if
core
.
parallel_state
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
,
log_level
=
2
).
start
()
config
.
timers
(
'forward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
config
=
config
,
if
timers
is
not
None
:
)
timers
(
'forward-send-backward-recv'
).
stop
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
:
torch
.
Tensor
,
def
send_backward_recv_forward
(
tensor_shape
:
Shape
,
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
"""Batched send and recv with previous rank in pipeline.
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
See _communicate for argument details.
...
@@ -369,88 +471,101 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
...
@@ -369,88 +471,101 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor,
if
core
.
parallel_state
.
is_pipeline_first_stage
():
if
core
.
parallel_state
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
,
log_level
=
2
).
start
()
config
.
timers
(
'backward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
config
=
config
,
if
timers
is
not
None
:
)
timers
(
'backward-send-forward-recv'
).
stop
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
def
send_forward_recv_forward
(
recv_prev
:
bool
,
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
recv_prev
:
bool
,
dtype
:
torch
.
dtype
,
tensor_shape
:
Shape
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
config
:
ModelParallelConfig
,
overlap_p2p_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline.
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
See _communicate for argument details.
"""
"""
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
config
.
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
,
wait_handles
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
wait_on_reqs
=
(
not
overlap_p2p_comm
),
if
timers
is
not
None
:
config
=
config
,
timers
(
'forward-send-forward-recv'
).
stop
()
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-forward-recv'
).
stop
()
if
overlap_p2p_comm
:
return
input_tensor
,
wait_handles
return
input_tensor
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
:
torch
.
Tensor
,
def
send_backward_recv_backward
(
recv_next
:
bool
,
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
recv_next
:
bool
,
dtype
:
torch
.
dtype
,
tensor_shape
:
Shape
,
timers
:
Callable
=
None
)
->
torch
.
Tensor
:
config
:
ModelParallelConfig
,
overlap_p2p_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""Batched recv from next rank and send to previous rank in pipeline.
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
See _communicate for argument details.
"""
"""
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
,
log_level
=
2
).
start
()
config
.
timers
(
'backward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
,
wait_handles
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
wait_on_reqs
=
(
not
overlap_p2p_comm
),
if
timers
is
not
None
:
config
=
config
,
timers
(
'backward-send-backward-recv'
).
stop
()
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send-backward-recv'
).
stop
()
if
overlap_p2p_comm
:
return
output_tensor_grad
,
wait_handles
return
output_tensor_grad
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
def
send_forward_backward_recv_forward_backward
(
output_tensor
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_prev
:
bool
,
recv_next
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
dtype
:
torch
.
dtype
,
config
:
ModelParallelConfig
,
timers
:
Callable
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
"""Batched send and recv with previous and next ranks in pipeline.
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
See _communicate for argument details.
"""
"""
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
,
config
.
timers
(
'forward-backward-send-forward-backward-recv'
,
log_level
=
2
).
start
()
log_level
=
2
).
start
()
input_tensor
,
output_tensor_grad
,
_
=
_communicate
(
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
)
config
=
config
,
if
timers
is
not
None
:
)
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
return
input_tensor
,
output_tensor_grad
megatron/core/pipeline_parallel/schedules.py
View file @
3aca1415
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
contextlib
import
contextmanager
,
nullcontext
import
contextlib
from
typing
import
Optional
,
List
,
Union
,
Callable
,
Any
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch.autograd.variable
import
Variable
from
torch.autograd.variable
import
Variable
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
core
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.core.utils
import
get_attr_wrapped_model
,
get_model_type
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.utils
import
get_attr_wrapped_model
,
get_model_config
,
get_model_type
# Types
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
get_forward_backward_func
():
def
get_forward_backward_func
():
"""Retrieves the appropriate forward_backward function given the
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
configuration of parallel_state.
...
@@ -24,6 +26,10 @@ def get_forward_backward_func():
...
@@ -24,6 +26,10 @@ def get_forward_backward_func():
world size and virtual pipeline model parallel world size in the
world size and virtual pipeline model parallel world size in the
global parallel_state.
global parallel_state.
Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.
The function returned takes the following arguments:
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
forward_step_func (required): A function that takes a data
...
@@ -32,6 +38,13 @@ def get_forward_backward_func():
...
@@ -32,6 +38,13 @@ def get_forward_backward_func():
take one torch.Tensor and return a torch.Tensor of loss and a
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.
For example:
For example:
def loss_func(loss_mask, output_tensor):
def loss_func(loss_mask, output_tensor):
...
@@ -54,44 +67,28 @@ def get_forward_backward_func():
...
@@ -54,44 +67,28 @@ def get_forward_backward_func():
data_iterator (required): an iterator over the data, will be
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.
model (required): the actual model.
A torch.nn.M
odule
or,
in the
model (required): the actual model.
Expected to be a list of m
odule
s
in the
case of interleaved
case or iterleaving, a list of torch.nn.
Module
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.Megatron
Module
.
num_microbatches (int, required):
num_microbatches (int, required):
The number of microbatches to go through
The number of microbatches to go through
dtype (required when using pipeline parallelism): dtype used in
seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
p2p communication, usually params_dtype
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
tensor_shape (required when using pipeline parallelism): Shape of
this sequence length.
tensor. The tensor is expected to be 3D and its order of
dimension is supposed to be ``(sequence, batch, hidden)``.
decoder_seq_length (int, required for ModelType.encoder_and_decoder models):
micro_batch_size (int, required): The number of sequences in a microbatch.
Sequence length of the decoder portion, used to determine tensor shapes.
grad_scaler (optional, default=None): If using loss scaling,
decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
this function should take the loss and return the scaled
transformer. This is ignored for a single-stack transformer.
loss. If None, no function is called on the loss.
sequence_parallel (optional, default=False):
forward_only (optional, default = False): Perform only the forward step
Set to :obj:`True` for this function to handle sequence
length. When :obj:`True`, the sequence length on each tensor
model parallel rank is updated to
:math:`original\_sequence\_length /
tensor\_model\_parallel\_world\_size`.
TODO: Do we need this? Just roll into tensor_shape arg?
forward_only (optional, default=False): Perform only the forward step
collect_non_loss_data (optional, bool, default=False): TODO
timers (optional, default=None): TODO
collect_non_loss_data: TODO
enable_autocast (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context
"""
"""
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
...
@@ -104,24 +101,20 @@ def get_forward_backward_func():
...
@@ -104,24 +101,20 @@ def get_forward_backward_func():
forward_backward_func
=
forward_backward_no_pipelining
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
return
forward_backward_func
def
deallocate_output_tensor
(
out
):
def
deallocate_output_tensor
(
out
,
deallocate_pipeline_outputs
=
False
):
'''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
'''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
only useful for its '.grad_fn' field, and not its '.data'.
'''
'''
if
out
is
None
:
if
(
out
is
None
)
or
(
not
deallocate_pipeline_outputs
)
:
return
return
assert
isinstance
(
out
,
torch
.
Tensor
),
\
assert
isinstance
(
out
,
torch
.
Tensor
),
"expected Tensor, found %s."
%
type
(
out
).
__name__
"expected Tensor, found %s."
%
type
(
out
).
__name__
assert
out
.
_base
is
None
,
"counter-productive to free a view of another tensor."
assert
out
.
_base
is
None
,
\
out
.
data
=
torch
.
empty
((
1
,),
device
=
out
.
device
,
dtype
=
out
.
dtype
,)
"counter-productive to free a view of another tensor."
out
.
data
=
torch
.
empty
(
(
1
,),
device
=
out
.
device
,
dtype
=
out
.
dtype
,
)
def
custom_backward
(
output
,
grad_output
):
def
custom_backward
(
output
,
grad_output
):
'''Directly call C++ autograd engine.
'''Directly call C++ autograd engine.
...
@@ -132,53 +125,48 @@ def custom_backward(output, grad_output):
...
@@ -132,53 +125,48 @@ def custom_backward(output, grad_output):
grad have the same shape, while C++'s 'backward' does not.
grad have the same shape, while C++'s 'backward' does not.
'''
'''
assert
output
.
numel
()
==
1
,
\
assert
output
.
numel
()
==
1
,
"output should be pseudo-'freed' in schedule, to optimize memory"
"output should be pseudo-'freed' in schedule, to optimize memory"
assert
isinstance
(
output
,
torch
.
Tensor
),
"output == '%s'."
%
type
(
output
).
__name__
assert
isinstance
(
output
,
torch
.
Tensor
),
\
assert
isinstance
(
grad_output
,
(
torch
.
Tensor
,
type
(
None
))),
(
"output == '%s'."
%
type
(
output
).
__name__
assert
isinstance
(
grad_output
,
(
torch
.
Tensor
,
type
(
None
))),
\
"grad_output == '%s'."
%
type
(
grad_output
).
__name__
"grad_output == '%s'."
%
type
(
grad_output
).
__name__
)
# Handle scalar output
# Handle scalar output
if
grad_output
is
None
:
if
grad_output
is
None
:
assert
output
.
numel
()
==
1
,
"implicit grad requires scalar output."
assert
output
.
numel
()
==
1
,
"implicit grad requires scalar output."
grad_output
=
torch
.
ones_like
(
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
,)
output
,
memory_format
=
torch
.
preserve_format
,
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable
.
_execution_engine
.
run_backward
(
Variable
.
_execution_engine
.
run_backward
(
tensors
=
(
output
,),
tensors
=
(
output
,),
grad_tensors
=
(
grad_output
,),
grad_tensors
=
(
grad_output
,),
keep_graph
=
False
,
keep_graph
=
False
,
create_graph
=
False
,
create_graph
=
False
,
inputs
=
tuple
(),
inputs
=
tuple
(),
allow_unreachable
=
True
,
allow_unreachable
=
True
,
accumulate_grad
=
True
,
accumulate_grad
=
True
,
)
)
def
forward_step
(
forward_step_func
,
data_iterator
,
def
forward_step
(
forward_step_func
,
model
,
data_iterator
,
num_microbatches
,
model
,
input_tensor
,
num_microbatches
,
forward_data_store
,
input_tensor
,
config
,
forward_data_store
,
collect_non_loss_data
=
False
,
timers
,
checkpoint_activations_microbatch
=
None
,
collect_non_loss_data
=
False
,
):
enable_autocast
=
False
):
"""Forward step for passed-in model.
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
passed-in input_tensor is used.
Returns output tensor."""
Returns output tensor."""
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
unwrap_output_tensor
=
False
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
if
not
isinstance
(
input_tensor
,
list
):
...
@@ -188,9 +176,17 @@ def forward_step(forward_step_func,
...
@@ -188,9 +176,17 @@ def forward_step(forward_step_func,
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
set_input_tensor
(
input_tensor
)
context_manager
=
torch
.
autocast
(
"cuda"
)
if
enable_autocast
else
nullcontext
()
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
with
context_manager
:
with
context_manager
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
else
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
,
checkpoint_activations_microbatch
)
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
if
not
collect_non_loss_data
:
if
not
collect_non_loss_data
:
...
@@ -202,24 +198,24 @@ def forward_step(forward_step_func,
...
@@ -202,24 +198,24 @@ def forward_step(forward_step_func,
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
forward_data_store
.
append
(
data
)
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'forward-compute'
).
stop
()
config
.
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
# downstream as well.
model_type
=
get_model_type
(
model
)
model_type
=
get_model_type
(
model
)
if
(
if
parallel_state
.
is_pipeline_stage_after_split
()
and
\
parallel_state
.
is_pipeline_stage_after_split
()
model_type
==
ModelType
.
encoder_and_decoder
:
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]]
return
[
output_tensor
,
input_tensor
[
-
1
]]
if
unwrap_output_tensor
:
if
unwrap_output_tensor
:
return
output_tensor
return
output_tensor
return
[
output_tensor
]
return
[
output_tensor
]
def
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
def
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
):
output_tensor_grad
,
model_type
,
timers
):
"""Backward step through passed-in output tensor.
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
If last stage, output_tensor_grad is None, otherwise gradient of loss
...
@@ -232,8 +228,8 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
...
@@ -232,8 +228,8 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
# needs to be modified slightly to support arbitrary numbers of skip
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
# connections.
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
config
.
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
unwrap_input_tensor_grad
=
False
...
@@ -250,9 +246,13 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
...
@@ -250,9 +246,13 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
output_tensor_grad
=
[
output_tensor_grad
]
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
and
grad_scaler
is
not
None
:
if
output_tensor_grad
[
0
]
is
None
and
config
.
grad_scale_func
is
not
None
:
output_tensor
=
grad_scaler
(
output_tensor
[
0
])
output_tensor
[
0
]
=
config
.
grad_scale_func
(
output_tensor
[
0
])
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
if
config
.
deallocate_pipeline_outputs
:
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
input_tensor_grad
=
[
None
]
...
@@ -266,42 +266,34 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
...
@@ -266,42 +266,34 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
# Handle single skip connection if it exists (encoder_hidden_state in
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
# model with encoder and decoder).
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
\
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
model_type
==
ModelType
.
encoder_and_decoder
:
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
if
output_tensor_grad
[
1
]
is
not
None
:
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
input_tensor_grad
=
input_tensor_grad
[
0
]
if
timers
is
not
None
:
if
config
.
timers
is
not
None
:
timers
(
'backward-compute'
).
stop
()
config
.
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
return
input_tensor_grad
@
contextmanager
def
forward_backward_no_pipelining
(
def
dummy_handler
():
*
,
try
:
forward_step_func
,
yield
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
finally
:
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
pass
num_microbatches
:
int
,
seq_length
:
int
,
# unused
micro_batch_size
:
int
,
# unused
def
forward_backward_no_pipelining
(
*
,
decoder_seq_length
:
int
=
None
,
# unused
forward_step_func
,
forward_only
:
bool
=
False
,
data_iterator
,
collect_non_loss_data
:
bool
=
False
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
):
num_microbatches
:
int
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
# unused
tensor_shape
:
Optional
[
Shape
]
=
None
,
# unused
decoder_seq_length
:
Optional
[
int
]
=
None
,
# unused
grad_scaler
:
Callable
=
None
,
sequence_parallel
:
bool
=
False
,
# unused
forward_only
:
bool
=
False
,
timers
:
Callable
=
None
,
collect_non_loss_data
:
bool
=
False
,
enable_autocast
:
bool
=
False
):
"""Run forward and backward passes with no pipeline parallelism
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
(no inter-stage communication).
...
@@ -310,57 +302,121 @@ def forward_backward_no_pipelining(*,
...
@@ -310,57 +302,121 @@ def forward_backward_no_pipelining(*,
See get_forward_backward_func() for argument details
See get_forward_backward_func() for argument details
"""
"""
assert
len
(
model
)
==
1
model
=
model
[
0
]
context_handler
=
dummy_handler
if
isinstance
(
model
,
list
):
if
isinstance
(
model
,
torchDDP
):
assert
len
(
model
)
==
1
,
"non-pipeline-parallel schedule does not support model chunking"
context_handler
=
model
.
no_sync
model
=
model
[
0
]
if
isinstance
(
data_iterator
,
list
):
assert
(
len
(
data_iterator
)
==
1
),
"non-pipeline-parallel schedule does not support model chunking"
data_iterator
=
data_iterator
[
0
]
config
=
get_model_config
(
model
)
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
and
isinstance
(
model
,
torchDDP
):
no_sync_func
=
model
.
no_sync
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
model_type
=
get_model_type
(
model
)
model_type
=
get_model_type
(
model
)
forward_data_store
=
[]
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
with
no_sync_func
():
for
i
in
range
(
num_microbatches
-
1
):
for
i
in
range
(
num_microbatches
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
forward_step_func
,
timers
,
collect_non_loss_data
,
enable_autocast
)
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
output_tensor_grad
,
model_type
,
timers
)
# Run computation for last microbatch out of context handler (want to
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
output_tensor
=
forward_step
(
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
forward_step_func
,
timers
,
collect_non_loss_data
,
enable_autocast
)
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
output_tensor_grad
,
model_type
,
timers
)
return
forward_data_store
return
forward_data_store
def
forward_backward_pipelining_with_interleaving
(
*
,
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
*
,
data_iterator
,
forward_step_func
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
num_microbatches
:
int
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
dtype
:
torch
.
dtype
,
num_microbatches
:
int
,
tensor_shape
:
Shape
,
seq_length
:
int
,
decoder_seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
int
,
grad_scaler
:
Callable
=
None
,
decoder_seq_length
:
int
=
None
,
sequence_parallel
:
bool
=
False
,
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
timers
:
Callable
=
None
,
):
collect_non_loss_data
:
bool
=
False
,
enable_autocast
:
bool
=
False
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert
isinstance
(
model
,
list
),
"interleaved pipeline parallelism expected model chunking"
assert
all
(
isinstance
(
chunk
,
torch
.
nn
.
Module
)
for
chunk
in
model
),
"invalid model chunking"
assert
isinstance
(
data_iterator
,
list
),
"interleaved pipeline parallelism expected each model chunk to have a data iterator"
config
=
get_model_config
(
model
[
0
])
if
config
.
overlap_p2p_comm
and
config
.
batch_p2p_comm
:
raise
ValueError
(
"Can not use both overlap_p2p_comm and batch_p2p_comm"
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
and
all
(
isinstance
(
chunk
,
torchDDP
)
for
chunk
in
model
):
def
multi_no_sync
():
stack
=
contextlib
.
ExitStack
()
for
chunk
in
model
:
stack
.
enter_context
(
chunk
.
no_sync
())
return
stack
no_sync_func
=
multi_no_sync
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
None
:
no_sync_context
=
no_sync_func
()
no_sync_context
.
__enter__
()
def
enable_grad_sync
():
"""Enable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
not
None
:
no_sync_context
.
__exit__
(
None
,
None
,
None
)
no_sync_context
=
None
disable_grad_sync
()
# Model chunk IDs with synchronized grads
synchronized_model_chunks
=
set
()
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
...
@@ -381,17 +437,15 @@ def forward_backward_pipelining_with_interleaving(*,
...
@@ -381,17 +437,15 @@ def forward_backward_pipelining_with_interleaving(*,
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
model_type
==
ModelType
.
encoder_and_decoder
:
raise
RuntimeError
(
"Interleaving is not supported with an encoder and decoder model."
)
raise
RuntimeError
(
"Interleaving is not supported with an encoder and decoder model."
)
if
decoder_seq_length
is
not
None
and
decoder_seq_length
!=
tensor_shape
[
0
]:
if
decoder_seq_length
is
not
None
and
decoder_seq_length
!=
seq_length
:
raise
RuntimeError
(
"Interleaving is not supported with a different decoder sequence length."
)
raise
RuntimeError
(
"Interleaving is not supported with a different decoder sequence length."
if
sequence_parallel
:
seq_length
,
batch_size
,
hidden
=
tensor_shape
tensor_shape
=
(
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
(),
batch_size
,
hidden
,
)
)
tensor_shape
=
[
seq_length
,
micro_batch_size
,
config
.
hidden_size
]
if
config
.
sequence_parallel
:
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_tensor_model_parallel_world_size
()
# Compute number of warmup and remaining microbatches.
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_model_chunks
=
len
(
model
)
total_num_microbatches
=
num_microbatches
*
num_model_chunks
total_num_microbatches
=
num_microbatches
*
num_model_chunks
...
@@ -409,45 +463,96 @@ def forward_backward_pipelining_with_interleaving(*,
...
@@ -409,45 +463,96 @@ def forward_backward_pipelining_with_interleaving(*,
num_warmup_microbatches
=
total_num_microbatches
num_warmup_microbatches
=
total_num_microbatches
all_warmup_microbatches
=
True
all_warmup_microbatches
=
True
else
:
else
:
num_warmup_microbatches
=
\
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
+=
(
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
total_num_microbatches
)
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_microbatches_remaining
=
total_num_microbatches
-
num_warmup_microbatches
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
total_num_microbatches
)
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
num_microbatches_remaining
=
\
# within the maximum outstanding micro-batch backpropagations.
total_num_microbatches
-
num_warmup_microbatches
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops
=
None
if
config
.
num_microbatches_with_partial_activation_checkpoints
is
not
None
:
max_outstanding_backprops
=
num_warmup_microbatches
+
1
# Synchronize params for first two model chunks
if
config
.
param_sync_func
is
not
None
:
config
.
param_sync_func
(
model
[
0
].
parameters
())
config
.
param_sync_func
(
model
[
1
].
parameters
())
def
get_model_chunk_id
(
microbatch_id
,
forward
):
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
if
not
forward
:
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
):
def
is_first_microbatch_for_model_chunk
(
microbatch_id
:
int
)
->
bool
:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size
=
pipeline_parallel_size
*
num_model_chunks
num_microbatch_groups
=
total_num_microbatches
//
microbatch_group_size
microbatch_group_id
=
microbatch_id
//
microbatch_group_size
microbatch_id_in_group
=
microbatch_id
%
microbatch_group_size
if
microbatch_group_id
==
0
:
return
microbatch_id_in_group
%
pipeline_parallel_size
==
0
else
:
return
False
def
is_last_microbatch_for_model_chunk
(
microbatch_id
:
int
)
->
bool
:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size
=
pipeline_parallel_size
*
num_model_chunks
num_microbatch_groups
=
total_num_microbatches
//
microbatch_group_size
microbatch_group_id
=
microbatch_id
//
microbatch_group_size
microbatch_id_in_group
=
microbatch_id
%
microbatch_group_size
if
microbatch_group_id
==
num_microbatch_groups
-
1
:
return
microbatch_id_in_group
%
pipeline_parallel_size
==
pipeline_parallel_size
-
1
else
:
return
False
def
forward_step_helper
(
microbatch_id
,
checkpoint_activations_microbatch
):
"""Helper method to run forward step with model split into chunks
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
param_sync_func
is
not
None
:
param_sync_microbatch_id
=
microbatch_id
+
pipeline_parallel_rank
if
(
param_sync_microbatch_id
<
total_num_microbatches
and
is_first_microbatch_for_model_chunk
(
param_sync_microbatch_id
)
):
param_sync_chunk_id
=
get_model_chunk_id
(
param_sync_microbatch_id
,
forward
=
True
)
+
1
if
1
<
param_sync_chunk_id
<
num_model_chunks
:
config
.
param_sync_func
(
model
[
param_sync_chunk_id
].
parameters
())
# forward step
# forward step
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
output_tensor
=
forward_step
(
data_iterator
[
model_chunk_id
],
forward_step_func
,
model
[
model_chunk_id
],
data_iterator
[
model_chunk_id
],
num_microbatches
,
model
[
model_chunk_id
],
input_tensor
,
num_microbatches
,
forward_data_store
,
input_tensor
,
timers
,
forward_data_store
,
collect_non_loss_data
,
config
,
enable_autocast
)
collect_non_loss_data
,
checkpoint_activations_microbatch
,
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
...
@@ -464,31 +569,65 @@ def forward_backward_pipelining_with_interleaving(*,
...
@@ -464,31 +569,65 @@ def forward_backward_pipelining_with_interleaving(*,
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch grad synchronization (default)
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
microbatch_id
):
enable_grad_sync
()
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
\
input_tensor_grad
=
backward_step
(
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
input_tensor
,
)
output_tensor
,
output_tensor_grad
,
# launch grad synchronization (custom grad sync)
model_type
,
# Note: Asynchronous communication tends to slow down compute.
timers
)
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
grad_sync_func
is
not
None
:
grad_sync_microbatch_id
=
microbatch_id
-
pipeline_parallel_rank
if
grad_sync_microbatch_id
>=
0
and
is_last_microbatch_for_model_chunk
(
grad_sync_microbatch_id
):
grad_sync_chunk_id
=
get_model_chunk_id
(
grad_sync_microbatch_id
,
forward
=
False
)
enable_grad_sync
()
config
.
grad_sync_func
(
model
[
grad_sync_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
grad_sync_chunk_id
)
disable_grad_sync
()
return
input_tensor_grad
return
input_tensor_grad
# Run warmup forward passes.
# Run warmup forward passes.
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
p2p_communication
.
recv_forward
(
tensor_shape
,
dtype
,
timers
=
timers
))
fwd_wait_handles
=
None
bwd_wait_handles
=
None
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
k
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
output_tensor
=
forward_step_helper
(
k
,
checkpoint_activations_microbatch
)
# Determine if tensor should be received from previous stage.
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
if
next_forward_model_chunk_id
==
0
:
...
@@ -502,108 +641,255 @@ def forward_backward_pipelining_with_interleaving(*,
...
@@ -502,108 +641,255 @@ def forward_backward_pipelining_with_interleaving(*,
# Send and receive tensors as appropriate (send tensors computed
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
\
if
not
config
.
overlap_p2p_comm
:
not
all_warmup_microbatches
:
if
(
input_tensor_grad
=
None
k
==
(
num_warmup_microbatches
-
1
)
recv_next
=
True
and
not
forward_only
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
and
not
all_warmup_microbatches
recv_next
=
False
):
input_tensor
,
output_tensor_grad
=
\
input_tensor_grad
=
None
p2p_communication
.
send_forward_backward_recv_forward_backward
(
recv_next
=
True
output_tensor
,
input_tensor_grad
,
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
False
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
(
timers
=
timers
)
input_tensor
,
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
config
=
config
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
else
:
else
:
input_tensor
=
\
input_tensor
,
fwd_wait_handles
=
p2p_communication
.
send_forward_recv_forward
(
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
output_tensor
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
config
=
config
,
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
overlap_p2p_comm
=
True
,
deallocate_output_tensor
(
output_tensor
)
)
if
(
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
):
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
(
output_tensor_grad
,
bwd_wait_handles
,
)
=
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
)
# Backward pass.
# Decide to checkpoint all layers' activations of the current micro-batch
backward_k
=
k
if
max_outstanding_backprops
is
not
None
:
input_tensor_grad
=
backward_step_helper
(
backward_k
)
checkpoint_activations_microbatch
=
(
forward_k
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
if
config
.
overlap_p2p_comm
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
output_tensor
=
forward_step_helper
(
forward_k
,
checkpoint_activations_microbatch
)
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
# Last virtual stage no activation tensor to send
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# If last iteration, don't receive; we already received one extra
# and output_tensor_grad.
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Determine if current stage has anything to send in either direction,
# Send activation tensor to the next stage and receive activation tensor from the
# otherwise set tensor to None.
# previous stage
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
input_tensor
,
fwd_wait_handles
=
p2p_communication
.
send_forward_recv_forward
(
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
output_tensor
,
if
parallel_state
.
is_pipeline_last_stage
():
recv_prev
=
recv_prev
,
output_tensor
=
None
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
# assert fwd_wait_handles is not None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
# First virtual stage no activation gradient tensor to send
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if the current virtual stage has an activation gradient tensor to receive
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
output_tensor_grad
,
bwd_wait_handles
=
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
else
:
# no p2p overlap
output_tensor
=
forward_step_helper
(
forward_k
,
checkpoint_activations_microbatch
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
recv_next
=
True
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_first_stage
():
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
input_tensor_grad
=
None
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# Determine if peers are sending, and where in data structure to put
# If last iteration, don't receive; we already received one extra
# received tensors.
# before the start of the for loop.
recv_prev
=
True
if
k
==
(
num_microbatches_remaining
-
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# Communicate tensors.
# before the start of the for loop.
(
if
k
==
(
num_microbatches_remaining
-
1
):
input_tensor
,
recv_prev
=
False
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
# Communicate tensors.
output_tensor
,
input_tensor
,
output_tensor_grad
=
\
input_tensor_grad
,
p2p_communication
.
send_forward_backward_recv_forward_backward
(
recv_prev
=
recv_prev
,
output_tensor
,
input_tensor_grad
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
timers
=
timers
)
config
=
config
,
deallocate_output_tensor
(
output_tensor
)
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Put input_tensor and output_tensor_grad in data structures in the
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
# right location.
if
recv_prev
:
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
output_tensor_grad
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Run cooldown backward passes (flush out pipeline).
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
not
forward_only
:
if
config
.
overlap_p2p_comm
and
bwd_wait_handles
is
not
None
:
for
wait_handle
in
bwd_wait_handles
:
wait_handle
.
wait
()
if
all_warmup_microbatches
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
dtype
=
dtype
,
timers
=
timers
))
p2p_communication
.
recv_backward
(
tensor_shape
,
config
=
config
)
)
for
k
in
range
(
num_microbatches_remaining
,
total_num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
total_num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
...
@@ -612,18 +898,33 @@ def forward_backward_pipelining_with_interleaving(*,
...
@@ -612,18 +898,33 @@ def forward_backward_pipelining_with_interleaving(*,
recv_next
=
False
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
)
timers
=
timers
))
)
# Launch any remaining grad reductions
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
params
=
[]
for
model_chunk_id
in
range
(
num_model_chunks
):
if
model_chunk_id
not
in
synchronized_model_chunks
:
params
.
extend
(
model
[
model_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
params
:
config
.
grad_sync_func
(
params
)
return
forward_data_store
return
forward_data_store
def
get_tensor_shapes
(
*
,
rank
:
int
,
def
get_tensor_shapes
(
model_type
:
ModelType
,
*
,
tensor_shape
:
Shape
,
rank
:
int
,
decoder_seq_length
:
int
,
model_type
:
ModelType
,
sequence_parallel
:
bool
):
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
,
config
,
):
# Determine right tensor sizes (based on position of rank with respect to split
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# Send two tensors if model is T5 and rank is in decoder stage:
...
@@ -634,71 +935,63 @@ def get_tensor_shapes(*,
...
@@ -634,71 +935,63 @@ def get_tensor_shapes(*,
# Otherwise, send one tensor (pre-transpose).
# Otherwise, send one tensor (pre-transpose).
tensor_shapes
=
[]
tensor_shapes
=
[]
assert
(
if
config
.
sequence_parallel
:
len
(
tensor_shape
)
==
3
),
f
"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but
{
tensor_shape
}
"
seq_length
,
micro_batch_size
,
hidden_size
=
tensor_shape
if
sequence_parallel
:
seq_length
=
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
seq_length
=
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
decoder_seq_length
=
(
decoder_seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
)
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
sequence_parallel
:
decoder_seq_length
=
decoder_seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
if
parallel_state
.
is_pipeline_stage_before_split
(
rank
):
if
parallel_state
.
is_pipeline_stage_before_split
(
rank
):
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
else
:
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
hidden_size
))
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
config
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
else
:
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
return
tensor_shapes
return
tensor_shapes
def
recv_forward
(
tensor_shapes
,
config
):
def
recv_forward
(
tensor_shapes
,
dtype
,
timers
):
input_tensors
=
[]
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
input_tensors
.
append
(
None
)
else
:
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
dtype
,
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
timers
=
timers
))
return
input_tensors
return
input_tensors
def
recv_backward
(
tensor_shapes
,
dtype
,
timers
):
def
recv_backward
(
tensor_shapes
,
config
):
output_tensor_grads
=
[]
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
output_tensor_grads
.
append
(
None
)
else
:
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
dtype
,
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
config
))
timers
=
timers
))
return
output_tensor_grads
return
output_tensor_grads
def
send_forward
(
output_tensors
,
tensor_shapes
,
timers
):
def
send_forward
(
output_tensors
,
tensor_shapes
,
config
):
if
not
isinstance
(
output_tensors
,
list
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensors
=
[
output_tensors
]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
continue
continue
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
p2p_communication
.
send_forward
(
output_tensor
,
config
)
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
config
):
if
not
isinstance
(
input_tensor_grads
,
list
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensor_grads
=
[
input_tensor_grads
]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
if
tensor_shape
is
None
:
continue
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
=
timers
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
config
)
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
dtype
,
timers
):
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
config
):
if
not
isinstance
(
output_tensors
,
list
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
output_tensor_grads
=
[]
...
@@ -707,12 +1000,13 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, dtype, timers):
...
@@ -707,12 +1000,13 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, dtype, timers):
output_tensor_grads
.
append
(
None
)
output_tensor_grads
.
append
(
None
)
continue
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
dtype
,
timers
=
timers
)
output_tensor
,
tensor_shape
,
config
)
output_tensor_grads
.
append
(
output_tensor_grad
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
dtype
,
timers
):
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
config
):
if
not
isinstance
(
input_tensor_grads
,
list
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
input_tensors
=
[]
...
@@ -721,56 +1015,110 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers)
...
@@ -721,56 +1015,110 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers)
input_tensors
.
append
(
None
)
input_tensors
.
append
(
None
)
continue
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
dtype
,
timers
=
timers
)
input_tensor_grad
,
tensor_shape
,
config
)
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
*
,
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
*
,
data_iterator
,
forward_step_func
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
num_microbatches
:
int
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
dtype
:
torch
.
dtype
,
num_microbatches
:
int
,
tensor_shape
:
Shape
,
seq_length
:
int
,
decoder_seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
int
,
grad_scaler
:
Callable
=
None
,
decoder_seq_length
:
int
=
None
,
sequence_parallel
:
bool
=
False
,
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
timers
:
Callable
=
None
,
):
collect_non_loss_data
:
bool
=
False
,
enable_autocast
:
bool
=
False
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert
len
(
model
)
==
1
if
isinstance
(
model
,
list
):
model
=
model
[
0
]
assert
(
len
(
model
)
==
1
),
"non-interleaved pipeline parallelism does not support model chunking"
model
=
model
[
0
]
if
isinstance
(
data_iterator
,
list
):
assert
(
len
(
data_iterator
)
==
1
),
"non-pipeline-parallel schedule does not support model chunking"
data_iterator
=
data_iterator
[
0
]
config
=
get_model_config
(
model
)
if
config
.
overlap_p2p_comm
:
raise
ValueError
(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
and
isinstance
(
model
,
torchDDP
):
no_sync_func
=
model
.
no_sync
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
None
:
no_sync_context
=
no_sync_func
()
no_sync_context
.
__enter__
()
def
enable_grad_sync
():
"""Enable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
not
None
:
no_sync_context
.
__exit__
(
None
,
None
,
None
)
no_sync_context
=
None
disable_grad_sync
()
# Compute number of warmup microbatches.
# Compute number of warmup microbatches.
num_warmup_microbatches
=
\
num_warmup_microbatches
=
(
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_world_size
()
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
)
-
parallel_state
.
get_pipeline_model_parallel_rank
()
num_warmup_microbatches
=
min
(
-
1
num_warmup_microbatches
,
)
num_microbatches
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
num_microbatches
-
num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops
=
None
if
config
.
num_microbatches_with_partial_activation_checkpoints
is
not
None
:
max_outstanding_backprops
=
num_warmup_microbatches
+
1
model_type
=
get_model_type
(
model
)
model_type
=
get_model_type
(
model
)
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
-
1
,
recv_tensor_shapes
=
get_tensor_shapes
(
model_type
=
model_type
,
rank
=
rank
-
1
,
tensor_shape
=
tensor_shape
,
model_type
=
model_type
,
decoder_seq_length
=
decoder_seq_length
,
seq_length
=
seq_length
,
sequence_parallel
=
sequence_parallel
)
micro_batch_size
=
micro_batch_size
,
send_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
,
decoder_seq_length
=
decoder_seq_length
,
model_type
=
model_type
,
config
=
config
,
tensor_shape
=
tensor_shape
,
)
decoder_seq_length
=
decoder_seq_length
,
send_tensor_shapes
=
get_tensor_shapes
(
sequence_parallel
=
sequence_parallel
)
rank
=
rank
,
model_type
=
model_type
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
decoder_seq_length
=
decoder_seq_length
,
config
=
config
,
)
# Input, output tensors only need to be saved when doing backward passes
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
input_tensors
=
None
...
@@ -782,77 +1130,125 @@ def forward_backward_pipelining_without_interleaving(*,
...
@@ -782,77 +1130,125 @@ def forward_backward_pipelining_without_interleaving(*,
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
# Decide to checkpoint all layers' activations of the current micro-batch
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
if
max_outstanding_backprops
is
not
None
:
input_tensor
,
forward_data_store
,
checkpoint_activations_microbatch
=
(
timers
,
collect_non_loss_data
,
enable_autocast
)
i
%
max_outstanding_backprops
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
if
not
forward_only
:
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
[
0
])
deallocate_output_tensor
(
output_tensor
[
0
]
,
config
.
deallocate_pipeline_outputs
)
# Before running 1F1B, need to receive first forward tensor.
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
num_microbatches_remaining
>
0
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
)
)
last_iteration
=
i
==
(
num_microbatches_remaining
-
1
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
# Decide to checkpoint all layers' activations of the current micro-batch
input_tensor
,
forward_data_store
,
if
max_outstanding_backprops
is
not
None
:
timers
,
collect_non_loss_data
,
enable_autocast
)
checkpoint_activations_microbatch
=
(
(
i
+
num_warmup_microbatches
)
%
max_outstanding_backprops
)
>=
config
.
num_microbatches_with_partial_activation_checkpoints
else
:
checkpoint_activations_microbatch
=
None
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
if
not
last_iteration
:
if
not
last_iteration
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
else
:
else
:
output_tensor_grad
=
\
output_tensor_grad
=
send_forward_recv_backward
(
send_forward_recv_backward
(
output_tensor
,
output_tensor
,
send_tensor_shapes
,
config
send_tensor_shapes
,
dtype
,
)
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
[
0
])
deallocate_output_tensor
(
output_tensor
[
0
]
,
config
.
deallocate_pipeline_outputs
)
# Pop input_tensor and output_tensor from the start of the list for
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
# the backward pass.
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
\
input_tensor_grad
=
backward_step
(
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
output_tensor_grad
,
model_type
,
timers
)
)
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
else
:
else
:
input_tensor
=
\
input_tensor
=
send_backward_recv_forward
(
send_backward_recv_forward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
input_tensor_grad
,
recv_tensor_shapes
,
dtype
,
timers
=
timers
)
)
# Run cooldown backward passes.
# Run cooldown backward passes.
if
not
forward_only
:
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if
i
==
num_warmup_microbatches
-
1
:
if
config
.
grad_sync_func
is
None
or
rank
==
0
:
enable_grad_sync
()
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
dtype
,
timers
=
timers
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
config
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
input_tensor_grad
=
\
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
backward_step
(
grad_scaler
,
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
# Launch any remaining grad reductions
if
no_sync_context
is
not
None
:
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
config
.
grad_sync_func
(
model
.
parameters
())
return
forward_data_store
return
forward_data_store
megatron/core/requirements.txt
0 → 100644
View file @
3aca1415
torch
\ No newline at end of file
megatron/core/tensor_parallel/__init__.py
View file @
3aca1415
from
.cross_entropy
import
vocab_parallel_cross_entropy
from
.cross_entropy
import
vocab_parallel_cross_entropy
from
.data
import
broadcast_data
from
.data
import
broadcast_data
from
.layers
import
(
from
.layers
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
linear_with_grad_accumulation_and_async_allreduce
,
param_is_not_tensor_parallel_duplicate
,
param_is_not_tensor_parallel_duplicate
,
linear_with_grad_accumulation_and_async_allreduce
set_defaults_if_not_set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
)
)
from
.mappings
import
(
from
.mappings
import
(
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_sequence_parallel_region
,
gather_from_sequence_parallel_region
,
sc
at
t
er_
to
_tensor_model_parallel_region
,
g
at
h
er_
from
_tensor_model_parallel_region
,
scatter_to_sequence_parallel_region
,
scatter_to_sequence_parallel_region
,
scatter_to_tensor_model_parallel_region
,
)
)
from
.random
import
checkpoint
,
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
from
.random
import
(
checkpoint
,
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
,
)
from
.utils
import
(
from
.utils
import
(
gather_split_1d_tensor
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
split_tensor_into_1d_equal_chunks
,
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
,
)
)
__all__
=
[
__all__
=
[
...
@@ -38,7 +29,7 @@ __all__ = [
...
@@ -38,7 +29,7 @@ __all__ = [
"vocab_parallel_cross_entropy"
,
"vocab_parallel_cross_entropy"
,
# data.py
# data.py
"broadcast_data"
,
"broadcast_data"
,
#layers.py
#
layers.py
"ColumnParallelLinear"
,
"ColumnParallelLinear"
,
"RowParallelLinear"
,
"RowParallelLinear"
,
"VocabParallelEmbedding"
,
"VocabParallelEmbedding"
,
...
@@ -51,7 +42,7 @@ __all__ = [
...
@@ -51,7 +42,7 @@ __all__ = [
"copy_to_tensor_model_parallel_region"
,
"copy_to_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"gather_from_sequence_parallel_region"
,
"gather_from_sequence_parallel_region"
,
# "reduce_from_tensor_model_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
# random.py
# random.py
...
...
megatron/core/tensor_parallel/cross_entropy.py
View file @
3aca1415
...
@@ -5,22 +5,21 @@ import torch
...
@@ -5,22 +5,21 @@ import torch
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size
,
)
)
from
.utils
import
VocabUtility
from
.utils
import
VocabUtility
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
,
label_smoothing
=
0.0
):
def
forward
(
ctx
,
vocab_parallel_logits
,
target
,
label_smoothing
=
0.0
):
# Maximum value along vocab dimension across all GPUs.
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
torch
.
distributed
.
all_reduce
(
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_tensor_model_parallel_group
()
group
=
get_tensor_model_parallel_group
()
)
)
# Subtract the maximum value.
# Subtract the maximum value.
vocab_parallel_logits
=
vocab_parallel_logits
-
logits_max
.
unsqueeze
(
dim
=-
1
)
vocab_parallel_logits
=
vocab_parallel_logits
-
logits_max
.
unsqueeze
(
dim
=-
1
)
...
@@ -29,8 +28,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -29,8 +28,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
partition_vocab_size
,
rank
,
world_size
)
partition_vocab_size
,
rank
,
world_size
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask
=
(
target
<
vocab_start_index
)
|
(
target
>=
vocab_end_index
)
target_mask
=
(
target
<
vocab_start_index
)
|
(
target
>=
vocab_end_index
)
...
@@ -42,24 +40,27 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -42,24 +40,27 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d
=
vocab_parallel_logits
.
view
(
-
1
,
partition_vocab_size
)
logits_2d
=
vocab_parallel_logits
.
view
(
-
1
,
partition_vocab_size
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
device
=
logits_2d
.
device
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
[
target_mask
]
=
0.0
predicted_logits
[
target_mask
]
=
0.0
# All reduce is needed to get the chunks from other GPUs.
# All reduce is needed to get the chunks from other GPUs.
torch
.
distributed
.
all_reduce
(
predicted_logits
,
torch
.
distributed
.
all_reduce
(
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
predicted_logits
,
group
=
get_tensor_model_parallel_group
())
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_tensor_model_parallel_group
(),
)
# Sum of exponential of logits along vocab dimension across all GPUs.
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
exp_logits
=
vocab_parallel_logits
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
torch
.
distributed
.
all_reduce
(
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
sum_exp_logits
,
group
=
get_tensor_model_parallel_group
())
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_tensor_model_parallel_group
(),
)
# Loss = log(sum(exp(logits))) - predicted-logit.
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
@@ -87,7 +88,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -87,7 +88,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
loss
=
(
1.0
-
smoothing
)
*
loss
-
smoothing
*
mean_log_probs
loss
=
(
1.0
-
smoothing
)
*
loss
-
smoothing
*
mean_log_probs
ctx
.
label_smoothing
,
ctx
.
vocab_size
=
label_smoothing
,
vocab_size
ctx
.
label_smoothing
,
ctx
.
vocab_size
=
label_smoothing
,
vocab_size
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
# Store softmax, target-mask and masked-target for backward pass.
# Store softmax, target-mask and masked-target for backward pass.
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
...
@@ -108,8 +108,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -108,8 +108,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
grad_2d
.
device
)
device
=
grad_2d
.
device
)
softmax_update
=
1.0
-
target_mask
.
view
(
-
1
).
float
()
softmax_update
=
1.0
-
target_mask
.
view
(
-
1
).
float
()
...
...
megatron/core/tensor_parallel/data.py
View file @
3aca1415
...
@@ -8,15 +8,16 @@ from megatron.core.parallel_state import (
...
@@ -8,15 +8,16 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_src_rank
,
get_tensor_model_parallel_src_rank
,
)
)
_MAX_DATA_DIM
=
5
_MAX_DATA_DIM
=
5
def
_check_data_types
(
keys
,
data
,
target_dtype
):
def
_check_data_types
(
keys
,
data
,
target_dtype
):
"""Check that all the keys have the same target data type."""
"""Check that all the keys have the same target data type."""
for
key
in
keys
:
for
key
in
keys
:
assert
data
[
key
].
dtype
==
target_dtype
,
'{} has data type {} which '
\
assert
data
[
key
].
dtype
==
target_dtype
,
(
'{} has data type {} which '
'is different than {}'
.
format
(
key
,
data
[
key
].
dtype
,
target_dtype
)
'is different than {}'
.
format
(
key
,
data
[
key
].
dtype
,
target_dtype
)
)
def
_build_key_size_numel_dictionaries
(
keys
,
data
):
def
_build_key_size_numel_dictionaries
(
keys
,
data
):
...
@@ -36,8 +37,9 @@ def _build_key_size_numel_dictionaries(keys, data):
...
@@ -36,8 +37,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
# Move to GPU and broadcast.
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
torch
.
distributed
.
broadcast
(
sizes_cuda
,
get_tensor_model_parallel_src_rank
(),
torch
.
distributed
.
broadcast
(
group
=
get_tensor_model_parallel_group
())
sizes_cuda
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
()
)
# Move back to cpu and unpack.
# Move back to cpu and unpack.
sizes_cpu
=
sizes_cuda
.
cpu
()
sizes_cpu
=
sizes_cuda
.
cpu
()
...
@@ -74,24 +76,21 @@ def broadcast_data(keys, data, datatype):
...
@@ -74,24 +76,21 @@ def broadcast_data(keys, data, datatype):
"""
"""
# Build (key, size) and (key, number of elements) dictionaries along
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
# with the total number of elements on all ranks.
key_size
,
key_numel
,
total_numel
=
_build_key_size_numel_dictionaries
(
keys
,
key_size
,
key_numel
,
total_numel
=
_build_key_size_numel_dictionaries
(
keys
,
data
)
data
)
# Pack on rank zero.
# Pack on rank zero.
if
get_tensor_model_parallel_rank
()
==
0
:
if
get_tensor_model_parallel_rank
()
==
0
:
# Check that all keys have the same data type.
# Check that all keys have the same data type.
_check_data_types
(
keys
,
data
,
datatype
)
_check_data_types
(
keys
,
data
,
datatype
)
# Flatten the data associated with the keys
# Flatten the data associated with the keys
flatten_data
=
torch
.
cat
(
flatten_data
=
torch
.
cat
([
data
[
key
].
contiguous
().
view
(
-
1
)
for
key
in
keys
],
dim
=
0
).
cuda
()
[
data
[
key
].
contiguous
().
view
(
-
1
)
for
key
in
keys
],
dim
=
0
).
cuda
()
else
:
else
:
flatten_data
=
torch
.
empty
(
total_numel
,
flatten_data
=
torch
.
empty
(
total_numel
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
datatype
)
device
=
torch
.
cuda
.
current_device
(),
dtype
=
datatype
)
# Broadcast
# Broadcast
torch
.
distributed
.
broadcast
(
flatten_data
,
get_tensor_model_parallel_src_rank
(),
torch
.
distributed
.
broadcast
(
group
=
get_tensor_model_parallel_group
())
flatten_data
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
()
)
# Unpack
# Unpack
output
=
{}
output
=
{}
...
...
megatron/core/tensor_parallel/layers.py
View file @
3aca1415
...
@@ -5,37 +5,33 @@
...
@@ -5,37 +5,33 @@
import
math
import
math
import
os
import
os
from
typing
import
Optional
import
warnings
import
warnings
from
typing
import
Callable
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
import
torch.nn.init
as
init
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.cuda.amp
import
custom_fwd
,
custom_bwd
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
get_global_memory_buffer
,
)
)
from
.mappings
import
(
from
.mappings
import
(
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_sequence_parallel_region
,
gather_from_sequence_parallel_region
,
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
scatter_to_tensor_model_parallel_region
,
)
)
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.utils
import
(
from
.utils
import
VocabUtility
,
divide
,
split_tensor_along_last_dim
divide
,
split_tensor_along_last_dim
,
VocabUtility
,
)
_grad_accum_fusion_available
=
True
_grad_accum_fusion_available
=
True
try
:
try
:
...
@@ -43,14 +39,17 @@ try:
...
@@ -43,14 +39,17 @@ try:
except
ImportError
:
except
ImportError
:
_grad_accum_fusion_available
=
False
_grad_accum_fusion_available
=
False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'partition_dim'
:
-
1
,
'tensor_model_parallel'
:
False
,
'partition_stride'
:
1
}
'partition_dim'
:
-
1
,
'partition_stride'
:
1
,
}
def
param_is_not_tensor_parallel_duplicate
(
param
):
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
param
.
tensor_model_parallel
)
or
(
param
.
tensor_model_parallel
)
or
(
get_
tensor_model_parallel
_rank
()
==
0
get_tensor_model_parallel_rank
()
==
0
)
)
def
set_tensor_model_parallel_attributes
(
tensor
,
is_parallel
,
dim
,
stride
):
def
set_tensor_model_parallel_attributes
(
tensor
,
is_parallel
,
dim
,
stride
):
...
@@ -67,6 +66,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
...
@@ -67,6 +66,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def
maybe_set
(
attribute
,
value
):
def
maybe_set
(
attribute
,
value
):
if
not
hasattr
(
tensor
,
attribute
):
if
not
hasattr
(
tensor
,
attribute
):
setattr
(
tensor
,
attribute
,
value
)
setattr
(
tensor
,
attribute
,
value
)
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
maybe_set
(
attribute
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
[
attribute
])
maybe_set
(
attribute
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
[
attribute
])
...
@@ -74,51 +74,52 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
...
@@ -74,51 +74,52 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def
copy_tensor_model_parallel_attributes
(
destination_tensor
,
source_tensor
):
def
copy_tensor_model_parallel_attributes
(
destination_tensor
,
source_tensor
):
def
maybe_copy
(
attribute
):
def
maybe_copy
(
attribute
):
if
hasattr
(
source_tensor
,
attribute
):
if
hasattr
(
source_tensor
,
attribute
):
setattr
(
destination_tensor
,
attribute
,
setattr
(
destination_tensor
,
attribute
,
getattr
(
source_tensor
,
attribute
))
getattr
(
source_tensor
,
attribute
))
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
maybe_copy
(
attribute
)
maybe_copy
(
attribute
)
def
_initialize_affine_weight_gpu
(
weight
,
init_method
,
def
_initialize_affine_weight_gpu
(
weight
,
init_method
,
partition_dim
,
stride
=
1
):
partition_dim
,
stride
=
1
):
"""Initialize affine weight for model parallel on GPU."""
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
set_tensor_model_parallel_attributes
(
is_parallel
=
True
,
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
dim
=
partition_dim
,
)
stride
=
stride
)
with
get_cuda_rng_tracker
().
fork
():
with
get_cuda_rng_tracker
().
fork
():
init_method
(
weight
)
init_method
(
weight
)
def
_initialize_affine_weight_cpu
(
weight
,
output_size
,
input_size
,
def
_initialize_affine_weight_cpu
(
per_partition_size
,
partition_dim
,
weight
,
init_method
,
stride
=
1
,
output_size
,
return_master_weight
=
False
,
input_size
,
*
,
params_dtype
=
torch
.
float32
):
per_partition_size
,
partition_dim
,
init_method
,
stride
=
1
,
return_master_weight
=
False
,
*
,
params_dtype
=
torch
.
float32
,
):
"""Initialize affine weight for model parallel.
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
Build the master weight on all processes and scatter
the relevant chunk."""
the relevant chunk."""
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
set_tensor_model_parallel_attributes
(
is_parallel
=
True
,
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
dim
=
partition_dim
,
)
stride
=
stride
)
# Initialize master weight
# Initialize master weight
master_weight
=
torch
.
empty
(
output_size
,
input_size
,
master_weight
=
torch
.
empty
(
output_size
,
input_size
,
dtype
=
torch
.
float
,
requires_grad
=
False
)
dtype
=
torch
.
float
,
requires_grad
=
False
)
init_method
(
master_weight
)
init_method
(
master_weight
)
master_weight
=
master_weight
.
to
(
dtype
=
params_dtype
)
master_weight
=
master_weight
.
to
(
dtype
=
params_dtype
)
# Split and copy
# Split and copy
per_partition_per_stride_size
=
divide
(
per_partition_size
,
stride
)
per_partition_per_stride_size
=
divide
(
per_partition_size
,
stride
)
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
dim
=
partition_dim
)
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
my_weight_list
=
weight_list
[
rank
::
world_size
]
my_weight_list
=
weight_list
[
rank
::
world_size
]
...
@@ -140,17 +141,17 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -140,17 +141,17 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
embedding_dim: size of hidden state.
Keyword Arguments:
Keyword Arguments:
init_method: method to initialize weights.
config: A megatron.core.ModelParallelConfig object
params_dtype
use_cpu_initialization
perform_initialization
"""
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
*
,
def
__init__
(
init_method
=
init
.
xavier_normal_
,
self
,
params_dtype
:
torch
.
dtype
=
torch
.
float32
,
num_embeddings
:
int
,
use_cpu_initialization
:
bool
=
False
,
embedding_dim
:
int
,
perform_initialization
:
bool
=
True
):
*
,
init_method
:
Callable
,
config
:
ModelParallelConfig
,
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
super
(
VocabParallelEmbedding
,
self
).
__init__
()
# Keep the input dimensions.
# Keep the input dimensions.
self
.
num_embeddings
=
num_embeddings
self
.
num_embeddings
=
num_embeddings
...
@@ -158,52 +159,68 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -158,52 +159,68 @@ class VocabParallelEmbedding(torch.nn.Module):
# Set the detauls for compatibility.
# Set the detauls for compatibility.
self
.
padding_idx
=
None
self
.
padding_idx
=
None
self
.
max_norm
=
None
self
.
max_norm
=
None
self
.
norm_type
=
2.
self
.
norm_type
=
2.
0
self
.
scale_grad_by_freq
=
False
self
.
scale_grad_by_freq
=
False
self
.
sparse
=
False
self
.
sparse
=
False
self
.
_weight
=
None
self
.
_weight
=
None
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
\
(
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
vocab_start_index
,
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
vocab_end_index
,
self
.
tensor_model_parallel_size
)
)
=
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
tensor_model_parallel_size
self
.
vocab_start_index
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
self
.
vocab_start_index
# Allocate weights and initialize.
# Allocate weights and initialize.
if
use_cpu_initialization
:
if
config
.
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
=
Parameter
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
torch
.
empty
(
dtype
=
params_dtype
))
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
dtype
=
config
.
params_dtype
if
perform_initialization
:
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_cpu
(
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
weight
,
self
.
num_embeddings_per_partition
,
0
,
init_method
,
self
.
num_embeddings
,
params_dtype
=
params_dtype
)
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
,
params_dtype
=
config
.
params_dtype
,
)
else
:
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
=
Parameter
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
torch
.
empty
(
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
self
.
num_embeddings_per_partition
,
if
perform_initialization
:
self
.
embedding_dim
,
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
device
=
torch
.
cuda
.
current_device
(),
partition_dim
=
0
,
stride
=
1
)
dtype
=
config
.
params_dtype
,
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
if
self
.
tensor_model_parallel_size
>
1
:
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
masked_input
[
input_mask
]
=
0
else
:
else
:
masked_input
=
input_
masked_input
=
input_
# Get the embeddings.
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
,
output_parallel
=
F
.
embedding
(
self
.
padding_idx
,
self
.
max_norm
,
masked_input
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
weight
,
self
.
sparse
)
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
# Mask the output embedding.
# Mask the output embedding.
if
self
.
tensor_model_parallel_size
>
1
:
if
self
.
tensor_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
output_parallel
[
input_mask
,
:]
=
0.0
...
@@ -212,13 +229,97 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -212,13 +229,97 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
return
output
class
LinearWithFrozenWeight
(
torch
.
autograd
.
Function
):
"""Linear operator that does not calculate gradient for weight.
This op and LinearWithGradAccumulationAndAsyncCommunication performs
mathematically-identical forward and DGRAD.
Conceptually this op is the same as torch.nn.functional.linear with
weight.requires_grad==False, but in experiments they are not identical
mathematically. """
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
input
,
weight
,
bias
,
):
ctx
.
save_for_backward
(
weight
)
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
(
weight
,)
=
ctx
.
saved_tensors
grad_input
=
grad_output
.
matmul
(
weight
)
return
grad_input
,
None
,
None
def
linear_with_frozen_weight
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel
:
bool
,
)
->
torch
.
Tensor
:
"""Linear layer execution with weight.requires_grad == False.
This function handles linear layers with weight frozen (untrainable).
In the forward, it only saves weight and does not save input activations.
In the backward, it does not perform weight gradient calculation, or
weight gradient allreduce.
Arguments:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
async_grad_allreduce (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
if
sequence_parallel
:
input
=
gather_from_sequence_parallel_region
(
input
,
tensor_parallel_output_grad
=
True
)
else
:
input
=
input
args
=
[
input
,
weight
,
bias
,
]
return
LinearWithFrozenWeight
.
apply
(
*
args
)
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""See linear_with_grad_accumulation_and_async_allreduce"""
"""See linear_with_grad_accumulation_and_async_allreduce"""
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
def
forward
(
async_grad_allreduce
,
sequence_parallel
):
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel
,
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
...
@@ -230,12 +331,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -230,12 +331,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
all_gather_buffer
=
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
()
input
,
)
group
=
get_tensor_model_parallel_group
())
total_input
=
all_gather_buffer
total_input
=
all_gather_buffer
else
:
else
:
total_input
=
input
total_input
=
input
...
@@ -256,12 +355,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -256,12 +355,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
all_gather_buffer
=
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
handle
=
torch
.
distributed
.
_all_gather_base
(
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
input
,
)
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
# gather is scheduled before the input gradient computation
...
@@ -273,43 +370,49 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -273,43 +370,49 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if
ctx
.
sequence_parallel
:
if
ctx
.
sequence_parallel
:
handle
.
wait
()
handle
.
wait
()
# Doing gather + slicing during the NeMo forward pass can make this tensor
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
# Convert the tensor shapes to 2D for execution compatibility
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
2
])
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
]
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
)
total_input
.
shape
[
2
])
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
]
)
if
ctx
.
async_grad_allreduce
:
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
# all-reduce is scheduled before the weight gradient computation
if
ctx
.
sequence_parallel
:
if
ctx
.
sequence_parallel
:
assert
not
ctx
.
async_grad_allreduce
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
sub_grad_input
=
torch
.
empty
(
device
=
torch
.
cuda
.
current_device
(),
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
requires_grad
=
False
)
)
# reduce_scatter
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
group
=
get_tensor_model_parallel_group
(),
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
async_op
=
True
)
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
# reduce scatter is scheduled before the weight gradient computation
if
ctx
.
gradient_accumulation_fusion
:
if
ctx
.
gradient_accumulation_fusion
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
elif
weight
.
main_grad
.
dtype
==
torch
.
float16
:
total_input
,
grad_output
,
weight
.
main_grad
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
)
elif
weight
.
main_grad
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
else
:
else
:
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
grad_weight
=
None
grad_weight
=
None
...
@@ -326,13 +429,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -326,13 +429,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
def
linear_with_grad_accumulation_and_async_allreduce
(
def
linear_with_grad_accumulation_and_async_allreduce
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel
_enabled
:
bool
,
sequence_parallel
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Linear layer execution with asynchronous communication and
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
gradient accumulation fusion in backprop.
...
@@ -378,10 +482,10 @@ def linear_with_grad_accumulation_and_async_allreduce(
...
@@ -378,10 +482,10 @@ def linear_with_grad_accumulation_and_async_allreduce(
async_grad_allreduce (bool required): Do the allreduce of input
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients asyncronously with the computation of weight
gradients. If sequence_parallel
_enabled
is True, this must be
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
False, as no all reduce is performed.
sequence_parallel
_enabled
(bool required): Indicates that sequence
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
all gathered, and the backward pass the input gradients are
reduce scattered.
reduce scattered.
...
@@ -392,29 +496,33 @@ def linear_with_grad_accumulation_and_async_allreduce(
...
@@ -392,29 +496,33 @@ def linear_with_grad_accumulation_and_async_allreduce(
bias
,
bias
,
gradient_accumulation_fusion
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
async_grad_allreduce
,
sequence_parallel
_enabled
,
sequence_parallel
,
]
]
if
not
linear_with_grad_accumulation_and_async_allreduce
.
warned
:
if
not
linear_with_grad_accumulation_and_async_allreduce
.
warned
:
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
sequence_parallel
_enabled
:
if
sequence_parallel
:
warnings
.
warn
(
warnings
.
warn
(
"When using sequence parallelism it is recommended to set the "
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
if
async_grad_allreduce
:
if
async_grad_allreduce
:
warnings
.
warn
(
warnings
.
warn
(
"When using async grad allreduce it is recommended to set the "
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
@@ -436,28 +544,34 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -436,28 +544,34 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
set to False. It returns the master weights
used for initialization.
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
skip_bias_add: If True, do not add the bias term, instead
can be fused with other elementwise operations. we skip
return it to be added by the caller. This
adding bias but instead return it.
enables performance optimations where bias can
async_tensor_model_parallel_allreduce:
be fused with other elementwise operations.
params_dtype:
use_cpu_initialization:
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
gradient_accumulation_fusion:
as a keyword argument `weight` during the forward pass. Note
sequence_parallel_enabled:
that this does not affect bias, which will be allocated if
bias is True. Defaults to False.
config: ModelParallelConfig object
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
def
__init__
(
bias
=
True
,
gather_output
=
True
,
self
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
input_size
,
keep_master_weight_for_test
=
False
,
output_size
,
skip_bias_add
=
False
,
*
,
async_tensor_model_parallel_allreduce
=
True
,
config
:
ModelParallelConfig
,
params_dtype
=
torch
.
float32
,
init_method
:
Callable
,
use_cpu_initialization
=
False
,
bias
=
True
,
perform_initialization
=
True
,
gather_output
=
False
,
gradient_accumulation_fusion
=
False
,
stride
=
1
,
sequence_parallel_enabled
:
bool
=
False
,
keep_master_weight_for_test
=
False
,
):
skip_bias_add
=
False
,
skip_weight_param_allocation
:
bool
=
False
,
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
super
(
ColumnParallelLinear
,
self
).
__init__
()
# Keep input parameters
# Keep input parameters
...
@@ -468,105 +582,151 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -468,105 +582,151 @@ class ColumnParallelLinear(torch.nn.Module):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
self
.
config
=
config
# Parameters.
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# we allocate the transpose.
# Initialize weight.
# Initialize weight.
if
use_cpu_initialization
:
if
not
skip_weight_param_allocation
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
if
config
.
use_cpu_initialization
:
self
.
input_size
,
self
.
weight
=
Parameter
(
dtype
=
params_dtype
))
torch
.
empty
(
if
perform_initialization
:
self
.
output_size_per_partition
,
self
.
input_size
,
dtype
=
config
.
params_dtype
self
.
master_weight
=
_initialize_affine_weight_cpu
(
)
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
)
self
.
output_size_per_partition
,
0
,
init_method
,
if
config
.
perform_initialization
:
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
)
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size_per_partition
,
0
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
,
)
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
stride
)
else
:
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
=
None
self
.
output_size_per_partition
,
self
.
input_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
if
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
stride
)
if
bias
:
if
bias
:
if
use_cpu_initialization
:
if
config
.
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
))
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
config
.
params_dtype
)
)
else
:
else
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
self
.
output_size_per_partition
,
torch
.
empty
(
device
=
torch
.
cuda
.
current_device
(),
self
.
output_size_per_partition
,
dtype
=
params_dtype
))
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
)
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
stride
)
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
stride
)
# Always initialize bias to zero.
if
config
.
perform_initialization
:
with
torch
.
no_grad
():
# Always initialize bias to zero.
self
.
bias
.
zero_
()
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
self
.
async_tensor_model_parallel_allreduce
=
(
async_tensor_model_parallel_allreduce
and
config
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
world_size
>
1
)
)
if
sequence_parallel_enabled
:
if
world_size
<=
1
:
warnings
.
warn
(
f
"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is
{
world_size
}
. "
f
"Disabling sequence parallel."
)
sequence_parallel_enabled
=
False
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
gradient_accumulation_fusion
:
self
.
sequence_parallel
=
config
.
sequence_parallel
if
not
_grad_accum_fusion_available
:
if
self
.
sequence_parallel
and
world_size
<=
1
:
raise
RuntimeError
(
warnings
.
warn
(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
f
"`sequence_parallel` is set to `True`, but tensor model parallel size is
{
world_size
}
. "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
f
"Disabling sequence parallel."
"module is not found. To use gradient_accumulation_fusion you must "
)
"install APEX with --cpp_ext and --cuda_ext. For example: "
self
.
sequence_parallel
=
False
"pip install --global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext .
\"
"
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
if
config
.
gradient_accumulation_fusion
and
not
_grad_accum_fusion_available
:
"gradient accumulation fusion."
raise
RuntimeError
(
)
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
"pip install --global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext .
\"
"
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self
.
gradient_accumulation_fusion
=
config
.
gradient_accumulation_fusion
if
self
.
async_tensor_model_parallel_allreduce
and
self
.
sequence_parallel
_enabled
:
if
self
.
async_tensor_model_parallel_allreduce
and
self
.
sequence_parallel
:
raise
RuntimeError
(
raise
RuntimeError
(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel
_enabled
` "
"`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
"cannot be enabled at the same time."
"cannot be enabled at the same time."
)
)
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Forward of ColumnParallelLinear
"""Forward of ColumnParallelLinear
Args:
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional): weight tensor to use, compulsory when
skip_weight_param_allocation is True.
Returns:
Returns:
- output
- output
- bias
- bias
"""
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
else
:
# Check the weight passed in is the correct shape
expected_shape
=
(
self
.
output_size_per_partition
,
self
.
input_size
)
if
weight
.
shape
!=
expected_shape
:
raise
RuntimeError
(
f
"supplied weight's shape is
{
tuple
(
weight
.
shape
)
}
, "
f
"not
{
expected_shape
}
as expected"
)
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
or
\
if
self
.
async_tensor_model_parallel_allreduce
or
self
.
sequence_parallel
:
self
.
sequence_parallel_enabled
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
linear_with_grad_accumulation_and_async_allreduce
(
if
not
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
input
=
input_parallel
,
weight
=
self
.
weight
,
weight
=
weight
,
bias
=
bias
,
bias
=
bias
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
self
.
async_tensor_model_parallel_allreduce
,
async_grad_allreduce
=
self
.
async_tensor_model_parallel_allreduce
,
sequence_parallel
_enabled
=
self
.
sequence_parallel
_enabled
,
sequence_parallel
=
self
.
sequence_parallel
,
)
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
assert
not
self
.
sequence_parallel
_enabled
assert
not
self
.
sequence_parallel
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
@@ -601,27 +761,27 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -601,27 +761,27 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
set to False. It returns the master weights
used for initialization.
used for initialization.
skip_bias_add: This was added to enable performance optimization where bias
skip_bias_add: If True, do not add the bias term, instead
can be fused with other elementwise operations. We skip
return it to be added by the caller. This
adding bias but instead return it.
enables performance optimations where bias can
params_dtype:
be fused with other elementwise operations.
use_cpu_initialization:
config: ModelParallelConfig object
perform_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
def
__init__
(
bias
=
True
,
input_is_parallel
=
False
,
self
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
input_size
:
int
,
keep_master_weight_for_test
=
False
,
output_size
:
int
,
skip_bias_add
=
False
,
*
,
params_dtype
=
torch
.
float32
,
config
:
ModelParallelConfig
,
use_cpu_initialization
=
False
,
init_method
:
Callable
,
perform_initialization
=
True
,
bias
:
bool
=
True
,
gradient_accumulation_fusion
=
False
,
input_is_parallel
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
stride
:
int
=
1
,
):
keep_master_weight_for_test
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
):
super
(
RowParallelLinear
,
self
).
__init__
()
super
(
RowParallelLinear
,
self
).
__init__
()
# Keep input parameters
# Keep input parameters
...
@@ -632,49 +792,68 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -632,49 +792,68 @@ class RowParallelLinear(torch.nn.Module):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
self
.
config
=
config
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
self
.
gradient_accumulation_fusion
=
config
.
gradient_accumulation_fusion
if
self
.
sequence_parallel_enabled
and
not
self
.
input_is_parallel
:
self
.
sequence_parallel
=
config
.
sequence_parallel
raise
RuntimeError
(
"To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`"
)
if
self
.
sequence_parallel
and
not
self
.
input_is_parallel
:
raise
RuntimeError
(
"To enable `sequence_parallel`, `input_is_parallel` must be `True`"
)
# Parameters.
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# we allocate the transpose.
# Initialize weight.
# Initialize weight.
if
use_cpu_initialization
:
if
config
.
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
weight
=
Parameter
(
self
.
input_size_per_partition
,
torch
.
empty
(
dtype
=
params_dtype
))
self
.
output_size
,
self
.
input_size_per_partition
,
dtype
=
config
.
params_dtype
if
perform_initialization
:
)
)
if
config
.
perform_initialization
:
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
weight
,
self
.
input_size_per_partition
,
1
,
init_method
,
self
.
output_size
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
,
self
.
input_size
,
params_dtype
=
params_dtype
)
self
.
input_size_per_partition
,
1
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
,
params_dtype
=
config
.
params_dtype
,
)
else
:
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
=
Parameter
(
self
.
output_size
,
self
.
input_size_per_partition
,
torch
.
empty
(
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
self
.
output_size
,
if
perform_initialization
:
self
.
input_size_per_partition
,
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
device
=
torch
.
cuda
.
current_device
(),
partition_dim
=
1
,
stride
=
stride
)
dtype
=
config
.
params_dtype
,
)
)
if
config
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
)
if
bias
:
if
bias
:
if
use_cpu_initialization
:
if
config
.
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
config
.
params_dtype
))
dtype
=
params_dtype
))
else
:
else
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
torch
.
empty
(
dtype
=
params_dtype
))
self
.
output_size
,
setattr
(
self
.
bias
,
'sequence_parallel'
,
sequence_parallel_enabled
)
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
# Always initialize bias to zero.
)
with
torch
.
no_grad
():
)
self
.
bias
.
zero_
()
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
if
config
.
perform_initialization
:
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
"""Forward of RowParallelLinear
"""Forward of RowParallelLinear
...
@@ -690,20 +869,24 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -690,20 +869,24 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
assert
not
self
.
sequence_parallel
_enabled
assert
not
self
.
sequence_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
linear_with_grad_accumulation_and_async_allreduce
(
if
not
self
.
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
input
=
input_parallel
,
weight
=
self
.
weight
,
weight
=
self
.
weight
,
bias
=
None
,
bias
=
None
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
False
,
async_grad_allreduce
=
False
,
sequence_parallel
_enabled
=
False
,
sequence_parallel
=
False
,
)
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
if
self
.
sequence_parallel
_enabled
:
if
self
.
sequence_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
...
...
megatron/core/tensor_parallel/mappings.py
View file @
3aca1415
...
@@ -3,10 +3,11 @@
...
@@ -3,10 +3,11 @@
import
torch
import
torch
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
)
)
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
...
@@ -14,7 +15,7 @@ def _reduce(input_):
...
@@ -14,7 +15,7 @@ def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
return
input_
# All-reduce.
# All-reduce.
...
@@ -53,13 +54,14 @@ def _split_along_first_dim(input_):
...
@@ -53,13 +54,14 @@ def _split_along_first_dim(input_):
# Split along first dimension.
# Split along first dimension.
dim_size
=
input_
.
size
()[
0
]
dim_size
=
input_
.
size
()[
0
]
assert
dim_size
%
world_size
==
0
,
\
assert
(
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
%
world_size
==
0
),
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size
=
dim_size
//
world_size
local_dim_size
=
dim_size
//
world_size
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
local_dim_size
dim_offset
=
rank
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
return
output
return
output
...
@@ -97,13 +99,14 @@ def _gather_along_first_dim(input_):
...
@@ -97,13 +99,14 @@ def _gather_along_first_dim(input_):
dim_size
=
list
(
input_
.
size
())
dim_size
=
list
(
input_
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
output
,
input_
.
contiguous
(),
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
()
group
=
get_tensor_model_parallel_group
()
)
)
return
output
return
output
def
_reduce_scatter_along_first_dim
(
input_
):
def
_reduce_scatter_along_first_dim
(
input_
):
"""Reduce-scatter the input tensor across model parallel group."""
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
...
@@ -112,15 +115,16 @@ def _reduce_scatter_along_first_dim(input_):
...
@@ -112,15 +115,16 @@ def _reduce_scatter_along_first_dim(input_):
return
input_
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
,
\
assert
(
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
[
0
]
%
world_size
==
0
),
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_reduce_scatter_base
(
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
.
contiguous
(),
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
()
group
=
get_tensor_model_parallel_group
()
)
)
return
output
return
output
...
@@ -130,7 +134,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
...
@@ -130,7 +134,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
input_
return
input_
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
input_
return
input_
...
@@ -146,7 +150,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
...
@@ -146,7 +150,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
return
_reduce
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_reduce
(
input_
)
return
_reduce
(
input_
)
...
@@ -178,7 +182,7 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
...
@@ -178,7 +182,7 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_gather_along_last_dim
(
input_
)
return
_gather_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_gather_along_last_dim
(
input_
)
return
_gather_along_last_dim
(
input_
)
...
@@ -205,12 +209,12 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
...
@@ -205,12 +209,12 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatinate."""
"""Gather the input from sequence parallel region and concatinate."""
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
,
tensor_parallel_output_grad
=
True
):
def
symbolic
(
graph
,
input_
,
tensor_parallel_output_grad
=
True
):
return
_gather_along_first_dim
(
input_
)
return
_gather_along_first_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
,
tensor_parallel_output_grad
=
True
):
def
forward
(
ctx
,
input_
,
tensor_parallel_output_grad
=
True
):
ctx
.
tensor_parallel_output_grad
=
tensor_parallel_output_grad
ctx
.
tensor_parallel_output_grad
=
tensor_parallel_output_grad
...
@@ -221,8 +225,8 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
...
@@ -221,8 +225,8 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
tensor_parallel_output_grad
=
ctx
.
tensor_parallel_output_grad
tensor_parallel_output_grad
=
ctx
.
tensor_parallel_output_grad
# If the computation graph after the gather operation is
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
# output gradients need to be scattered.
if
tensor_parallel_output_grad
:
if
tensor_parallel_output_grad
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
...
@@ -236,7 +240,7 @@ class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
...
@@ -236,7 +240,7 @@ class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
return
_reduce_scatter_along_first_dim
(
input_
)
...
@@ -250,6 +254,7 @@ class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
...
@@ -250,6 +254,7 @@ class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
# Helper functions.
# Helper functions.
# -----------------
# -----------------
def
copy_to_tensor_model_parallel_region
(
input_
):
def
copy_to_tensor_model_parallel_region
(
input_
):
return
_CopyToModelParallelRegion
.
apply
(
input_
)
return
_CopyToModelParallelRegion
.
apply
(
input_
)
...
@@ -276,4 +281,3 @@ def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=Tru
...
@@ -276,4 +281,3 @@ def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=Tru
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
megatron/core/tensor_parallel/random.py
View file @
3aca1415
...
@@ -7,7 +7,8 @@ import contextlib
...
@@ -7,7 +7,8 @@ import contextlib
import
torch
import
torch
from
torch
import
_C
from
torch
import
_C
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
torch.cuda
import
_lazy_call
from
torch.cuda
import
device
as
device_ctx_manager
from
torch.utils.checkpoint
import
detach_variable
from
torch.utils.checkpoint
import
detach_variable
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
...
@@ -16,14 +17,10 @@ from megatron.core.parallel_state import (
...
@@ -16,14 +17,10 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
.utils
import
(
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
,
)
from
megatron.core.utils
import
safely_set_viewless_tensor_data
from
megatron.core.utils
import
safely_set_viewless_tensor_data
from
.utils
import
gather_split_1d_tensor
,
split_tensor_into_1d_equal_chunks
# Default name for the model parallel rng tracker.
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
...
@@ -42,6 +39,7 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -42,6 +39,7 @@ def _set_cuda_rng_state(new_state, device=-1):
def
cb
():
def
cb
():
with
device_ctx_manager
(
device
):
with
device_ctx_manager
(
device
):
_C
.
_cuda_setRNGState
(
new_state
)
_C
.
_cuda_setRNGState
(
new_state
)
else
:
else
:
# newer PyTorch
# newer PyTorch
if
device
==
-
1
:
if
device
==
-
1
:
...
@@ -61,7 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -61,7 +59,6 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call
(
cb
)
_lazy_call
(
cb
)
class
CudaRNGStatesTracker
:
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
"""Tracker for the cuda RNG states.
...
@@ -169,8 +166,7 @@ def model_parallel_cuda_manual_seed(seed):
...
@@ -169,8 +166,7 @@ def model_parallel_cuda_manual_seed(seed):
# Set the default state.
# Set the default state.
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
# and model parallel state.
# and model parallel state.
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
tensor_model_parallel_seed
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
...
@@ -180,11 +176,11 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -180,11 +176,11 @@ class CheckpointFunction(torch.autograd.Function):
2) the states in the model parallel tracker are also properly
2) the states in the model parallel tracker are also properly
tracked/set/reset.
tracked/set/reset.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
distribute_saved_activations
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_saved_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
distribute_saved_activations
\
ctx
.
distribute_saved_activations
=
distribute_saved_activations
=
distribute_saved_activations
# Copy the rng states.
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -199,8 +195,8 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -199,8 +195,8 @@ class CheckpointFunction(torch.autograd.Function):
if
distribute_saved_activations
:
if
distribute_saved_activations
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
safely_set_viewless_tensor_data
(
safely_set_viewless_tensor_data
(
args
[
0
],
args
[
0
],
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
new_buffer
=
True
)
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
,
new_buffer
=
True
)
)
)
# Store everything.
# Store everything.
ctx
.
save_for_backward
(
*
args
)
ctx
.
save_for_backward
(
*
args
)
...
@@ -210,13 +206,15 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -210,13 +206,15 @@ class CheckpointFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
):
def
backward
(
ctx
,
*
args
):
if
not
torch
.
autograd
.
_is_checkpoint_valid
():
if
not
torch
.
autograd
.
_is_checkpoint_valid
():
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
raise
RuntimeError
(
"please use .backward() if possible"
)
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
inputs
=
ctx
.
saved_tensors
if
ctx
.
distribute_saved_activations
:
if
ctx
.
distribute_saved_activations
:
safely_set_viewless_tensor_data
(
safely_set_viewless_tensor_data
(
inputs
[
0
],
inputs
[
0
],
gather_split_1d_tensor
(
inputs
[
0
].
data
).
view
(
ctx
.
input_0_shape
)
gather_split_1d_tensor
(
inputs
[
0
].
data
).
view
(
ctx
.
input_0_shape
)
)
)
# Store the current states.
# Store the current states.
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -241,13 +239,11 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -241,13 +239,11 @@ class CheckpointFunction(torch.autograd.Function):
if
isinstance
(
outputs
,
torch
.
Tensor
):
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
(
outputs
,)
outputs
=
(
outputs
,)
torch
.
autograd
.
backward
(
outputs
,
args
)
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
for
inp
in
detached_inputs
)
return
(
None
,
None
)
+
grads
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
distribute_saved_activations
,
*
args
):
def
checkpoint
(
function
,
distribute_saved_activations
,
*
args
):
"""Checkpoint a model or part of the model.
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
return
CheckpointFunction
.
apply
(
function
,
distribute_saved_activations
,
*
args
)
distribute_saved_activations
,
*
args
)
megatron/core/tensor_parallel/utils.py
View file @
3aca1415
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
from
typing
import
List
,
Sequence
from
typing
import
List
,
Sequence
from
megatron.core.utils
import
divide
import
torch
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.utils
import
divide
def
split_tensor_along_last_dim
(
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
""" Split a tensor along its last dimension.
""" Split a tensor along its last dimension.
...
@@ -33,6 +33,7 @@ def split_tensor_along_last_dim(
...
@@ -33,6 +33,7 @@ def split_tensor_along_last_dim(
return
tensor_list
return
tensor_list
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
...
@@ -47,14 +48,16 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
...
@@ -47,14 +48,16 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
Default is False
Default is False
"""
"""
partition_size
=
torch
.
numel
(
tensor
)
//
\
partition_size
=
torch
.
numel
(
tensor
)
//
parallel_state
.
get_tensor_model_parallel_world_size
()
parallel_state
.
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
parallel_state
.
get_tensor_model_parallel_rank
()
start_index
=
partition_size
*
parallel_state
.
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
end_index
=
start_index
+
partition_size
if
new_buffer
:
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
data
=
torch
.
empty
(
device
=
torch
.
cuda
.
current_device
(),
partition_size
,
requires_grad
=
False
)
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
...
@@ -70,18 +73,18 @@ def gather_split_1d_tensor(tensor):
...
@@ -70,18 +73,18 @@ def gather_split_1d_tensor(tensor):
Arguments:
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
tensor: A Tensor or view of this rank's portion of the data.
"""
"""
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
numel_gathered
=
torch
.
numel
(
tensor
)
*
parallel_state
.
get_tensor_model_parallel_world_size
()
parallel_state
.
get_tensor_model_parallel_world_size
()
gathered
=
torch
.
empty
(
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
device
=
torch
.
cuda
.
current_device
(),
)
requires_grad
=
False
)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
# internal copies and can potentially cause slow down.
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
torch
.
distributed
.
_all_gather_base
(
group
=
parallel_state
.
get_tensor_model_parallel_group
())
gathered
,
tensor
,
group
=
parallel_state
.
get_tensor_model_parallel_group
()
)
return
gathered
return
gathered
...
@@ -101,7 +104,9 @@ class VocabUtility:
...
@@ -101,7 +104,9 @@ class VocabUtility:
return
index_f
,
index_l
return
index_f
,
index_l
@
staticmethod
@
staticmethod
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
)
->
Sequence
[
int
]:
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
)
->
Sequence
[
int
]:
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
per_partition_vocab_size
,
rank
,
world_size
...
...
megatron/core/transformer/__init__.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
.transformer_config
import
TransformerConfig
megatron/core/transformer/attention.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
abc
import
ABC
,
abstractmethod
import
torch
from
megatron.core
import
parallel_state
,
tensor_parallel
from
megatron.core.models.common.rotary_pos_embedding
import
apply_rotary_pos_emb
from
megatron.core.transformer.custom_layers.transformer_engine
import
(
TEDotProductAttention
,
TELayerNormColumnParallelLinear
,
TERowParallelLinear
,
)
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.utils
import
divide
from
.enums
import
AttnMaskType
from
.transformer_config
import
TransformerConfig
class
Attention
(
MegatronModule
,
ABC
):
"""Attention layer abstract class.
This layer only contains common modules required for the "self attn" and
"cross attn" specializations.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
=
1
,
attn_mask_type
=
AttnMaskType
.
padding
,
):
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
layer_number
=
layer_number
self
.
attn_mask_type
=
attn_mask_type
# For normal attention without groups, num_query_groups == num_attention_heads,
# so these two will be the same
self
.
query_projection_size
=
self
.
config
.
kv_channels
*
self
.
config
.
num_attention_heads
self
.
kv_projection_size
=
self
.
config
.
kv_channels
*
self
.
config
.
num_query_groups
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_attention_head
=
divide
(
self
.
query_projection_size
,
self
.
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
divide
(
self
.
config
.
num_attention_heads
,
world_size
)
self
.
num_query_groups_per_partition
=
divide
(
self
.
config
.
num_query_groups
,
world_size
)
self
.
dot_product_attention
=
TEDotProductAttention
(
config
=
self
.
config
,
layer_number
=
self
.
layer_number
,
attn_mask_type
=
self
.
attn_mask_type
)
self
.
checkpoint_dot_product_attention
=
self
.
config
.
recompute_granularity
==
'selective'
# Output.
self
.
linear_proj
=
TERowParallelLinear
(
self
.
query_projection_size
,
self
.
config
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
output_layer_init_method
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
True
,
)
def
_checkpointed_attention_forward
(
self
,
query
,
key
,
value
,
attention_mask
,
rotary_pos_emb
=
None
):
"""Forward method with selective activation checkpointing."""
def
custom_forward
(
*
inputs
):
query
=
inputs
[
0
]
key
=
inputs
[
1
]
value
=
inputs
[
2
]
attention_mask
=
inputs
[
3
]
output_
=
self
.
dot_product_attention
(
query
,
key
,
value
,
attention_mask
)
return
output_
hidden_states
=
tensor_parallel
.
checkpoint
(
custom_forward
,
False
,
query
,
key
,
value
,
attention_mask
,
rotary_pos_emb
)
return
hidden_states
def
_allocate_memory
(
self
,
inference_max_sequence_length
,
batch_size
,
dtype
):
"""Allocate memory to store kv cache during inference."""
return
torch
.
empty
(
inference_max_sequence_length
,
batch_size
,
self
.
num_query_groups_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
)
def
_adjust_key_value_for_inference
(
self
,
inference_params
,
key
,
value
,
rotary_pos_emb
):
"""
Saves the generated key and value tensors to the end of the buffers in inference_params.
Returns the full size keys and values from the provided inference_params, as well as
adjusted rotary_pos_emb.
Returns a tuple: (key, value, rotary_pos_emb)
"""
if
inference_params
is
None
:
return
key
,
value
,
rotary_pos_emb
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step
=
False
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_length
=
inference_params
.
max_sequence_length
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_length
,
inf_max_batch_size
,
key
.
dtype
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_length
,
inf_max_batch_size
,
value
.
dtype
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
,
)
is_first_step
=
True
else
:
# Get the pre-allocated buffers for this layer
inference_key_memory
,
inference_value_memory
=
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value
key
=
inference_key_memory
[:
sequence_end
,
batch_start
:
batch_end
,
...]
value
=
inference_value_memory
[:
sequence_end
,
batch_start
:
batch_end
,
...]
# adjust the key rotary positional embedding
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if
not
is_first_step
:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb
=
q_pos_emb
[
sequence_end
-
1
:
sequence_end
]
else
:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb
=
q_pos_emb
[:
sequence_end
,
:,
:,
:]
k_pos_emb
=
k_pos_emb
[:
sequence_end
,
:,
:,
:]
rotary_pos_emb
=
(
q_pos_emb
,
k_pos_emb
)
return
key
,
value
,
rotary_pos_emb
@
abstractmethod
def
get_query_key_value_tensors
(
self
,
hidden_states
,
key_value_states
):
"""
This method needs to be implemented based on whether the derived class
is "self-attn" or "cross-attn".
"""
def
forward
(
self
,
hidden_states
,
attention_mask
,
key_value_states
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
,
):
# hidden_states: [sq, b, h]
# For self attention we just duplicate the rotary_pos_emb if it isn't already
if
rotary_pos_emb
is
not
None
and
not
isinstance
(
rotary_pos_emb
,
tuple
):
rotary_pos_emb
=
(
rotary_pos_emb
,)
*
2
# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
query
,
key
,
value
=
self
.
get_query_key_value_tensors
(
hidden_states
,
key_value_states
)
# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key
,
value
,
rotary_pos_emb
=
self
.
_adjust_key_value_for_inference
(
inference_params
,
key
,
value
,
rotary_pos_emb
)
# ================================================
# relative positional embedding (rotary embedding)
# ================================================
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
query
=
apply_rotary_pos_emb
(
query
,
q_pos_emb
)
key
=
apply_rotary_pos_emb
(
key
,
k_pos_emb
)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
if
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
>
1
:
key
=
key
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
value
=
value
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
if
self
.
checkpoint_dot_product_attention
:
core_attn_out
=
self
.
_checkpointed_attention_forward
(
query
,
key
,
value
,
attention_mask
)
else
:
core_attn_out
=
self
.
dot_product_attention
(
query
,
key
,
value
,
attention_mask
)
# =================
# Output. [sq, b, h]
# =================
output
,
bias
=
self
.
linear_proj
(
core_attn_out
)
return
output
,
bias
class
SelfAttention
(
Attention
):
"""Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
=
1
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
().
__init__
(
config
=
config
,
layer_number
=
layer_number
,
attn_mask_type
=
attn_mask_type
)
self
.
linear_qkv
=
TELayerNormColumnParallelLinear
(
self
.
config
.
hidden_size
,
self
.
query_projection_size
+
2
*
self
.
kv_projection_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
False
,
)
def
get_query_key_value_tensors
(
self
,
hidden_states
,
key_value_states
=
None
):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_qkv
,
_
=
self
.
linear_qkv
(
hidden_states
)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape
=
mixed_qkv
.
size
()[:
-
1
]
+
(
self
.
num_query_groups_per_partition
,
(
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
+
2
)
*
self
.
hidden_size_per_attention_head
),
)
mixed_qkv
=
mixed_qkv
.
view
(
*
new_tensor_shape
)
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(
query
,
key
,
value
)
=
torch
.
split
(
mixed_qkv
,
[
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
),
self
.
hidden_size_per_attention_head
,
self
.
hidden_size_per_attention_head
,
],
dim
=
3
,
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query
=
query
.
reshape
(
query
.
size
(
0
),
query
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
return
query
,
key
,
value
class
CrossAttention
(
Attention
):
"""Cross-attention layer class
Cross-attention layer takes input with size [s, b, h] and context with size
[s, b, h] and returns output of the same size.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
=
1
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
().
__init__
(
config
=
config
,
layer_number
=
layer_number
,
attn_mask_type
=
attn_mask_type
)
if
self
.
config
.
num_query_groups
!=
self
.
config
.
num_attention_heads
:
raise
ValueError
(
f
"Group query attention is not currently supported in cross attention."
)
assert
self
.
query_projection_size
==
self
.
kv_projection_size
self
.
linear_q
=
TELayerNormColumnParallelLinear
(
self
.
config
.
hidden_size
,
self
.
query_projection_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
False
,
)
self
.
linear_kv
=
TELayerNormColumnParallelLinear
(
self
.
config
.
hidden_size
,
2
*
self
.
kv_projection_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
False
,
)
def
get_query_key_value_tensors
(
self
,
hidden_states
,
key_value_states
):
"""
Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
from `key_value_states`.
"""
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv
,
_
=
self
.
linear_kv
(
key_value_states
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
,
)
mixed_kv
=
mixed_kv
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key
,
value
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query
,
_
=
self
.
linear_q
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
)
query
=
query
.
view
(
*
new_tensor_shape
)
return
query
,
key
,
value
megatron/core/transformer/custom_layers/__init__.py
0 → 100644
View file @
3aca1415
megatron/core/transformer/custom_layers/transformer_engine.py
0 → 100644
View file @
3aca1415
from
importlib.metadata
import
version
from
typing
import
Callable
import
torch
import
transformer_engine
as
te
from
pkg_resources
import
packaging
from
megatron.core.parallel_state
import
get_tensor_model_parallel_group
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.transformer_config
import
TransformerConfig
def
_get_extra_te_kwargs
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
{}
from
importlib.metadata
import
version
from
pkg_resources
import
packaging
te_version
=
packaging
.
version
.
Version
(
version
(
"transformer-engine"
))
if
te_version
>=
packaging
.
version
.
Version
(
"0.12.0"
):
if
config
.
use_cpu_initialization
:
extra_transformer_engine_kwargs
[
"device"
]
=
'cpu'
else
:
extra_transformer_engine_kwargs
[
"device"
]
=
torch
.
cuda
.
current_device
()
return
extra_transformer_engine_kwargs
class
TENorm
:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
def
__new__
(
cls
,
config
:
TransformerConfig
,
hidden_size
:
int
,
eps
:
float
=
1e-5
,
sequence_parallel
:
bool
=
False
,
normalization
=
"LayerNorm"
,
**
kwargs
):
zero_centered_gamma
=
kwargs
.
get
(
'zero_centered_gamma'
,
False
)
if
normalization
==
"LayerNorm"
:
instance
=
te
.
pytorch
.
LayerNorm
(
hidden_size
=
hidden_size
,
eps
=
eps
,
sequence_parallel
=
sequence_parallel
,
zero_centered_gamma
=
zero_centered_gamma
,
**
_get_extra_te_kwargs
(
config
),
)
elif
normalization
==
"RMSNorm"
:
assert
hasattr
(
te
.
pytorch
,
"RMSNorm"
),
"Transformer-Engine >= v0.11 required to use this feature"
instance
=
te
.
pytorch
.
RMSNorm
(
hidden_size
=
hidden_size
,
eps
=
eps
,
sequence_parallel
=
sequence_parallel
,
zero_centered_gamma
=
zero_centered_gamma
,
**
_get_extra_te_kwargs
(
config
),
)
else
:
raise
Exception
(
'Only LayerNorm and RMSNorm are curently supported'
)
return
instance
class
TELinear
(
te
.
pytorch
.
Linear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
config
:
TransformerConfig
,
parallel_mode
:
str
,
init_method
:
Callable
,
*
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
**
kwargs
):
self
.
config
=
config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self
.
te_return_bias
=
skip_bias_add
and
bias
super
().
__init__
(
in_features
=
input_size
,
out_features
=
output_size
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
fuse_wgrad_accumulation
=
self
.
config
.
gradient_accumulation_fusion
,
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
get_cuda_rng_tracker
,
init_method
=
init_method
,
params_dtype
=
self
.
config
.
params_dtype
,
parallel_mode
=
parallel_mode
,
bias
=
bias
,
return_bias
=
self
.
te_return_bias
,
**
_get_extra_te_kwargs
(
config
),
**
kwargs
,
)
def
forward
(
self
,
x
):
out
=
super
().
forward
(
x
)
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if
self
.
te_return_bias
:
return
out
return
out
,
None
class
TELayerNormColumnParallelLinear
(
te
.
pytorch
.
LayerNormLinear
):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
config
:
TransformerConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
**
kwargs
):
self
.
config
=
config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self
.
te_return_bias
=
skip_bias_add
and
bias
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
te_version
=
packaging
.
version
.
Version
(
version
(
"transformer-engine"
))
if
te_version
>=
packaging
.
version
.
Version
(
"0.11.0"
):
kwargs
[
"normalization"
]
=
self
.
config
.
normalization
super
().
__init__
(
in_features
=
input_size
,
out_features
=
output_size
,
bias
=
bias
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
fuse_wgrad_accumulation
=
self
.
config
.
gradient_accumulation_fusion
,
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
get_cuda_rng_tracker
,
init_method
=
init_method
,
params_dtype
=
self
.
config
.
params_dtype
,
parallel_mode
=
"column"
,
return_bias
=
self
.
te_return_bias
,
zero_centered_gamma
=
self
.
config
.
layernorm_zero_centered_gamma
,
**
_get_extra_te_kwargs
(
config
),
**
kwargs
,
)
def
forward
(
self
,
x
):
out
=
super
().
forward
(
x
)
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if
self
.
te_return_bias
:
return
out
return
out
,
None
class
TEColumnParallelLinear
(
TELinear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
config
:
TransformerConfig
,
**
kwargs
):
self
.
config
=
config
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
config
=
self
.
config
,
parallel_mode
=
"column"
,
**
kwargs
,
)
class
TERowParallelLinear
(
TELinear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
config
:
TransformerConfig
,
**
kwargs
):
self
.
config
=
config
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
config
=
self
.
config
,
parallel_mode
=
"row"
,
**
kwargs
,
)
class
TEDotProductAttention
(
te
.
pytorch
.
DotProductAttention
):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
=
1
,
attn_mask_type
:
AttnMaskType
=
AttnMaskType
.
padding
,
**
kwargs
):
self
.
config
=
config
super
().
__init__
(
num_attention_heads
=
self
.
config
.
num_attention_heads
,
kv_channels
=
self
.
config
.
kv_channels
,
attention_dropout
=
self
.
config
.
attention_dropout
,
layer_number
=
layer_number
,
attn_mask_type
=
attn_mask_type
.
name
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
get_cuda_rng_tracker
,
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
**
kwargs
,
)
megatron/core/transformer/dot_product_attention.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
math
import
torch
from
torch
import
Tensor
from
megatron.core
import
parallel_state
,
tensor_parallel
from
megatron.core.fusions.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.utils
import
attention_mask_func
from
megatron.core.utils
import
divide
class
DotProductAttention
(
MegatronModule
):
"""
Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
makes activation checkpointing more efficient for LLMs (20B+).
See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
We use the following notation:
h: hidden size
n: number of attention heads
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
=
1
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
().
__init__
(
config
=
config
)
self
.
config
:
TransformerConfig
=
config
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attn_mask_type
=
attn_mask_type
projection_size
=
self
.
config
.
kv_channels
*
config
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
divide
(
projection_size
,
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
divide
(
config
.
num_attention_heads
,
world_size
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
config
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
input_in_fp16
=
self
.
config
.
fp16
,
input_in_bf16
=
self
.
config
.
bf16
,
attn_mask_type
=
self
.
attn_mask_type
,
scaled_masked_softmax_fusion
=
self
.
config
.
masked_softmax_fusion
,
mask_func
=
attention_mask_func
,
softmax_in_fp32
=
self
.
config
.
attention_softmax_in_fp32
,
scale
=
coeff
,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
self
.
config
.
attention_dropout
)
def
forward
(
self
,
query_layer
:
Tensor
,
key_layer
:
Tensor
,
value_layer
:
Tensor
,
attention_mask
:
Tensor
):
# ===================================
# Raw attention scores. [b, n/p, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
# This will be a simple view when doing normal attention, but in group query attention
# the key and value tensors are repeated to match the queries so you can't use simple strides
# to extract the queries.
query_layer
=
query_layer
.
reshape
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
parallel_state
.
get_global_memory_buffer
().
get_tensor
(
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
query_layer
.
dtype
,
"mpu"
,
)
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_input_buffer
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
),
)
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs
:
Tensor
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if
not
self
.
config
.
sequence_parallel
:
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
),
)
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
return
context_layer
megatron/core/transformer/enums.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
enum
# can we get rid of this?
# it's being used in pipeline schedules
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
# class LayerType(enum.Enum):
# encoder = 1
# decoder = 2
class
AttnType
(
enum
.
Enum
):
self_attn
=
1
cross_attn
=
2
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
megatron/core/transformer/identity_op.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
torch
class
IdentityOp
(
torch
.
nn
.
Module
):
"""
This is a placeholder for IdentityOp (NoOp)
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
IdentityOp
,
self
).
__init__
()
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
return
x
megatron/core/transformer/mlp.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
megatron.core
import
tensor_parallel
from
megatron.core.fusions.fused_bias_gelu
import
bias_gelu_impl
from
megatron.core.transformer.custom_layers.transformer_engine
import
(
TELayerNormColumnParallelLinear
,
TERowParallelLinear
,
)
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.transformer_config
import
TransformerConfig
class
MLP
(
MegatronModule
):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
Returns an output and a bias to be added to the output.
If config.add_bias_linear is False, the bias returned is None.
We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def
__init__
(
self
,
config
:
TransformerConfig
):
super
().
__init__
(
config
=
config
)
self
.
config
:
TransformerConfig
=
config
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size
=
self
.
config
.
ffn_hidden_size
if
self
.
config
.
gated_linear_unit
:
ffn_hidden_size
*=
2
self
.
linear_fc1
=
TELayerNormColumnParallelLinear
(
self
.
config
.
hidden_size
,
ffn_hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
True
,
)
if
self
.
config
.
gated_linear_unit
:
def
glu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
self
.
config
.
activation_func
(
x
[
0
])
*
x
[
1
]
self
.
activation_func
=
glu
else
:
self
.
activation_func
=
self
.
config
.
activation_func
self
.
linear_fc2
=
TERowParallelLinear
(
self
.
config
.
ffn_hidden_size
,
self
.
config
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
output_layer_init_method
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
True
,
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4 * h/p]
intermediate_parallel
,
bias_parallel
=
self
.
linear_fc1
(
hidden_states
)
if
self
.
config
.
bias_gelu_fusion
:
assert
self
.
config
.
add_bias_linear
is
True
assert
self
.
activation_func
==
F
.
gelu
intermediate_parallel
=
bias_gelu_impl
(
intermediate_parallel
,
bias_parallel
)
else
:
if
bias_parallel
is
not
None
:
intermediate_parallel
=
intermediate_parallel
+
bias_parallel
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
output
,
output_bias
=
self
.
linear_fc2
(
intermediate_parallel
)
return
output
,
output_bias
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment