Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
96850dfa
Unverified
Commit
96850dfa
authored
Aug 15, 2022
by
Jithun Nair
Committed by
GitHub
Aug 15, 2022
Browse files
Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29
IFU-master-2022-07-29
parents
87fc4125
cc5f83b5
Changes
235
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2687 additions
and
600 deletions
+2687
-600
apex/transformer/layers/layer_norm.py
apex/transformer/layers/layer_norm.py
+99
-0
apex/transformer/log_util.py
apex/transformer/log_util.py
+1
-2
apex/transformer/microbatches.py
apex/transformer/microbatches.py
+40
-17
apex/transformer/parallel_state.py
apex/transformer/parallel_state.py
+329
-43
apex/transformer/pipeline_parallel/p2p_communication.py
apex/transformer/pipeline_parallel/p2p_communication.py
+335
-161
apex/transformer/pipeline_parallel/schedules/__init__.py
apex/transformer/pipeline_parallel/schedules/__init__.py
+10
-14
apex/transformer/pipeline_parallel/schedules/common.py
apex/transformer/pipeline_parallel/schedules/common.py
+250
-70
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
...rmer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
+55
-14
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
...arallel/schedules/fwd_bwd_pipelining_with_interleaving.py
+153
-46
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
...llel/schedules/fwd_bwd_pipelining_without_interleaving.py
+364
-45
apex/transformer/pipeline_parallel/utils.py
apex/transformer/pipeline_parallel/utils.py
+26
-2
apex/transformer/tensor_parallel/__init__.py
apex/transformer/tensor_parallel/__init__.py
+2
-1
apex/transformer/tensor_parallel/data.py
apex/transformer/tensor_parallel/data.py
+15
-6
apex/transformer/tensor_parallel/layers.py
apex/transformer/tensor_parallel/layers.py
+374
-71
apex/transformer/tensor_parallel/mappings.py
apex/transformer/tensor_parallel/mappings.py
+164
-19
apex/transformer/tensor_parallel/memory.py
apex/transformer/tensor_parallel/memory.py
+20
-5
apex/transformer/tensor_parallel/random.py
apex/transformer/tensor_parallel/random.py
+41
-24
apex/transformer/tensor_parallel/utils.py
apex/transformer/tensor_parallel/utils.py
+14
-4
apex/transformer/testing/arguments.py
apex/transformer/testing/arguments.py
+197
-32
apex/transformer/testing/commons.py
apex/transformer/testing/commons.py
+198
-24
No files found.
apex/transformer/layers/layer_norm.py
0 → 100644
View file @
96850dfa
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# NOTE(mkozuki): This file defines two LayerNorm that are compatible with Megatron-LM.
# while avoiding introducing the breaking change of `"sequence_parallel_enabled"` attribute into apex.normalization.FusedLayerNorm
# and apex.contrib.layer_norm.FastLayerNorm.
import
warnings
import
torch
from
apex.normalization
import
FusedLayerNorm
as
OrigFusedLayerNorm
from
apex.normalization
import
MixedFusedLayerNorm
as
OrigMixedFusedLayerNorm
try
:
from
apex.contrib.layer_norm
import
FastLayerNorm
as
OrigFastLayerNorm
except
ImportError
:
HAS_FAST_LAYER_NORM
=
False
else
:
HAS_FAST_LAYER_NORM
=
True
__all__
=
[
"FusedLayerNorm"
,
"FastLayerNorm"
,
"MixedFusedLayerNorm"
,
]
def
_set_sequence_parallel_enabled
(
param
:
torch
.
Tensor
,
sequence_parallel_enabled
:
bool
,
)
->
None
:
setattr
(
param
,
"sequence_parallel_enabled"
,
sequence_parallel_enabled
)
class
FusedLayerNorm
(
OrigFusedLayerNorm
):
def
__init__
(
self
,
normalized_shape
,
eps
:
float
=
1e-5
,
elementwise_affine
:
bool
=
True
,
*
,
sequence_parallel_enabled
:
bool
=
False
,
):
super
().
__init__
(
normalized_shape
=
normalized_shape
,
eps
=
eps
,
elementwise_affine
=
elementwise_affine
,
)
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
self
.
elementwise_affine
:
_set_sequence_parallel_enabled
(
self
.
weight
,
self
.
sequence_parallel_enabled
)
_set_sequence_parallel_enabled
(
self
.
bias
,
self
.
sequence_parallel_enabled
)
# note: MixedFusedLayerNorm is no different from FusedLayerNorm if it's used in `torch.cuda.amp`.
class
MixedFusedLayerNorm
(
OrigMixedFusedLayerNorm
):
def
__init__
(
self
,
normalized_shape
,
eps
:
float
=
1e-5
,
**
kwargs
,
)
->
None
:
self
.
sequence_parallel_enabled
=
kwargs
.
get
(
"sequence_parallel_enabled"
,
False
)
super
().
__init__
(
normalized_shape
=
normalized_shape
,
eps
=
eps
,
**
kwargs
)
if
self
.
sequence_parallel_enabled
:
_set_sequence_parallel_enabled
(
self
.
weight
,
self
.
sequence_parallel_enabled
)
_set_sequence_parallel_enabled
(
self
.
bias
,
self
.
sequence_parallel_enabled
)
if
HAS_FAST_LAYER_NORM
:
class
FastLayerNorm
(
OrigFastLayerNorm
):
def
__init__
(
self
,
hidden_size
,
eps
:
float
=
1e-5
,
*
,
sequence_parallel_enabled
:
bool
=
False
,
):
super
().
__init__
(
hidden_size
=
hidden_size
,
eps
=
eps
)
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
_set_sequence_parallel_enabled
(
self
.
weight
,
self
.
sequence_parallel_enabled
)
_set_sequence_parallel_enabled
(
self
.
bias
,
self
.
sequence_parallel_enabled
)
else
:
class
FastLayerNorm
(
FusedLayerNorm
):
def
__init__
(
self
,
hidden_size
,
eps
:
float
=
1e-5
,
*
,
sequence_parallel_enabled
:
bool
=
False
,
):
warnings
.
warn
(
"`apex.contrib.layer_norm.FastLayerNorm` isn't available thus falling back to `apex.normalization.FusedLayerNorm`"
)
super
().
__init__
(
normalized_shape
=
hidden_size
,
eps
=
eps
,
elementwise_affine
=
True
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
apex/transformer/log_util.py
View file @
96850dfa
from
typing
import
Optional
import
logging
import
logging
import
os
import
os
import
threading
def
get_transformer_logger
(
name
:
str
)
->
logging
.
Logger
:
def
get_transformer_logger
(
name
:
str
)
->
logging
.
Logger
:
...
@@ -16,4 +14,5 @@ def set_logging_level(verbosity) -> None:
...
@@ -16,4 +14,5 @@ def set_logging_level(verbosity) -> None:
verbosity
verbosity
"""
"""
from
apex
import
_library_root_logger
from
apex
import
_library_root_logger
_library_root_logger
.
setLevel
(
verbosity
)
_library_root_logger
.
setLevel
(
verbosity
)
apex/transformer/microbatches.py
View file @
96850dfa
...
@@ -17,13 +17,18 @@ from abc import ABC
...
@@ -17,13 +17,18 @@ from abc import ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
from
apex.transformer.log_util
import
get_transformer_logger
_logger
=
get_transformer_logger
(
__name__
)
def
build_num_microbatches_calculator
(
def
build_num_microbatches_calculator
(
rank
:
int
,
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
global_batch_size
:
int
,
micro_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
data_parallel_size
:
int
,
):
):
# Constant num micro-batches.
# Constant num micro-batches.
if
rampup_batch_size
is
None
:
if
rampup_batch_size
is
None
:
...
@@ -31,8 +36,10 @@ def build_num_microbatches_calculator(
...
@@ -31,8 +36,10 @@ def build_num_microbatches_calculator(
global_batch_size
,
micro_batch_size
,
data_parallel_size
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
)
if
rank
==
0
:
if
rank
==
0
:
print
(
_logger
.
info
(
"setting number of micro-batches to constant {}"
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
"setting number of micro-batches to constant {}"
.
format
(
num_microbatches_calculator
.
get
()
)
)
)
else
:
else
:
...
@@ -45,13 +52,15 @@ def build_num_microbatches_calculator(
...
@@ -45,13 +52,15 @@ def build_num_microbatches_calculator(
batch_size_increment
=
int
(
rampup_batch_size
[
1
])
batch_size_increment
=
int
(
rampup_batch_size
[
1
])
ramup_samples
=
int
(
rampup_batch_size
[
2
])
ramup_samples
=
int
(
rampup_batch_size
[
2
])
if
rank
==
0
:
if
rank
==
0
:
print
(
_logger
.
info
(
"will use batch size rampup starting from global batch "
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"size {} to global batch size {} with batch size increments "
"{} over {} samples."
.
format
(
"{} over {} samples."
.
format
(
start_batch_size
,
global_batch_size
,
batch_size_increment
,
ramup_samples
start_batch_size
,
),
global_batch_size
,
flush
=
True
,
batch_size_increment
,
ramup_samples
,
)
)
)
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatches
(
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
start_batch_size
,
...
@@ -86,7 +95,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -86,7 +95,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
micro_batch_times_data_parallel
=
micro_batch_size
*
data_parallel_size
micro_batch_times_data_parallel
=
micro_batch_size
*
data_parallel_size
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
(
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
(
"global batch size ({}) is not divisible by micro batch size ({})"
"global batch size ({}) is not divisible by micro batch size ({})"
" times data parallel size ({})"
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
" times data parallel size ({})"
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
)
)
self
.
num_micro_batches
=
global_batch_size
//
micro_batch_times_data_parallel
self
.
num_micro_batches
=
global_batch_size
//
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
assert
self
.
num_micro_batches
>=
1
...
@@ -126,7 +137,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -126,7 +137,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self
.
micro_batch_size
=
micro_batch_size
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
self
.
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
(
self
.
micro_batch_size
*
self
.
data_parallel_size
)
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
assert
start_batch_size
>
0
...
@@ -158,15 +171,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -158,15 +171,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self
.
current_global_batch_size
=
self
.
global_batch_size
self
.
current_global_batch_size
=
self
.
global_batch_size
else
:
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
self
.
current_global_batch_size
=
self
.
start_batch_size
+
steps
*
self
.
batch_size_increment
self
.
current_global_batch_size
=
(
self
.
start_batch_size
+
steps
*
self
.
batch_size_increment
)
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
if
consistency_check
:
if
consistency_check
:
assert
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
,
(
assert
(
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
),
(
"current global "
"current global "
"batch size ({}) is not divisible by micro-batch-size ({}) times"
"batch size ({}) is not divisible by micro-batch-size ({}) times"
"data parallel size ({})"
.
format
(
"data parallel size ({})"
.
format
(
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
,
)
)
)
)
self
.
num_micro_batches
=
self
.
current_global_batch_size
//
self
.
micro_batch_times_data_parallel_size
self
.
num_micro_batches
=
(
self
.
current_global_batch_size
//
self
.
micro_batch_times_data_parallel_size
)
apex/transformer/parallel_state.py
View file @
96850dfa
...
@@ -12,14 +12,24 @@
...
@@ -12,14 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO (mkozuki): Replace assert with RuntimeError.
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
"""Model and data parallel groups."""
"""Model and data parallel groups."""
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
import
warnings
import
torch
import
torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
from
apex.transformer.log_util
import
get_transformer_logger
# only for ensure_divisibility
from
apex.transformer.utils
import
ensure_divisibility
_logger
=
get_transformer_logger
(
__name__
)
# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# {
# 'get_num_layers',
# }
# Intra-layer model parallel group that the current rank belongs to.
# Intra-layer model parallel group that the current rank belongs to.
...
@@ -30,11 +40,17 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
...
@@ -30,11 +40,17 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP
=
None
_MODEL_PARALLEL_GROUP
=
None
# Embedding group.
# Embedding group.
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
# Position embedding group.
_POSITION_EMBEDDING_GROUP
=
None
# Relative position embedding group.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
None
# These values enable us to change the mpu sizes on the fly.
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
...
@@ -45,6 +61,13 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
...
@@ -45,6 +61,13 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the relative position embedding.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
...
@@ -56,14 +79,31 @@ def is_unitialized():
...
@@ -56,14 +79,31 @@ def is_unitialized():
def
initialize_model_parallel
(
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
tensor_model_parallel_size_
:
int
=
1
,
):
pipeline_model_parallel_size_
:
int
=
1
,
virtual_pipeline_model_parallel_size_
:
Optional
[
int
]
=
None
,
pipeline_model_parallel_split_rank_
:
Optional
[
int
]
=
None
,
*
,
default_backend
:
Optional
[
str
]
=
None
,
p2p_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
"""
Initialize model data parallel groups.
Initialize model data parallel groups.
Arguments:
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
Keyword Arguments:
default_backend: Backend of process groups except for pipeline parallel ones.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
p2p_backend: Backend of process groups for pipeline model parallel.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
.. note::
`torch_ucc <https://github.com/facebookresearch/torch_ucc>`_ is
necessary for "ucc" backend.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...
@@ -83,28 +123,61 @@ def initialize_model_parallel(
...
@@ -83,28 +123,61 @@ def initialize_model_parallel(
"""
"""
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
assert
default_backend
is
None
or
default_backend
in
(
"nccl"
,
"ucc"
)
tensor_model_parallel_size
=
min
(
tensor_model_parallel_size_
,
world_size
)
assert
p2p_backend
is
None
or
p2p_backend
in
(
"nccl"
,
"ucc"
)
pipeline_model_parallel_size
=
min
(
pipeline_model_parallel_size_
,
world_size
)
if
"ucc"
in
(
default_backend
,
p2p_backend
):
ensure_divisibility
(
world_size
,
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
check_torch_ucc_availability
()
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
warnings
.
warn
(
"`ucc` backend support is experimental"
,
ExperimentalWarning
)
if
default_backend
==
"ucc"
:
warnings
.
warn
(
"The UCC's functionality as `default_backend` is not well verified"
,
ExperimentalWarning
)
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
:
int
=
min
(
tensor_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
:
int
=
min
(
pipeline_model_parallel_size_
,
world_size
)
if
world_size
%
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
!=
0
:
raise
RuntimeError
(
f
"`world_size` (
{
world_size
}
) is not divisible by tensor_model_parallel_size (
{
tensor_model_parallel_size
}
) x pipeline_model_parallel_size (
{
pipeline_model_parallel_size
}
)"
)
data_parallel_size
:
int
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size
))
_logger
.
info
(
print
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
))
"> initializing tensor model parallel with size {}"
.
format
(
print
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
))
tensor_model_parallel_size
)
)
_logger
.
info
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
)
)
_logger
.
info
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
)
)
num_tensor_model_parallel_groups
=
world_size
//
tensor_model_parallel_size
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
num_data_parallel_groups
:
int
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
if
virtual_pipeline_model_parallel_size_
is
not
None
:
assert
pipeline_model_parallel_size_
>
2
,
\
# n.b. (eqy) This check was inherited from Megatron-LM, need to revisit
'pipeline-model-parallel size should be greater than 2 with '
\
# the root cause as we do see numerical mismatches with 2 stages and
'interleaved schedule'
# the interleaved schedule
assert
pipeline_model_parallel_size_
>
2
,
(
"pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule"
)
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
(
virtual_pipeline_model_parallel_size_
)
if
pipeline_model_parallel_split_rank_
is
not
None
:
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
pipeline_model_parallel_split_rank_
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
...
@@ -118,7 +191,7 @@ def initialize_model_parallel(
...
@@ -118,7 +191,7 @@ def initialize_model_parallel(
for
j
in
range
(
tensor_model_parallel_size
):
for
j
in
range
(
tensor_model_parallel_size
):
ranks
=
range
(
start_rank
+
j
,
end_rank
,
tensor_model_parallel_size
)
ranks
=
range
(
start_rank
+
j
,
end_rank
,
tensor_model_parallel_size
)
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
default_backend
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP
=
group
...
@@ -126,17 +199,24 @@ def initialize_model_parallel(
...
@@ -126,17 +199,24 @@ def initialize_model_parallel(
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
"model parallel group is already initialized"
assert
_MODEL_PARALLEL_GROUP
is
None
,
"model parallel group is already initialized"
for
i
in
range
(
data_parallel_size
):
for
i
in
range
(
data_parallel_size
):
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
ranks
=
[
group
=
torch
.
distributed
.
new_group
(
ranks
)
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
default_backend
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
_MODEL_PARALLEL_GROUP
=
group
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TENSOR_MODEL_PARALLEL_GROUP
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
None
,
"tensor model parallel group is already initialized"
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
),
"tensor model parallel group is already initialized"
for
i
in
range
(
num_tensor_model_parallel_groups
):
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
ranks
=
list
(
group
=
torch
.
distributed
.
new_group
(
ranks
)
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
default_backend
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TENSOR_MODEL_PARALLEL_GROUP
=
group
...
@@ -144,43 +224,111 @@ def initialize_model_parallel(
...
@@ -144,43 +224,111 @@ def initialize_model_parallel(
# (first and last rank in each pipeline model-parallel group).
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
"pipeline model parallel group is already initialized"
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
),
"pipeline model parallel group is already initialized"
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
(
_POSITION_EMBEDDING_GROUP
is
None
),
"position embedding group is already initialized"
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
assert
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
None
or
\
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
None
,
\
'relative position embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
p2p_backend
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
_PIPELINE_GLOBAL_RANKS
=
ranks
# Setup embedding group (to exchange gradients between
# Setup embedding group (to exchange gradients between
# first and last stages).
# first and last stages).
encoder_relative_position_embedding_ranks
=
None
decoder_relative_position_embedding_ranks
=
None
if
len
(
ranks
)
>
1
:
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
encoder_relative_position_embedding_ranks
=
[
ranks
[
0
]]
decoder_relative_position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank_
is
not
None
:
encoder_relative_position_embedding_ranks
=
\
ranks
[:
pipeline_model_parallel_split_rank_
]
decoder_relative_position_embedding_ranks
=
\
ranks
[
pipeline_model_parallel_split_rank_
:]
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
],
]
if
(
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
position_embedding_ranks
):
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
]
else
:
else
:
embedding_ranks
=
ranks
embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
position_embedding_ranks
=
ranks
encoder_relative_position_embedding_ranks
=
ranks
decoder_relative_position_embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
,
backend
=
default_backend
)
if
rank
in
embedding_ranks
:
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
,
backend
=
default_backend
)
if
rank
in
position_embedding_ranks
:
_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
if
encoder_relative_position_embedding_ranks
:
group
=
torch
.
distributed
.
new_group
(
encoder_relative_position_embedding_ranks
)
if
rank
in
encoder_relative_position_embedding_ranks
:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
\
encoder_relative_position_embedding_ranks
if
decoder_relative_position_embedding_ranks
:
group
=
torch
.
distributed
.
new_group
(
decoder_relative_position_embedding_ranks
)
if
rank
in
decoder_relative_position_embedding_ranks
:
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
\
decoder_relative_position_embedding_ranks
def
get_rank_info
()
->
Tuple
[
int
,
int
,
int
]:
def
get_rank_info
()
->
Tuple
[
int
,
int
,
int
]:
"""Returns a tuple of (tensor, pipeline,
data
)-parallel-rank for logger."""
"""Returns a tuple of (
data,
tensor, pipeline,
virtual pipeline
)-parallel-rank for logger."""
if
model_parallel_is_initialized
():
if
model_parallel_is_initialized
():
return
(
return
(
get_data_parallel_rank
(),
get_tensor_model_parallel_rank
(),
get_tensor_model_parallel_rank
(),
get_pipeline_model_parallel_rank
(),
get_pipeline_model_parallel_rank
(),
# get_virtual_pipeline_model_parallel_rank(),
get_virtual_pipeline_model_parallel_rank
(),
get_data_parallel_rank
(),
)
)
return
(
0
,
0
,
0
)
return
(
0
,
0
,
0
,
0
)
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
"""Check if model and data parallel groups are initialized."""
if
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
:
if
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
):
return
False
return
False
return
True
return
True
...
@@ -193,13 +341,17 @@ def get_model_parallel_group():
...
@@ -193,13 +341,17 @@ def get_model_parallel_group():
def
get_tensor_model_parallel_group
():
def
get_tensor_model_parallel_group
():
"""Get the tensor model parallel group the caller rank belongs to."""
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
,
"intra_layer_model parallel group is not initialized"
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
),
"intra_layer_model parallel group is not initialized"
return
_TENSOR_MODEL_PARALLEL_GROUP
return
_TENSOR_MODEL_PARALLEL_GROUP
def
get_pipeline_model_parallel_group
():
def
get_pipeline_model_parallel_group
():
"""Get the pipeline model parallel group the caller rank belongs to."""
"""Get the pipeline model parallel group the caller rank belongs to."""
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
,
"pipeline_model parallel group is not initialized"
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
),
"pipeline_model parallel group is not initialized"
return
_PIPELINE_MODEL_PARALLEL_GROUP
return
_PIPELINE_MODEL_PARALLEL_GROUP
...
@@ -215,6 +367,25 @@ def get_embedding_group():
...
@@ -215,6 +367,25 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
return
_EMBEDDING_GROUP
def
get_position_embedding_group
():
"""Get the position embedding group the caller rank belongs to."""
assert
(
_POSITION_EMBEDDING_GROUP
is
not
None
),
"position embedding group is not initialized"
return
_POSITION_EMBEDDING_GROUP
def
get_encoder_relative_position_embedding_group
():
"""Get the encoder relative position embedding group the caller rank belongs to."""
assert
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
not
None
,
\
'encoder relative position embedding group is not initialized'
return
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
def
get_decoder_relative_position_embedding_group
():
"""Get the decoder relative position embedding group the caller rank belongs to."""
assert
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
not
None
,
\
'decoder relative position embedding group is not initialized'
return
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
...
@@ -231,6 +402,64 @@ def is_rank_in_embedding_group(ignore_virtual=False):
...
@@ -231,6 +402,64 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return
False
return
False
def
is_rank_in_position_embedding_group
():
"""Return whether the current rank is in position embedding group."""
rank
=
torch
.
distributed
.
get_rank
()
global
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_rank_in_encoder_relative_position_embedding_group
():
"""Return true if current rank is in encoder relative position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_rank_in_decoder_relative_position_embedding_group
():
"""Return true if current rank is in decoder relative position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
<
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_after_split
(
rank
=
None
):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
>=
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_at_split
():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank
=
get_pipeline_model_parallel_rank
()
return
is_pipeline_stage_before_split
(
rank
)
and
is_pipeline_stage_after_split
(
rank
+
1
)
def
set_tensor_model_parallel_world_size
(
world_size
):
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
@@ -287,6 +516,21 @@ def get_pipeline_model_parallel_rank():
...
@@ -287,6 +516,21 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe?
def
get_pipeline_model_parallel_split_rank
():
"""Return my rank for the pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def
set_pipeline_model_parallel_split_rank
(
pipeline_model_parallel_split_rank
:
int
):
"""Set my rank for the pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
pipeline_model_parallel_split_rank
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
...
@@ -301,12 +545,16 @@ def is_pipeline_first_stage(ignore_virtual=False):
...
@@ -301,12 +545,16 @@ def is_pipeline_first_stage(ignore_virtual=False):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
get_virtual_pipeline_model_parallel_world_size
()
virtual_pipeline_model_parallel_world_size
=
(
get_virtual_pipeline_model_parallel_world_size
()
)
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
virtual_pipeline_model_parallel_world_size
-
1
):
):
return
False
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
def
get_virtual_pipeline_model_parallel_rank
():
...
@@ -335,26 +583,42 @@ def get_tensor_model_parallel_src_rank():
...
@@ -335,26 +583,42 @@ def get_tensor_model_parallel_src_rank():
return
(
global_rank
//
local_world_size
)
*
local_world_size
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank in the data parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
:
int
=
get_data_parallel_world_size
()
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
global_rank
%
num_data_parallel_groups
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
def
get_pipeline_model_parallel_next_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
...
@@ -370,6 +634,9 @@ def get_data_parallel_rank():
...
@@ -370,6 +634,9 @@ def get_data_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM.
# Otherwise pipeline parallel forward_backward functions test hangs possibly because
# the clean-up of the original is NOT enough.
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
...
@@ -382,6 +649,12 @@ def destroy_model_parallel():
...
@@ -382,6 +649,12 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
global
_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP
=
None
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
...
@@ -394,3 +667,16 @@ def destroy_model_parallel():
...
@@ -394,3 +667,16 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
# Used to warn when the UCC is specified.
class
ExperimentalWarning
(
Warning
):
pass
def
check_torch_ucc_availability
()
->
None
:
try
:
import
torch_ucc
# NOQA
except
ImportError
:
raise
ImportError
(
"UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
)
apex/transformer/pipeline_parallel/p2p_communication.py
View file @
96850dfa
# coding=utf-8
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,63 +12,108 @@
...
@@ -12,63 +12,108 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO(mkozuki): Consider removing `timers`.
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
operator
from
typing
import
Union
,
Optional
,
Tuple
from
typing
import
Union
,
Optional
,
Tuple
import
warnings
import
torch
import
torch
from
apex._autocast_utils
import
_get_current_dtype
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer.log_util
import
get_transformer_logger
from
apex.transformer.utils
import
split_tensor_into_1d_equal_chunks
from
apex.transformer.utils
import
split_tensor_into_1d_equal_chunks
from
apex.transformer.utils
import
gather_split_1d_tensor
from
apex.transformer.utils
import
gather_split_1d_tensor
from
apex.transformer.pipeline_parallel.utils
import
Shape
from
apex.transformer.pipeline_parallel.utils
import
Shape
from
apex.transformer.pipeline_parallel._timers
import
_Timers
from
apex.transformer.pipeline_parallel._timers
import
_Timers
_logger
=
get_transformer_logger
(
__name__
)
class
FutureTensor
:
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
waitfunc
):
self
.
tensor
=
tensor
self
.
waitfunc
=
waitfunc
def
get
(
self
):
if
self
.
waitfunc
is
not
None
:
res
=
self
.
waitfunc
()
if
isinstance
(
res
,
torch
.
Tensor
):
self
.
tensor
=
res
self
.
waitfunc
=
None
return
self
.
tensor
def
_run_p2pops
(
def
_run_p2pops
(
tensor_send_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_next
:
Union
[
torch
.
Tensor
,
None
],
async_comm
:
bool
=
False
):
):
ops
=
[]
ops
=
[]
p2p_group
=
parallel_state
.
get_pipeline_model_parallel_group
()
default_group
=
parallel_state
.
get_model_parallel_group
()
need_to_sync
=
p2p_group
.
name
()
!=
default_group
.
name
()
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
op
=
torch
.
distributed
.
isend
,
tensor_send_prev
,
tensor
=
tensor_send_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
peer
=
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
group
=
p2p_group
,
)
)
ops
.
append
(
send_prev_op
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
op
=
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
tensor
=
tensor_recv_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
peer
=
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
group
=
p2p_group
,
)
)
ops
.
append
(
recv_prev_op
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
op
=
torch
.
distributed
.
isend
,
tensor_send_next
,
tensor
=
tensor_send_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
peer
=
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
group
=
p2p_group
,
)
)
ops
.
append
(
send_next_op
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
op
=
torch
.
distributed
.
irecv
,
tensor_recv_next
,
tensor
=
tensor_recv_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
peer
=
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
group
=
p2p_group
,
)
)
ops
.
append
(
recv_next_op
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
if
need_to_sync
:
for
req
in
reqs
:
torch
.
cuda
.
synchronize
()
req
.
wait
()
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
if
async_comm
:
assert
len
(
reqs
)
==
len
(
ops
)
tensor_send_prev_req
=
None
if
tensor_send_prev
is
None
else
reqs
.
pop
(
0
)
tensor_recv_prev_req
=
None
if
tensor_recv_prev
is
None
else
reqs
.
pop
(
0
)
tensor_send_next_req
=
None
if
tensor_send_next
is
None
else
reqs
.
pop
(
0
)
tensor_recv_next_req
=
None
if
tensor_recv_next
is
None
else
reqs
.
pop
(
0
)
return
(
tensor_send_prev_req
,
tensor_recv_prev_req
,
tensor_send_next_req
,
tensor_recv_next_req
)
else
:
for
req
in
reqs
:
req
.
wait
()
return
(
None
,
None
,
None
,
None
)
return
(
None
,
None
,
None
,
None
)
# TODO(mkozuki): Check if it's possible to sunset `override_scatter_gather_tensors_in_pipeline`.
# TODO(mkozuki): Think about if it's possible to push some logic and arguments e.g.
# `scatter_gather_tensors_in_pipeline`, `sequence_parallel_enabled`, and
# `override_scatter_gather_tensors_in_pipeline` # to the user of
# apex.transformer forward_backwardfunctions.
def
_communicate
(
def
_communicate
(
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
...
@@ -76,14 +121,26 @@ def _communicate(
...
@@ -76,14 +121,26 @@ def _communicate(
recv_next
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Optional
[
Shape
]
=
None
,
tensor_shape
:
Optional
[
Shape
]
=
None
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
dtype_
:
torch
.
dtype
=
torch
.
float
,
dtype_
:
Optional
[
torch
.
dtype
]
=
None
,
*
,
*
,
scatter_gather_tensors_in_pipeline
:
bool
=
True
,
scatter_gather_tensors_in_pipeline
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
fp32_residual_connection
:
bool
=
False
,
fp32_residual_connection
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
Union
[
torch
.
Tensor
,
None
]]:
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
FutureTensor
,
None
],
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]]:
"""Base function for communication of tensors between stages.
"""Base function for communication of tensors between stages.
.. note::
Reference https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/cfd2e2160700b7f2c1bf35298ac14bc341f4c759/megatron/p2p_communication.py#L24-L159
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
torch.float32 is used.
See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159
for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``.
Args:
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
...
@@ -99,6 +156,9 @@ def _communicate(
...
@@ -99,6 +156,9 @@ def _communicate(
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled.
This argument is here for consistency with Megatron-LM.
This argument has an effect on the communication optimization, not on tensor_shape update.
Returns:
Returns:
tuple containing
tuple containing
...
@@ -106,6 +166,13 @@ def _communicate(
...
@@ -106,6 +166,13 @@ def _communicate(
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
"""
if
async_comm
and
sequence_parallel_enabled
:
import
warnings
# NOQA
class
ExperimentalWarning
(
UserWarning
):
pass
# NOQA
warnings
.
warn
(
"The combination of `async_comm` and `sequence_parallel_enabled` is not well tested."
,
ExperimentalWarning
,
)
# Create placeholder tensors for receive in forward and backward directions if needed.
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev
=
None
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_next
=
None
...
@@ -113,25 +180,45 @@ def _communicate(
...
@@ -113,25 +180,45 @@ def _communicate(
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise
RuntimeError
(
raise
RuntimeError
(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`"
)
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`"
)
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
(
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
parallel_state
.
get_tensor_model_parallel_world_size
(),)
tensor_parallel_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
override_scatter_gather_tensors_in_pipeline_
=
False
# TODO(mkozuki): Demystify hardcode False of `scatter_gather_tensors_in_pipeline` and add a testcase if possible.
# NOTE(mkozuki): This is super strange and doesn't make sense to me. I have no idea what is happening here.
# However, I can say that this hardcoding override is necessary for sequence parallel in nemo megatron to work.
# I've not managed to reproduce the hang using standalone GPT with sequence parallel.
# The hang in NeMo Megatron happens in the 3rd iteration, the last iteration of stead phase inside
# forward_backward_pipelining_without_interleaving, pipeline parallel rank of 0 (tensor model parallel world
# size of 2 and pipeline model parallel world size of 2). The commit then of APEX and NeMo were
# https://github.com/NVIDIA/apex/pull/1396/commits/3060c98dd8ba42abf7702ea9d2cff0f39ea74f45 and
# https://github.com/NVIDIA/NeMo/pull/4232/commits/1cb32dfca2ab9b20f53ebdb84476c34cb42f0205.
# The PyTorch version was 1.13.0a0+git2d354cd, for what is worth.
# Currently, indiscriminately this is set to `False`, which can lead to an unexpected performance regression
# for non sequence parallel case.
scatter_gather_tensors_in_pipeline
=
False
if
scatter_gather_tensors_in_pipeline
and
not
sequence_parallel_enabled
:
tensor_chunk_size
=
int
(
reduce
(
operator
.
mul
,
tensor_shape
,
1
))
if
tensor_chunk_size
%
tensor_parallel_size
==
0
:
tensor_chunk_shape
=
[
tensor_chunk_size
//
tensor_parallel_size
]
else
:
tensor_chunk_shape
=
tensor_shape
override_scatter_gather_tensors_in_pipeline_
=
True
else
:
else
:
tensor_chunk_shape
=
tensor_shape
tensor_chunk_shape
=
tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# The dtype logic below is copied from NVIDIA/Megatron-LM repo:
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
# It might be possible if we restrict model architecture.
dtype
=
params_dtype
or
torch
.
float
# dtype = params_dtype or torch.float
if
fp32_residual_connection
:
# if fp32_residual_connection:
dtype
=
torch
.
float
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if
dtype_
!=
torch
.
float32
or
params_dtype
is
not
None
:
if
torch
.
distributed
.
get_rank
()
==
0
:
warnings
.
warn
(
"Tensor P2P communications are executed in FP32"
)
dtype
=
torch
.
float32
requires_grad
=
True
requires_grad
=
True
if
dtype_
is
not
None
:
dtype
=
dtype_
# TODO(mkozuki): Figure out why this logic of requires_grad isn't working
# when sequence_parallel_enabled=True. Otherwise, `x.retain_grad()` of
# https://github.com/crcrpar/apex/blob/069832078a652b4bd8a99db84faf953a81415ab3/apex/transformer/pipeline_parallel/schedules/common.py#L360
# fails.
# requires_grad = False
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_recv_prev
=
torch
.
empty
(
...
@@ -149,7 +236,12 @@ def _communicate(
...
@@ -149,7 +236,12 @@ def _communicate(
)
)
# Split tensor into smaller chunks if using scatter-gather optimization.
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
scatter_gather_optimization_doable
=
(
not
override_scatter_gather_tensors_in_pipeline_
and
scatter_gather_tensors_in_pipeline
and
not
sequence_parallel_enabled
)
if
scatter_gather_optimization_doable
:
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
@@ -157,41 +249,89 @@ def _communicate(
...
@@ -157,41 +249,89 @@ def _communicate(
tensor_send_prev
=
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
tensor_send_prev
=
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops
(
tensor_send_prev
,
tensor_send_next
,
tensor_recv_prev
,
tensor_recv_next
)
tensor_send_prev_req
,
tensor_recv_prev_req
,
tensor_send_next_req
,
tensor_recv_next_req
=
_run_p2pops
(
tensor_send_prev
,
tensor_send_next
,
tensor_recv_prev
,
tensor_recv_next
,
async_comm
=
async_comm
)
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
if
async_comm
:
tensor_recv_prev_waitfunc
=
None
tensor_recv_next_waitfunc
=
None
# TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642)
# see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait
if
tensor_recv_prev_req
is
not
None
:
def
tensor_recv_prev_wait
():
tensor_recv_prev_req
.
wait
()
torch
.
cuda
.
synchronize
()
tensor_recv_prev_waitfunc
=
tensor_recv_prev_wait
if
tensor_recv_next_req
is
not
None
:
def
tensor_recv_next_wait
():
tensor_recv_next_req
.
wait
()
torch
.
cuda
.
synchronize
()
tensor_recv_next_waitfunc
=
tensor_recv_next_wait
else
:
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
if
scatter_gather_optimization_doable
:
if
recv_prev
:
if
not
async_comm
:
tensor_recv_prev
=
(
if
recv_prev
:
gather_split_1d_tensor
(
tensor_recv_prev
)
tensor_recv_prev
=
(
.
view
(
tensor_shape
)
gather_split_1d_tensor
(
tensor_recv_prev
)
.
requires_grad_
()
.
view
(
tensor_shape
)
)
.
requires_grad_
()
)
if
recv_next
:
tensor_recv_next
=
(
if
recv_next
:
gather_split_1d_tensor
(
tensor_recv_next
)
tensor_recv_next
=
(
.
view
(
tensor_shape
)
gather_split_1d_tensor
(
tensor_recv_next
)
.
requires_grad_
()
.
view
(
tensor_shape
)
)
.
requires_grad_
()
)
else
:
def
gather_recv_prev_wait
():
tensor_recv_prev_req
.
wait
()
# From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
# A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
torch
.
cuda
.
synchronize
()
return
(
gather_split_1d_tensor
(
tensor_recv_prev
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
def
gather_recv_next_wait
():
tensor_recv_next_req
.
wait
()
torch
.
cuda
.
synchronize
()
return
(
gather_split_1d_tensor
(
tensor_recv_next
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
tensor_recv_prev_waitfunc
=
gather_recv_prev_wait
tensor_recv_next_waitfunc
=
gather_recv_next_wait
if
async_comm
:
future_tensor_recv_prev
=
None
future_tensor_recv_next
=
None
if
tensor_recv_prev
is
not
None
:
future_tensor_recv_prev
=
FutureTensor
(
tensor_recv_prev
,
tensor_recv_prev_waitfunc
)
if
tensor_recv_next
is
not
None
:
future_tensor_recv_next
=
FutureTensor
(
tensor_recv_next
,
tensor_recv_next_waitfunc
)
return
future_tensor_recv_prev
,
future_tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
def
recv_forward
(
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
():
return
None
return
None
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-recv"
).
start
()
#
timers("forward-recv").start()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -199,50 +339,58 @@ def recv_forward(
...
@@ -199,50 +339,58 @@ def recv_forward(
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-recv"
).
stop
()
#
timers("forward-recv").stop()
return
input_tensor
return
input_tensor
def
recv_backward
(
def
recv_backward
(
tensor_shape
:
Shape
=
None
,
tensor_shape
:
Shape
=
None
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
):
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Receive tensor from next rank in pipeline (backward receive)."""
"""Receive tensor from next rank in pipeline (backward receive)."""
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
return
None
return
None
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-recv"
).
start
()
#
timers("backward-recv").start()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-recv"
).
stop
()
#
timers("backward-recv").stop()
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
tensor_shape
:
Shape
=
None
,
tensor_shape
:
Shape
=
None
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
None
:
)
->
None
:
"""Send tensor to next rank in pipeline (forward send)."""
"""Send tensor to next rank in pipeline (forward send)."""
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
return
return
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-send"
).
start
()
#
timers("forward-send").start()
_communicate
(
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -250,155 +398,181 @@ def send_forward(
...
@@ -250,155 +398,181 @@ def send_forward(
recv_next
=
False
,
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-send"
).
stop
()
#
timers("forward-send").stop()
def
send_backward
(
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
None
:
)
->
None
:
"""Send tensor to previous rank in pipeline (backward send)."""
"""Send tensor to previous rank in pipeline (backward send)."""
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
():
return
return
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-send"
).
start
()
#
timers("backward-send").start()
_communicate
(
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-send"
).
stop
()
#
timers("backward-send").stop()
def
send_forward_recv_backward
(
def
send_forward_recv_backward
(
output_tensor
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
)
->
None
:
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Batched send and recv with next rank in pipeline."""
"""Batched send and recv with next rank in pipeline."""
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
return
None
return
None
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-send-backward-recv"
).
start
()
#
timers("forward-send-backward-recv").start()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
True
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-send-backward-recv"
).
stop
()
#
timers("forward-send-backward-recv").stop()
return
output_tensor_grad
return
output_tensor_grad
def
send_backward_recv_forward
(
def
send_backward_recv_forward
(
input_tensor_grad
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Batched send and recv with previous rank in pipeline."""
"""Batched send and recv with previous rank in pipeline."""
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
():
return
None
return
None
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-send-forward-recv"
).
start
()
#
timers("backward-send-forward-recv").start()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-send-forward-recv"
).
stop
()
#
timers("backward-send-forward-recv").stop()
return
input_tensor
return
input_tensor
def
send_forward_recv_forward
(
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_prev
:
bool
,
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
]:
"""Batched recv from previous rank and send to next rank in pipeline."""
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-send-forward-recv"
).
start
()
#
timers("forward-send-forward-recv").start()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-send-forward-recv"
).
stop
()
#
timers("forward-send-forward-recv").stop()
return
input_tensor
return
input_tensor
def
send_backward_recv_backward
(
def
send_backward_recv_backward
(
input_tensor_grad
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_next
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
*
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
]:
"""Batched recv from next rank and send to previous rank in pipeline."""
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-send-backward-recv"
).
start
()
#
timers("backward-send-backward-recv").start()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"backward-send-backward-recv"
).
stop
()
#
timers("backward-send-backward-recv").stop()
return
output_tensor_grad
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
def
send_forward_backward_recv_forward_backward
(
output_tensor
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_prev
:
bool
,
recv_next
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
tensor_shape
:
Shape
,
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
async_comm
:
bool
=
False
,
):
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
FutureTensor
],
Union
[
torch
.
Tensor
,
FutureTensor
]]:
"""Batched send and recv with previous and next ranks in pipeline."""
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-backward-send-forward-backward-recv"
).
start
()
#
timers("forward-backward-send-forward-backward-recv").start()
input_tensor
,
output_tensor_grad
=
_communicate
(
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
if
timers
is
not
None
:
#
if timers is not None:
timers
(
"forward-backward-send-forward-backward-recv"
).
stop
()
#
timers("forward-backward-send-forward-backward-recv").stop()
return
input_tensor
,
output_tensor_grad
return
input_tensor
,
output_tensor_grad
apex/transformer/pipeline_parallel/schedules/__init__.py
View file @
96850dfa
import
warnings
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
forward_backward_no_pipelining
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
(
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
_forward_backward_pipelining_with_interleaving
forward_backward_no_pipelining
,
)
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
(
_forward_backward_pipelining_with_interleaving
,
)
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving
import
(
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving
import
(
forward_backward_pipelining_without_interleaving
,
forward_backward_pipelining_without_interleaving
,
)
)
__all__
=
[
"get_forward_backward_func"
,
]
class
ExperimentalWarning
(
Warning
):
class
ExperimentalWarning
(
Warning
):
pass
pass
...
@@ -21,19 +27,9 @@ def get_forward_backward_func(
...
@@ -21,19 +27,9 @@ def get_forward_backward_func(
if
get_num_microbatches
()
%
pipeline_model_parallel_size
!=
0
:
if
get_num_microbatches
()
%
pipeline_model_parallel_size
!=
0
:
msg
=
"number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
msg
=
"number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise
RuntimeError
(
msg
)
raise
RuntimeError
(
msg
)
warnings
.
warn
(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f
"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`:
{
virtual_pipeline_model_parallel_size
}
"
,
ExperimentalWarning
)
forward_backward_func
=
_forward_backward_pipelining_with_interleaving
forward_backward_func
=
_forward_backward_pipelining_with_interleaving
else
:
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
else
:
forward_backward_func
=
forward_backward_no_pipelining
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
return
forward_backward_func
__all__
=
[
"get_forward_backward_func"
,
]
apex/transformer/pipeline_parallel/schedules/common.py
View file @
96850dfa
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
,
Optional
,
Sequence
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
,
Optional
import
torch
import
torch
from
torch.autograd.variable
import
Variable
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer.enums
import
ModelType
from
apex.transformer.pipeline_parallel.p2p_communication
import
FutureTensor
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
unwrap_model
from
apex.transformer.pipeline_parallel.utils
import
unwrap_model
from
apex.transformer.tensor_parallel.layers
import
set_defaults_if_not_set_tensor_model_parallel_attributes
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.tensor_parallel.layers
import
(
set_defaults_if_not_set_tensor_model_parallel_attributes
,
)
from
apex.transformer.log_util
import
get_transformer_logger
Batch
=
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
...]]
_logger
=
get_transformer_logger
(
__name__
)
Batch
=
Union
[
torch
.
Tensor
,
FutureTensor
,
List
[
Union
[
torch
.
Tensor
,
FutureTensor
]],
Tuple
[
Union
[
torch
.
Tensor
,
FutureTensor
],
...]]
LossFunc
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
LossFunc
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
FwdStepFunc
=
Callable
[[
Batch
,
torch
.
nn
.
Module
],
Tuple
[
torch
.
Tensor
,
LossFunc
]]
FwdStepFunc
=
Callable
[
[
Optional
[
Batch
],
torch
.
nn
.
Module
],
Tuple
[
torch
.
Tensor
,
LossFunc
]
]
def
build_model
(
def
build_model
(
model_provider_func
:
Callable
[[
Any
,
Dict
[
str
,
Any
]],
torch
.
nn
.
Module
],
model_provider_func
:
Callable
[[
Any
,
Dict
[
str
,
Any
]],
torch
.
nn
.
Module
],
wrap_with_ddp
:
bool
=
True
,
wrap_with_ddp
:
bool
=
True
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
*
args
,
model_type
:
ModelType
=
ModelType
.
encoder_or_decoder
,
**
kwargs
*
args
:
Any
,
**
kwargs
:
Any
,
)
->
List
[
torch
.
nn
.
Module
]:
)
->
List
[
torch
.
nn
.
Module
]:
"""Build the model satisfying pipeline model parallel requirements.
"""Build the model satisfying pipeline model parallel requirements.
...
@@ -32,6 +45,7 @@ def build_model(
...
@@ -32,6 +45,7 @@ def build_model(
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
model_type:
*args: arguments for model provider func
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
**kwargs: Keyword arguments for model provider func
...
@@ -40,8 +54,8 @@ def build_model(
...
@@ -40,8 +54,8 @@ def build_model(
the list has multiple models, otherwise one.
the list has multiple models, otherwise one.
"""
"""
if
(
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
virtual_pipeline_model_parallel_size
is
not
None
and
virtual_pipeline_model_parallel_size
is
not
None
):
):
model
=
[]
model
=
[]
for
i
in
range
(
virtual_pipeline_model_parallel_size
):
for
i
in
range
(
virtual_pipeline_model_parallel_size
):
...
@@ -51,22 +65,48 @@ def build_model(
...
@@ -51,22 +65,48 @@ def build_model(
# Set pre_process and post_process only after virtual rank is set.
# Set pre_process and post_process only after virtual rank is set.
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
({
cur_kwargs
.
update
(
"pre_process"
:
pre_process
,
{
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,}
"post_process"
:
post_process
,
)
})
this_model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
this_model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
model
.
append
(
this_model
)
model
.
append
(
this_model
)
else
:
else
:
cur_args
=
args
cur_args
=
args
cur_kwargs
=
kwargs
cur_kwargs
=
kwargs
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
if
model_type
==
ModelType
.
encoder_or_decoder
:
post_process
=
parallel_state
.
is_pipeline_last_stage
()
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
cur_kwargs
.
update
({
post_process
=
parallel_state
.
is_pipeline_last_stage
()
"pre_process"
:
pre_process
,
cur_kwargs
.
update
(
"post_process"
:
post_process
,
{
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,}
})
)
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
elif
model_type
==
ModelType
.
encoder_and_decoder
:
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
# `add_encoder` & `add_decoder` logic.
add_encoder
,
add_decoder
=
True
,
True
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
split_rank
=
parallel_state
.
get_pipeline_model_parallel_split_rank
()
if
split_rank
is
None
:
raise
RuntimeError
(
"Split rank needs to be specified for model with both encoder and decoder."
)
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
world_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pre_process
=
rank
==
0
or
rank
==
split_rank
post_process
=
rank
==
(
split_rank
-
1
)
or
rank
==
(
world_size
-
1
)
add_encoder
=
parallel_state
.
is_pipeline_stage_before_split
()
add_decoder
=
parallel_state
.
is_pipeline_stage_after_split
()
cur_kwargs
.
update
(
{
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
"add_encoder"
:
add_encoder
,
"add_decoder"
:
add_decoder
,
}
)
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
model
.
model_type
=
model_type
if
not
isinstance
(
model
,
list
):
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
model
=
[
model
]
...
@@ -80,11 +120,14 @@ def build_model(
...
@@ -80,11 +120,14 @@ def build_model(
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
# Print number of parameters.
if
parallel_state
.
get_data_parallel_rank
()
==
0
:
if
(
parallel_state
.
model_parallel_is_initialized
()
and
parallel_state
.
get_data_parallel_rank
()
==
0
):
msg
=
" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}"
.
format
(
msg
=
" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}"
.
format
(
parallel_state
.
get_tensor_model_parallel_rank
(),
parallel_state
.
get_tensor_model_parallel_rank
(),
parallel_state
.
get_pipeline_model_parallel_rank
(),
parallel_state
.
get_pipeline_model_parallel_rank
(),
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])
_calc_number_of_params
(
model
),
)
)
print
(
msg
,
flush
=
True
)
print
(
msg
,
flush
=
True
)
...
@@ -106,44 +149,119 @@ def build_model(
...
@@ -106,44 +149,119 @@ def build_model(
return
model
return
model
def
_calc_number_of_params
(
model
:
List
[
torch
.
nn
.
Module
])
->
int
:
assert
isinstance
(
model
,
list
)
return
sum
(
[
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
]
)
def
_get_params_for_weight_decay_optimization
(
def
_get_params_for_weight_decay_optimization
(
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
no_weight_decay_modules
=
(
FusedLayerNorm
,),
)
->
Dict
[
str
,
torch
.
nn
.
Parameter
]:
)
->
Dict
[
str
,
torch
.
nn
.
Parameter
]:
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
Layernorms and biases will have no weight decay but the rest will.
"""
"""
modules
=
listify_model
(
model
)
modules
=
listify_model
(
model
)
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
# NOQA
weight_decay_params
=
{
"params"
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
"params"
:
[],
"weight_decay"
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module
in
modules
:
for
module
in
modules
:
for
module_
in
module
.
modules
():
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
FusedLayerNorm
):
if
isinstance
(
module_
,
no_weight_decay_modules
):
no_weight_decay_params
[
'
params
'
].
extend
(
no_weight_decay_params
[
"
params
"
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
]
if
p
is
not
None
]
)
)
else
:
else
:
weight_decay_params
[
'params'
].
extend
(
weight_decay_params
[
"params"
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
[
if
p
is
not
None
and
n
!=
'bias'
])
p
no_weight_decay_params
[
'params'
].
extend
(
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
"bias"
if
p
is
not
None
and
n
==
'bias'
])
]
)
no_weight_decay_params
[
"params"
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
"bias"
]
)
return
weight_decay_params
,
no_weight_decay_params
return
weight_decay_params
,
no_weight_decay_params
def
free_output_tensor
(
output_tensors
:
Optional
[
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]],
deallocate_pipeline_outputs
:
bool
=
False
,
)
->
None
:
"""Pseudo-free the output tensor's `.data` field.
This method should be called right after the output tensor has been sent to the next
pipeline stage. At this point, the output tensor is only useful for its `.grad_fn` field,
and not its `.data`.
"""
if
not
deallocate_pipeline_outputs
:
return
if
output_tensors
is
None
:
return
if
isinstance
(
output_tensors
,
torch
.
Tensor
):
output_tensors
=
[
output_tensors
]
for
output_tensor
in
output_tensors
:
output_tensor
.
data
=
torch
.
cuda
.
FloatTensor
([
0
])
def
custom_backward
(
output
:
torch
.
Tensor
,
grad_output
:
Optional
[
torch
.
Tensor
])
->
None
:
"""Directly call C++ autograd engine.
To make the `free_output_tensor` optimization work, the C++ autograd engine must be called
directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the
output and grad have the same shape, while C++ `backward` does not.
"""
assert
(
output
.
numel
()
==
1
),
"output should be pseudo-freed in schedule, to optimize memory consumption"
assert
isinstance
(
output
,
torch
.
Tensor
),
"output == {}."
.
format
(
type
(
output
).
__name__
)
assert
isinstance
(
grad_output
,
(
torch
.
Tensor
,
type
(
None
))
),
"grad_outptu == {}."
.
format
(
type
(
grad_output
).
__name__
)
# Handle scalar output
if
grad_output
is
None
:
assert
output
.
numel
()
==
1
,
"Implicit grad requires scalar output."
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
)
# Call C++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable
.
_execution_engine
.
run_backward
(
tensors
=
(
output
,),
grad_tensors
=
(
grad_output
,),
keep_graph
=
False
,
create_graph
=
False
,
inputs
=
(),
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
def
forward_step
(
def
forward_step
(
forward_step_func
:
FwdStepFunc
,
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
batch
:
Optional
[
Batch
],
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
input_tensor
:
Optional
[
torch
.
Tensor
],
input_tensor
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]],
losses_reduced
:
List
[
torch
.
Tensor
],
losses_reduced
:
List
[
torch
.
Tensor
],
):
dtype
:
torch
.
dtype
,
disable_autocast
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]:
"""Forward step for passed-in model.
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used.
passed-in input_tensor is used.
Returns output tensor.
Returns output tensor.
...
@@ -154,6 +272,8 @@ def forward_step(
...
@@ -154,6 +272,8 @@ def forward_step(
model: unwrappable model
model: unwrappable model
input_tensor:
input_tensor:
losses_reduced:
losses_reduced:
dtype:
disable_autocast:
Returns:
Returns:
output_tensor
output_tensor
...
@@ -161,27 +281,51 @@ def forward_step(
...
@@ -161,27 +281,51 @@ def forward_step(
# timers = get_timers()
# timers = get_timers()
# timers("forward-compute").start()
# timers("forward-compute").start()
unwrapped_model
=
unwrap_model
(
model
)
unwrapped_model
=
unwrap_model
(
model
)
model_type
=
get_model_type
(
unwrapped_model
)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
# for the details of `set_input_tensor`.
unwrap_output_tensor
=
not
isinstance
(
input_tensor
,
list
)
if
unwrap_output_tensor
:
input_tensor
=
[
input_tensor
]
input_tensor
=
[
inp
.
get
()
if
isinstance
(
inp
,
FutureTensor
)
else
inp
for
inp
in
input_tensor
]
unwrapped_model
.
set_input_tensor
(
input_tensor
)
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
batch
,
model
)
with
torch
.
cuda
.
amp
.
autocast
(
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
enabled
=
not
disable_autocast
and
dtype
in
(
torch
.
half
,
torch
.
bfloat16
),
if
parallel_state
.
is_pipeline_last_stage
():
dtype
=
dtype
,
output_tensor
=
loss_func
(
output_tensor
)
):
loss
,
loss_reduced
=
output_tensor
output_tensor
,
loss_func
=
forward_step_func
(
batch
,
model
)
output_tensor
=
loss
/
get_num_microbatches
()
if
parallel_state
.
is_pipeline_last_stage
():
losses_reduced
.
append
(
loss_reduced
)
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
# timers("forward-compute").stop()
# timers("forward-compute").stop()
return
output_tensor
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]]
if
unwrap_output_tensor
:
return
output_tensor
return
[
output_tensor
]
def
backward_step
(
def
backward_step
(
input_tensor
:
Optional
[
torch
.
Tensor
],
input_tensor
:
Optional
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
output_tensor_grad
:
Optional
[
torch
.
Tensor
],
output_tensor_grad
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
model_type
:
ModelType
,
*
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
deallocate_pipeline_outputs
:
bool
=
False
,
)
->
Union
[
None
,
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]:
"""Backward step through passed-in output tensor.
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
If last stage, output_tensor_grad is None, otherwise gradient of loss
...
@@ -194,25 +338,61 @@ def backward_step(
...
@@ -194,25 +338,61 @@ def backward_step(
input_tensor:
input_tensor:
output_tensor:
output_tensor:
output_tensor_grad:
output_tensor_grad:
Keyword Arguments:
grad_scaler:
deallocate_pipeline_outputs: Experimental.
Returns:
Returns:
input_tensor_grad
input_tensor_grad
"""
"""
# timers = get_timers()
# timers = get_timers()
# timers("backward-compute").start()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
not
isinstance
(
input_tensor
,
list
)
if
unwrap_input_tensor_grad
:
input_tensor
=
[
input_tensor
]
input_tensor
=
[
inp
.
get
()
if
isinstance
(
inp
,
FutureTensor
)
else
inp
for
inp
in
input_tensor
]
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
output_tensor
=
[
out
.
get
()
if
isinstance
(
out
,
FutureTensor
)
else
out
for
out
in
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
output_tensor_grad
=
[
ogr
.
get
()
if
isinstance
(
ogr
,
FutureTensor
)
else
ogr
for
ogr
in
output_tensor_grad
]
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
# Backward pass.
# if output_tensor_grad is None:
if
grad_scaler
is
not
None
and
output_tensor_grad
[
0
]
is
None
:
# output_tensor = optimizer.scale_loss(output_tensor)
output_tensor
[
0
]
=
grad_scaler
.
scale
(
output_tensor
[
0
])
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
if
deallocate_pipeline_outputs
:
input_tensor_grad
=
None
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
input_tensor_grad
=
[]
# timers("backward-compute").stop()
for
x
in
input_tensor
:
input_tensor_grad
.
append
(
None
if
x
is
None
else
x
.
grad
)
return
input_tensor_grad
# Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
if
output_tensor_grad
[
1
]
is
not
None
:
# todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
# timers("backward-compute").stop()
return
input_tensor_grad
[
0
]
if
unwrap_input_tensor_grad
else
input_tensor_grad
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
View file @
96850dfa
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Optional
import
torch
import
torch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
from
apex.transformer.pipeline_parallel.schedules.common
import
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.log_util
import
get_transformer_logger
from
apex.transformer.log_util
import
get_transformer_logger
...
@@ -27,12 +29,16 @@ def placeholder_handler():
...
@@ -27,12 +29,16 @@ def placeholder_handler():
def
forward_backward_no_pipelining
(
def
forward_backward_no_pipelining
(
forward_step_func
:
FwdStepFunc
,
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
*
,
forward_only
:
bool
,
forward_only
:
bool
,
**
kwargs
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
disable_autocast
:
bool
=
False
,
custom_sync_context_handler
=
None
,
**
kwargs
,
):
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
...
@@ -48,6 +54,12 @@ def forward_backward_no_pipelining(
...
@@ -48,6 +54,12 @@ def forward_backward_no_pipelining(
Keyword args:
Keyword args:
forward_only:
forward_only:
grad_scaler:
dtype:
disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`.
Should be used when your forward and loss computation is in the autocast context to
avoid unnecesarily nest autocast context.
custom_sync_context_handler:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
Returns:
...
@@ -58,10 +70,14 @@ def forward_backward_no_pipelining(
...
@@ -58,10 +70,14 @@ def forward_backward_no_pipelining(
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
model
=
model
[
0
]
model_type
=
get_model_type
(
model
)
context_handler
=
placeholder_handler
if
custom_sync_context_handler
is
not
None
:
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
):
context_handler
=
custom_sync_context_handler
elif
isinstance
(
model
,
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
):
context_handler
=
model
.
no_sync
context_handler
=
model
.
no_sync
else
:
context_handler
=
placeholder_handler
losses_reduced
=
[]
losses_reduced
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
input_tensor
,
output_tensor_grad
=
None
,
None
...
@@ -72,20 +88,45 @@ def forward_backward_no_pipelining(
...
@@ -72,20 +88,45 @@ def forward_backward_no_pipelining(
cur_micro_batch
=
get_kth_microbatch
(
batch
,
i
)
cur_micro_batch
=
get_kth_microbatch
(
batch
,
i
)
_logger
.
debug
(
"Call `forward_step`"
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
output_tensor
=
forward_step
(
forward_step_func
,
cur_micro_batch
,
model
,
input_tensor
,
losses_reduced
)
forward_step_func
,
cur_micro_batch
,
model
,
input_tensor
,
losses_reduced
,
dtype
=
dtype
,
disable_autocast
=
disable_autocast
,
)
if
not
forward_only
:
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
)
# Run computation for last microbatch out of context handler (want to
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# synchronize gradients).
_logger
.
info
(
"Cooldown"
)
_logger
.
info
(
"Cooldown"
)
_logger
.
debug
(
"Call `forward_step`"
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
output_tensor
=
forward_step
(
forward_step_func
,
get_kth_microbatch
(
batch
,
num_micro_batches
-
1
),
model
,
input_tensor
,
losses_reduced
forward_step_func
,
get_kth_microbatch
(
batch
,
num_micro_batches
-
1
),
model
,
input_tensor
,
losses_reduced
,
dtype
=
dtype
,
disable_autocast
=
disable_autocast
,
)
)
if
not
forward_only
:
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
)
return
losses_reduced
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
View file @
96850dfa
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
,
Sequence
import
warnings
import
torch
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
from
apex.transformer.pipeline_parallel.schedules.common
import
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
free_output_tensor
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.log_util
import
get_transformer_logger
from
apex.transformer.log_util
import
get_transformer_logger
...
@@ -18,15 +22,22 @@ __all__ = ["_forward_backward_pipelining_with_interleaving"]
...
@@ -18,15 +22,22 @@ __all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger
=
get_transformer_logger
(
__name__
)
_logger
=
get_transformer_logger
(
__name__
)
# TODO
(mkozuki): Reduce cyclomatic complexity
# TODO(mkozuki): Reduce cyclomatic complexity
def
_forward_backward_pipelining_with_interleaving
(
def
_forward_backward_pipelining_with_interleaving
(
forward_step_func
:
FwdStepFunc
,
forward_step_func
:
FwdStepFunc
,
batch
:
List
[
Batch
],
batch
:
List
[
Optional
[
Batch
]],
model
:
List
[
torch
.
nn
.
Module
],
model
:
List
[
torch
.
nn
.
Module
],
*
,
*
,
forward_only
:
bool
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
disable_autocast
:
bool
=
False
,
deallocate_pipeline_outputs
:
bool
=
False
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
**
kwargs
,
)
->
List
[
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
...
@@ -48,7 +59,17 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -48,7 +59,17 @@ def _forward_backward_pipelining_with_interleaving(
Keyword args:
Keyword args:
forward_only:
forward_only:
tensor_shape: Shape of tensor.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
...
@@ -56,22 +77,43 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -56,22 +77,43 @@ def _forward_backward_pipelining_with_interleaving(
if
not
isinstance
(
model
,
list
):
if
not
isinstance
(
model
,
list
):
raise
RuntimeError
(
"`model` must be a list of `nn.Module`'s'"
)
raise
RuntimeError
(
"`model` must be a list of `nn.Module`'s'"
)
num_model_chunks
=
len
(
model
)
if
deallocate_pipeline_outputs
:
input_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
warnings
.
warn
(
output_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
"`deallocate_pipeline_outputs` is experimental and subject to change. "
curr_iters
=
[
0
for
_
in
range
(
num_model_chunks
)]
"This option is not recommended."
losses_reduced
=
[]
)
# mypy will blame the following if statement
if
sequence_parallel_enabled
:
seq_length
,
batch_size
,
hidden
=
tensor_shape
tensor_shape
=
(
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
(),
batch_size
,
hidden
,
)
num_model_chunks
:
int
=
len
(
model
)
input_tensors
:
List
[
List
[
Union
[
None
,
torch
.
Tensor
]]]
=
[
[]
for
_
in
range
(
num_model_chunks
)
]
output_tensors
:
List
[
List
[
Union
[
None
,
torch
.
Tensor
]]]
=
[
[]
for
_
in
range
(
num_model_chunks
)
]
curr_iters
:
List
[
int
]
=
[
0
for
_
in
range
(
num_model_chunks
)]
losses_reduced
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
if
not
forward_only
:
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
num_model_chunks
)]
output_tensor_grads
:
List
[
List
[
Union
[
None
,
torch
.
Tensor
]]]
=
[
[]
for
_
in
range
(
num_model_chunks
)
]
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_size
:
int
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
pipeline_parallel_rank
:
int
=
parallel_state
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
# Compute number of warmup and remaining microbatches.
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
num_microbatches
:
int
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
all_warmup_microbatches
:
bool
=
False
if
forward_only
:
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
num_warmup_microbatches
:
int
=
num_microbatches
else
:
else
:
# Run all forward passes and then all backward passes if number of
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# microbatches is just the number of pipeline stages.
...
@@ -83,10 +125,12 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -83,10 +125,12 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches
=
num_microbatches
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
all_warmup_microbatches
=
True
else
:
else
:
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
num_microbatches_remaining
:
int
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
f
"num_microbatches:
{
num_microbatches
}
, "
...
@@ -100,24 +144,26 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -100,24 +144,26 @@ def _forward_backward_pipelining_with_interleaving(
def
get_model_chunk_id
(
microbatch_id
:
int
,
forward
:
bool
)
->
int
:
def
get_model_chunk_id
(
microbatch_id
:
int
,
forward
:
bool
)
->
int
:
"""Helper function to get the model chunk ID given the iteration number."""
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
if
not
forward
:
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
,
curr_iters
)
:
def
forward_step_helper
(
microbatch_id
:
int
,
curr_iters
:
List
[
int
])
->
torch
.
Tensor
:
"""Helper method to run forward step with model split into chunks
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
"""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# forward step
# forward step
if
(
if
parallel_state
.
is_pipeline_first_stage
()
and
len
(
parallel_state
.
is_pipeline_first_stage
()
and
input_tensors
[
model_chunk_id
]
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
])
)
==
len
(
output_tensors
[
model_chunk_id
]):
):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
output_tensor
=
forward_step
(
...
@@ -126,6 +172,8 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -126,6 +172,8 @@ def _forward_backward_pipelining_with_interleaving(
model
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
input_tensor
,
losses_reduced
,
losses_reduced
,
dtype
,
disable_autocast
,
)
)
curr_iters
[
model_chunk_id
]
+=
1
curr_iters
[
model_chunk_id
]
+=
1
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
...
@@ -137,11 +185,13 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -137,11 +185,13 @@ def _forward_backward_pipelining_with_interleaving(
return
output_tensor
return
output_tensor
def
backward_step_helper
(
microbatch_id
)
:
def
backward_step_helper
(
microbatch_id
:
int
)
->
torch
.
Tensor
:
"""Helper method to run backward step with model split into chunks
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
"""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
model_type
=
get_model_type
(
model
[
model_chunk_id
])
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
...
@@ -150,7 +200,14 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -150,7 +200,14 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
deallocate_pipeline_outputs
=
deallocate_pipeline_outputs
,
)
return
input_tensor_grad
return
input_tensor_grad
...
@@ -158,7 +215,14 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -158,7 +215,14 @@ def _forward_backward_pipelining_with_interleaving(
# Run warmup forward passes.
# Run warmup forward passes.
###################################################################################################################
###################################################################################################################
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
))
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
_logger
.
info
(
"Warmup phase"
)
_logger
.
info
(
"Warmup phase"
)
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
k
}
/
{
num_warmup_microbatches
}
"
)
_logger
.
debug
(
f
"warmup iter:
{
k
}
/
{
num_warmup_microbatches
}
"
)
...
@@ -172,7 +236,9 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -172,7 +236,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
False
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
recv_prev
=
False
_logger
.
debug
(
f
"next fwd model chunk ID:
{
next_forward_model_chunk_id
}
, recv_prev:
{
recv_prev
}
"
)
_logger
.
debug
(
f
"next fwd model chunk ID:
{
next_forward_model_chunk_id
}
, recv_prev:
{
recv_prev
}
"
)
# Don't send tensor downstream if on last stage.
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
...
@@ -181,7 +247,11 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -181,7 +247,11 @@ def _forward_backward_pipelining_with_interleaving(
# Send and receive tensors as appropriate (send tensors computed
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
:
if
(
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
):
input_tensor_grad
=
None
input_tensor_grad
=
None
recv_next
=
True
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
@@ -196,12 +266,23 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -196,12 +266,23 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
else
:
_logger
.
debug
(
"send fwd and receive fwd"
)
_logger
.
debug
(
"send fwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
)
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
###################################################################################################################
###################################################################################################################
# Run 1F1B in steady state.
# Run 1F1B in steady state.
...
@@ -229,7 +310,9 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -229,7 +310,9 @@ def _forward_backward_pipelining_with_interleaving(
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
_logger
.
debug
(
f
"fwd/bwd model chunk id:
{
forward_model_chunk_id
}
/
{
backward_model_chunk_id
}
"
)
_logger
.
debug
(
f
"fwd/bwd model chunk id:
{
forward_model_chunk_id
}
/
{
backward_model_chunk_id
}
"
)
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
input_tensor_grad
=
None
...
@@ -245,7 +328,9 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -245,7 +328,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
False
recv_prev
=
False
next_forward_model_chunk_id
+=
1
next_forward_model_chunk_id
+=
1
else
:
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
@@ -257,7 +342,9 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -257,7 +342,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_next
=
False
recv_next
=
False
next_backward_model_chunk_id
-=
1
next_backward_model_chunk_id
-=
1
else
:
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
# before the start of the for loop.
...
@@ -275,7 +362,11 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -275,7 +362,11 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
# Put input_tensor and output_tensor_grad in data structures in the
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
# right location.
...
@@ -290,9 +381,18 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -290,9 +381,18 @@ def _forward_backward_pipelining_with_interleaving(
_logger
.
info
(
"Cooldown phase"
)
_logger
.
info
(
"Cooldown phase"
)
if
not
forward_only
:
if
not
forward_only
:
if
all_warmup_microbatches
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
))
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
_logger
.
debug
(
f
"cooldown iter
{
k
}
in range(
{
num_microbatches_remaining
}
,
{
num_microbatches
}
)"
)
_logger
.
debug
(
f
"cooldown iter
{
k
}
in range(
{
num_microbatches_remaining
}
,
{
num_microbatches
}
)"
)
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
recv_next
=
True
...
@@ -302,7 +402,14 @@ def _forward_backward_pipelining_with_interleaving(
...
@@ -302,7 +402,14 @@ def _forward_backward_pipelining_with_interleaving(
if
k
==
(
num_microbatches
-
1
):
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
)
return
losses_reduced
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
View file @
96850dfa
from
typing
import
Union
,
List
,
Optional
from
typing
import
Union
,
List
,
Optional
,
Sequence
import
warnings
import
torch
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer.enums
import
ModelType
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.p2p_communication
import
FutureTensor
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
from
apex.transformer.pipeline_parallel.schedules.common
import
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
free_output_tensor
from
apex.transformer.log_util
import
get_transformer_logger
from
apex.transformer.log_util
import
get_transformer_logger
...
@@ -19,14 +25,222 @@ __all__ = ["forward_backward_pipelining_without_interleaving"]
...
@@ -19,14 +25,222 @@ __all__ = ["forward_backward_pipelining_without_interleaving"]
_logger
=
get_transformer_logger
(
__name__
)
_logger
=
get_transformer_logger
(
__name__
)
def
get_tensor_shapes
(
rank
:
int
,
model_type
:
ModelType
,
*
,
tensor_shape
:
Union
[
List
[
int
],
torch
.
Size
],
decoder_sequence_length
:
Optional
[
int
]
=
None
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
Sequence
[
Sequence
[
int
]]:
"""Get tensors shapes
Args:
rank: pipeline parallel rank
model_type:
Keyword Args:
tensor_shape:
decoder_sequence_length:
sequence_parallel_enabled:
"""
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
assert
(
len
(
tensor_shape
)
==
3
),
f
"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but
{
tensor_shape
}
"
sequence_length
,
micro_batch_size
,
hidden_size
=
tensor_shape
tensor_shapes
=
[]
if
sequence_parallel_enabled
:
seq_length
=
sequence_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
sequence_length
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
sequence_parallel_enabled
:
dec_seq_length
=
decoder_sequence_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
else
:
dec_seq_length
=
decoder_sequence_length
if
parallel_state
.
is_pipeline_stage_before_split
(
rank
):
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
else
:
tensor_shapes
.
append
((
dec_seq_length
,
micro_batch_size
,
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
else
:
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
return
tensor_shapes
def
recv_forward
(
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
return
input_tensors
def
recv_backward
(
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
return
output_tensor_grads
def
send_forward
(
output_tensors
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
None
:
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
def
send_backward
(
input_tensor_grads
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
None
:
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
def
send_forward_recv_backward
(
output_tensors
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
:
FwdStepFunc
,
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
batch
:
Optional
[
Batch
],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
*
,
forward_only
:
bool
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
decoder_sequence_length
:
Optional
[
int
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
disable_autocast
:
bool
=
False
,
deallocate_pipeline_outputs
:
bool
=
False
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
**
kwargs
,
)
->
List
[
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
This pipeline parallel scheduling consists of three steps:
...
@@ -44,28 +258,59 @@ def forward_backward_pipelining_without_interleaving(
...
@@ -44,28 +258,59 @@ def forward_backward_pipelining_without_interleaving(
Keyword args:
Keyword args:
forward_only:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
"""
# timers = get_timers()
# timers = get_timers()
model
=
listify_model
(
model
)
if
deallocate_pipeline_outputs
:
warnings
.
warn
(
"`deallocate_pipeline_outputs` is experimental and subject to change. "
"This option is not recommended."
)
model
:
List
[
torch
.
nn
.
Module
]
=
listify_model
(
model
)
if
len
(
model
)
!=
1
:
if
len
(
model
)
!=
1
:
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
model
:
torch
.
nn
.
Module
=
model
[
0
]
# Compute number of warmup microbatches.
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_microbatches
:
int
=
get_num_microbatches
()
num_warmup_microbatches
=
(
num_warmup_microbatches
:
int
=
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
-
parallel_state
.
get_pipeline_model_parallel_rank
()
)
-
1
num_warmup_microbatches
:
int
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
:
int
=
num_microbatches
-
num_warmup_microbatches
model_type
=
get_model_type
(
model
)
rank
:
int
=
parallel_state
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
:
List
[
List
[
int
]]
=
get_tensor_shapes
(
rank
-
1
,
model_type
,
tensor_shape
=
tensor_shape
,
decoder_sequence_length
=
decoder_sequence_length
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
send_tensor_shapes
:
List
[
List
[
int
]]
=
get_tensor_shapes
(
rank
,
model_type
,
tensor_shape
=
tensor_shape
,
decoder_sequence_length
=
decoder_sequence_length
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
f
"num_microbatches:
{
num_microbatches
}
, "
...
@@ -74,13 +319,9 @@ def forward_backward_pipelining_without_interleaving(
...
@@ -74,13 +319,9 @@ def forward_backward_pipelining_without_interleaving(
)
)
# Input, output tensors only need to be saved when doing backward passes
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
input_tensors
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
output_tensors
=
None
output_tensors
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
if
not
forward_only
:
losses_reduced
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
###################################################################################################################
###################################################################################################################
# Run warmup forward passes.
# Run warmup forward passes.
###################################################################################################################
###################################################################################################################
...
@@ -88,22 +329,42 @@ def forward_backward_pipelining_without_interleaving(
...
@@ -88,22 +329,42 @@ def forward_backward_pipelining_without_interleaving(
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
i
}
/
{
num_warmup_microbatches
}
"
)
_logger
.
debug
(
f
"warmup iter:
{
i
}
/
{
num_warmup_microbatches
}
"
)
_logger
.
debug
(
"receive fwd"
)
_logger
.
debug
(
"receive fwd"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
input_tensor
=
recv_forward
(
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
)
tensor_shapes
=
recv_tensor_shapes
,
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
cur_microbatch
:
Optional
[
torch
.
Tensor
]
=
get_kth_microbatch
(
batch
,
i
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
,
dtype
,
disable_autocast
,
)
_logger
.
debug
(
"send fwd"
)
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
send_forward
(
output_tensor
,
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
not
forward_only
:
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
# Before running 1F1B, need to receive first forward tensor.
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
num_microbatches_remaining
>
0
:
_logger
.
debug
(
"recv_forward before steady state start"
)
_logger
.
debug
(
"recv_forward before steady state start"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
input_tensor
:
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]
=
recv_forward
(
tensor_shape
s
=
recv_
tensor_shape
s
,
dtype
=
dtype
,
async_comm
=
async_comm
)
###################################################################################################################
###################################################################################################################
# Run 1F1B in steady state.
# Run 1F1B in steady state.
...
@@ -111,42 +372,84 @@ def forward_backward_pipelining_without_interleaving(
...
@@ -111,42 +372,84 @@ def forward_backward_pipelining_without_interleaving(
_logger
.
info
(
"Steady phase"
)
_logger
.
info
(
"Steady phase"
)
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
_logger
.
debug
(
f
"steady iter:
{
i
}
/
{
num_microbatches_remaining
}
"
)
_logger
.
debug
(
f
"steady iter:
{
i
}
/
{
num_microbatches_remaining
}
"
)
last_iteration
=
i
==
(
num_microbatches_remaining
-
1
)
last_iteration
:
bool
=
i
==
(
num_microbatches_remaining
-
1
)
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
+
num_warmup_microbatches
)
cur_microbatch
:
Optional
[
torch
.
Tensor
]
=
get_kth_microbatch
(
batch
,
i
+
num_warmup_microbatches
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
output_tensor
:
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
,
dtype
,
disable_autocast
,
)
if
forward_only
:
if
forward_only
:
_logger
.
debug
(
"send fwd"
)
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
send_forward
(
output_tensor
,
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
not
last_iteration
:
if
not
last_iteration
:
_logger
.
debug
(
"receive fwd (last iteration)"
)
_logger
.
debug
(
"receive fwd (last iteration)"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
input_tensor
=
recv_forward
(
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
else
:
else
:
_logger
.
debug
(
"send fwd & receive bwd"
)
_logger
.
debug
(
"send fwd & receive bwd"
)
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
output_tensor_grad
=
send_forward_recv_backward
(
output_tensor
,
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
# Add input_tensor and output_tensor to end of list.
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
deallocate_pipeline_outputs
=
deallocate_pipeline_outputs
,
)
)
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
_logger
.
debug
(
"send bwd"
)
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
send_backward
(
input_tensor_grad
,
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
else
:
else
:
_logger
.
debug
(
"send bwd and receive fwd"
)
_logger
.
debug
(
"send bwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor
=
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
input_tensor_grad
,
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
###################################################################################################################
###################################################################################################################
# Run cooldown backward passes.
# Run cooldown backward passes.
###################################################################################################################
###################################################################################################################
...
@@ -158,13 +461,29 @@ def forward_backward_pipelining_without_interleaving(
...
@@ -158,13 +461,29 @@ def forward_backward_pipelining_without_interleaving(
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
_logger
.
debug
(
"receive bwd"
)
_logger
.
debug
(
"receive bwd"
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
)
output_tensor_grad
=
recv_backward
(
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
input_tensor_grad
=
backward_step
(
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
deallocate_pipeline_outputs
=
deallocate_pipeline_outputs
,
)
)
_logger
.
debug
(
"send bwd"
)
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
send_backward
(
input_tensor_grad
,
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
return
losses_reduced
return
losses_reduced
apex/transformer/pipeline_parallel/utils.py
View file @
96850dfa
...
@@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
...
@@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.transformer
import
parallel_state
from
apex.transformer
import
parallel_state
from
apex.transformer.enums
import
ModelType
from
apex.transformer.microbatches
import
build_num_microbatches_calculator
from
apex.transformer.microbatches
import
build_num_microbatches_calculator
from
apex.transformer.pipeline_parallel._timers
import
_Timers
from
apex.transformer.pipeline_parallel._timers
import
_Timers
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
...
@@ -118,14 +119,24 @@ def _split_batch_into_microbatch(
...
@@ -118,14 +119,24 @@ def _split_batch_into_microbatch(
# TODO(mkozuki): Support non-tensor local minibatches?
# TODO(mkozuki): Support non-tensor local minibatches?
def
get_kth_microbatch
(
batch
:
List
[
torch
.
Tensor
],
k
:
int
)
->
List
[
torch
.
Tensor
]:
def
get_kth_microbatch
(
batch
:
Optional
[
List
[
torch
.
Tensor
]
]
,
k
:
int
)
->
List
[
torch
.
Tensor
]:
"""Create a list of microbatches from a list of local minibatches.
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
"""
if
batch
is
None
:
return
batch
micro_batch_size
=
get_micro_batch_size
()
micro_batch_size
=
get_micro_batch_size
()
return
[
x
[
k
*
micro_batch_size
:(
k
+
1
)
*
micro_batch_size
]
for
x
in
batch
]
start
=
k
*
micro_batch_size
end
=
start
+
micro_batch_size
microbatch
=
list
()
for
x
in
batch
:
size
=
x
.
size
(
0
)
assert
size
>
start
and
size
>=
end
microbatch
.
append
(
x
[
start
:
end
])
assert
len
(
microbatch
)
>
0
return
microbatch
def
get_autoresume
():
def
get_autoresume
():
...
@@ -186,6 +197,19 @@ def unwrap_model(model, module_instances=(DistributedDataParallel,)):
...
@@ -186,6 +197,19 @@ def unwrap_model(model, module_instances=(DistributedDataParallel,)):
return
unwrapped_model
return
unwrapped_model
def
get_model_type
(
model
:
torch
.
nn
.
Module
,
)
->
ModelType
:
"""Get `model_type` of `model`.
If ``model`` doesn't have ``model_type`` attribute, return ``ModelType.encoder_or_decoder``.
Args:
model
"""
return
getattr
(
unwrap_model
(
model
),
"model_type"
,
ModelType
.
encoder_or_decoder
)
def
calc_params_l2_norm
(
model
:
torch
.
nn
.
Module
,
bf16
:
bool
):
def
calc_params_l2_norm
(
model
:
torch
.
nn
.
Module
,
bf16
:
bool
):
"""Calculate l2 norm of parameters """
"""Calculate l2 norm of parameters """
# args = get_args()
# args = get_args()
...
...
apex/transformer/tensor_parallel/__init__.py
View file @
96850dfa
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -33,6 +32,7 @@ from apex.transformer.tensor_parallel.mappings import (
...
@@ -33,6 +32,7 @@ from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
scatter_to_sequence_parallel_region
,
)
)
from
.random
import
(
from
.random
import
(
...
@@ -63,6 +63,7 @@ __all__ = [
...
@@ -63,6 +63,7 @@ __all__ = [
"gather_from_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"reduce_from_tensor_model_parallel_region"
,
"reduce_from_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
# random.py
# random.py
"checkpoint"
,
"checkpoint"
,
"get_cuda_rng_tracker"
,
"get_cuda_rng_tracker"
,
...
...
apex/transformer/tensor_parallel/data.py
View file @
96850dfa
...
@@ -25,8 +25,9 @@ _MAX_DATA_DIM = 5
...
@@ -25,8 +25,9 @@ _MAX_DATA_DIM = 5
def
_check_data_types
(
keys
,
data
,
target_dtype
):
def
_check_data_types
(
keys
,
data
,
target_dtype
):
"""Check that all the keys have the same target data type."""
"""Check that all the keys have the same target data type."""
for
key
in
keys
:
for
key
in
keys
:
assert
data
[
key
].
dtype
==
target_dtype
,
"{} has data type {} which "
"is different than {}"
.
format
(
assert
data
[
key
].
dtype
==
target_dtype
,
(
key
,
data
[
key
].
dtype
,
target_dtype
"{} has data type {} which "
"is different than {}"
.
format
(
key
,
data
[
key
].
dtype
,
target_dtype
)
)
)
...
@@ -48,7 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data):
...
@@ -48,7 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
# Move to GPU and broadcast.
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
torch
.
distributed
.
broadcast
(
torch
.
distributed
.
broadcast
(
sizes_cuda
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
sizes_cuda
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
)
)
# Move back to cpu and unpack.
# Move back to cpu and unpack.
...
@@ -92,13 +95,19 @@ def broadcast_data(keys, data, datatype):
...
@@ -92,13 +95,19 @@ def broadcast_data(keys, data, datatype):
# Check that all keys have the same data type.
# Check that all keys have the same data type.
_check_data_types
(
keys
,
data
,
datatype
)
_check_data_types
(
keys
,
data
,
datatype
)
# Flatten the data associated with the keys
# Flatten the data associated with the keys
flatten_data
=
torch
.
cat
([
data
[
key
].
contiguous
().
view
(
-
1
)
for
key
in
keys
],
dim
=
0
).
cuda
()
flatten_data
=
torch
.
cat
(
[
data
[
key
].
contiguous
().
view
(
-
1
)
for
key
in
keys
],
dim
=
0
).
cuda
()
else
:
else
:
flatten_data
=
torch
.
empty
(
total_numel
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
datatype
)
flatten_data
=
torch
.
empty
(
total_numel
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
datatype
)
# Broadcast
# Broadcast
torch
.
distributed
.
broadcast
(
torch
.
distributed
.
broadcast
(
flatten_data
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
flatten_data
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
)
)
# Unpack
# Unpack
...
...
apex/transformer/tensor_parallel/layers.py
View file @
96850dfa
# coding=utf-8
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -16,6 +16,9 @@
...
@@ -16,6 +16,9 @@
# Parts of the code here are adapted from PyTorch
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
# repo: https://github.com/pytorch/pytorch
from
typing
import
Optional
,
Dict
,
Tuple
,
List
import
warnings
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
import
torch.nn.init
as
init
...
@@ -26,12 +29,34 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_group
...
@@ -26,12 +29,34 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_group
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer.utils
import
divide
from
apex.transformer.utils
import
divide
from
apex.transformer.tensor_parallel.mappings
import
copy_to_tensor_model_parallel_region
from
apex.transformer.tensor_parallel.mappings
import
(
from
apex.transformer.tensor_parallel.mappings
import
gather_from_tensor_model_parallel_region
copy_to_tensor_model_parallel_region
,
from
apex.transformer.tensor_parallel.mappings
import
reduce_from_tensor_model_parallel_region
)
from
apex.transformer.tensor_parallel.mappings
import
scatter_to_tensor_model_parallel_region
from
apex.transformer.tensor_parallel.mappings
import
(
gather_from_tensor_model_parallel_region
,
)
from
apex.transformer.tensor_parallel.mappings
import
(
reduce_from_tensor_model_parallel_region
,
)
from
apex.transformer.tensor_parallel.mappings
import
(
scatter_to_tensor_model_parallel_region
,
)
from
apex.transformer.tensor_parallel.mappings
import
(
reduce_scatter_to_sequence_parallel_region
,
)
from
apex.transformer.tensor_parallel.random
import
get_cuda_rng_tracker
from
apex.transformer.tensor_parallel.random
import
get_cuda_rng_tracker
from
apex.transformer.tensor_parallel.utils
import
VocabUtility
from
apex.transformer.tensor_parallel.utils
import
VocabUtility
from
apex.transformer.log_util
import
get_transformer_logger
_logger
=
get_transformer_logger
(
__name__
)
_grad_accum_fusion_available
=
True
try
:
import
fused_weight_gradient_mlp_cuda
except
ImportError
:
_grad_accum_fusion_available
=
False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
...
@@ -41,13 +66,13 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
...
@@ -41,13 +66,13 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
}
}
def
param_is_not_tensor_parallel_duplicate
(
param
)
:
def
param_is_not_tensor_parallel_duplicate
(
param
:
torch
.
Tensor
)
->
bool
:
return
(
hasattr
(
param
,
"tensor_model_parallel"
)
and
param
.
tensor_model_parallel
)
or
(
return
(
get_
tensor_model_parallel
_rank
()
==
0
hasattr
(
param
,
"
tensor_model_parallel
"
)
and
param
.
tensor_model_parallel
)
)
or
(
get_tensor_model_parallel_rank
()
==
0
)
def
set_tensor_model_parallel_attributes
(
tensor
,
is_parallel
,
dim
,
stride
)
:
def
set_tensor_model_parallel_attributes
(
tensor
:
torch
.
Tensor
,
is_parallel
:
bool
,
dim
:
int
,
stride
:
int
)
->
None
:
# Make sure the attributes are not set.
# Make sure the attributes are not set.
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
assert
not
hasattr
(
tensor
,
attribute
)
assert
not
hasattr
(
tensor
,
attribute
)
...
@@ -57,7 +82,7 @@ def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
...
@@ -57,7 +82,7 @@ def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
setattr
(
tensor
,
"partition_stride"
,
stride
)
setattr
(
tensor
,
"partition_stride"
,
stride
)
def
set_defaults_if_not_set_tensor_model_parallel_attributes
(
tensor
)
:
def
set_defaults_if_not_set_tensor_model_parallel_attributes
(
tensor
:
torch
.
Tensor
)
->
None
:
def
maybe_set
(
attribute
,
value
):
def
maybe_set
(
attribute
,
value
):
if
not
hasattr
(
tensor
,
attribute
):
if
not
hasattr
(
tensor
,
attribute
):
setattr
(
tensor
,
attribute
,
value
)
setattr
(
tensor
,
attribute
,
value
)
...
@@ -66,7 +91,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
...
@@ -66,7 +91,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
maybe_set
(
attribute
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
[
attribute
])
maybe_set
(
attribute
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
[
attribute
])
def
copy_tensor_model_parallel_attributes
(
destination_tensor
,
source_tensor
)
:
def
copy_tensor_model_parallel_attributes
(
destination_tensor
:
torch
.
Tensor
,
source_tensor
:
torch
.
Tensor
)
->
None
:
def
maybe_copy
(
attribute
):
def
maybe_copy
(
attribute
):
if
hasattr
(
source_tensor
,
attribute
):
if
hasattr
(
source_tensor
,
attribute
):
setattr
(
destination_tensor
,
attribute
,
getattr
(
source_tensor
,
attribute
))
setattr
(
destination_tensor
,
attribute
,
getattr
(
source_tensor
,
attribute
))
...
@@ -76,9 +101,18 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
...
@@ -76,9 +101,18 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def
_initialize_affine_weight_gpu
(
weight
,
init_method
,
partition_dim
,
stride
=
1
):
def
_initialize_affine_weight_gpu
(
weight
,
init_method
,
partition_dim
,
stride
=
1
):
"""Initialize affine weight for model parallel on GPU.
"""
"""Initialize affine weight for model parallel on GPU.
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
Args:
weight (Parameter):
init_method (Callable[[Tensor], None]): Taking a Tensor and initialize its elements.
partition_dim (int): Dimension to apply partition.
stride (int):
"""
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
with
get_cuda_rng_tracker
().
fork
():
with
get_cuda_rng_tracker
().
fork
():
init_method
(
weight
)
init_method
(
weight
)
...
@@ -103,16 +137,22 @@ def _initialize_affine_weight_cpu(
...
@@ -103,16 +137,22 @@ def _initialize_affine_weight_cpu(
Build the master weight on all processes and scatter
Build the master weight on all processes and scatter
the relevant chunk."""
the relevant chunk."""
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
# Initialize master weight
# Initialize master weight
master_weight
=
torch
.
empty
(
output_size
,
input_size
,
dtype
=
torch
.
float
,
requires_grad
=
False
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
,
dtype
=
torch
.
float
,
requires_grad
=
False
)
init_method
(
master_weight
)
init_method
(
master_weight
)
master_weight
=
master_weight
.
to
(
dtype
=
params_dtype
)
master_weight
=
master_weight
.
to
(
dtype
=
params_dtype
)
# Split and copy
# Split and copy
per_partition_per_stride_size
=
divide
(
per_partition_size
,
stride
)
per_partition_per_stride_size
=
divide
(
per_partition_size
,
stride
)
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
my_weight_list
=
weight_list
[
rank
::
world_size
]
my_weight_list
=
weight_list
[
rank
::
world_size
]
...
@@ -136,9 +176,15 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -136,9 +176,15 @@ class VocabParallelEmbedding(torch.nn.Module):
"""
"""
def
__init__
(
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
,
*
,
params_dtype
=
torch
.
float32
,
use_cpu_initialization
=
False
,
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
init_method
=
init
.
xavier_normal_
,
*
,
params_dtype
:
torch
.
dtype
=
torch
.
float32
,
use_cpu_initialization
:
bool
=
False
,
):
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
super
().
__init__
()
# Keep the input dimensions.
# Keep the input dimensions.
self
.
num_embeddings
=
num_embeddings
self
.
num_embeddings
=
num_embeddings
self
.
embedding_dim
=
embedding_dim
self
.
embedding_dim
=
embedding_dim
...
@@ -150,19 +196,35 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -150,19 +196,35 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
sparse
=
False
self
.
sparse
=
False
self
.
_weight
=
None
self
.
_weight
=
None
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
# Divide the weight matrix along the vocabulary dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
VocabUtility
.
vocab_range_from_global_vocab_size
(
(
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
tensor_model_parallel_size
self
.
vocab_start_index
,
self
.
vocab_end_index
,
)
=
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
tensor_model_parallel_size
,
)
self
.
num_embeddings_per_partition
=
(
self
.
vocab_end_index
-
self
.
vocab_start_index
)
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
self
.
vocab_start_index
# Allocate weights and initialize.
# Allocate weights and initialize.
if
use_cpu_initialization
:
if
use_cpu_initialization
:
self
.
weight
=
Parameter
(
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
dtype
=
params_dtype
)
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
dtype
=
params_dtype
,
)
)
)
_initialize_affine_weight_cpu
(
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
,
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
)
)
else
:
else
:
...
@@ -174,12 +236,16 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -174,12 +236,16 @@ class VocabParallelEmbedding(torch.nn.Module):
dtype
=
params_dtype
,
dtype
=
params_dtype
,
)
)
)
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
if
self
.
tensor_model_parallel_size
>
1
:
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
masked_input
[
input_mask
]
=
0
...
@@ -203,16 +269,44 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -203,16 +269,44 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""
"""Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop."""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
use_16bit_in_wgrad_accum_fusion
:
bool
=
False
,
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
sequence_parallel_enabled
=
sequence_parallel_enabled
ctx
.
use_16bit_in_wgrad_accum_fusion
=
use_16bit_in_wgrad_accum_fusion
if
ctx
.
sequence_parallel_enabled
:
world_size
=
get_tensor_model_parallel_world_size
()
# `input` is supposed to be 3D and its order of dimension is [sequence, batch, hidden]
shape
=
list
(
input
.
shape
)
shape
[
0
]
*=
world_size
all_gather_buffer
=
torch
.
empty
(
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
())
total_input
=
all_gather_buffer
else
:
total_input
=
input
output
=
torch
.
matmul
(
total_input
,
weight
.
t
())
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
return
output
...
@@ -221,23 +315,115 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
...
@@ -221,23 +315,115 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
if
ctx
.
sequence_parallel_enabled
:
world_size
=
get_tensor_model_parallel_world_size
()
shape
=
list
(
input
.
shape
)
shape
[
0
]
*=
world_size
all_gather_buffer
=
torch
.
empty
(
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
,
)
total_input
=
all_gather_buffer
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
if
ctx
.
sequence_parallel_enabled
:
handle
.
wait
()
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
]
)
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
if
ctx
.
sequence_parallel_enabled
:
assert
not
ctx
.
async_grad_allreduce
sub_grad_input
=
torch
.
empty
(
input
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
if
ctx
.
gradient_accumulation_fusion
:
if
not
ctx
.
use_16bit_in_wgrad_accum_fusion
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
else
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
def
column_parallel_linear
(
input
,
weight
,
bias
):
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
args
=
_cast_if_autocast_enabled
(
input
,
weight
,
bias
)
if
ctx
.
sequence_parallel_enabled
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
def
linear_with_grad_accumulation_and_async_allreduce
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
)
->
torch
.
Tensor
:
args
=
_cast_if_autocast_enabled
(
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel_enabled
,
False
,
# use_16bit_in_wgrad_accum_fusion
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
*
args
)
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
def
linear_with_grad_accumulation_and_async_allreduce_in16bit
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
)
->
torch
.
Tensor
:
args
=
_cast_if_autocast_enabled
(
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel_enabled
,
True
,
# use_16bit_in_wgrad_accum_fusion
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
@@ -246,6 +432,10 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -246,6 +432,10 @@ class ColumnParallelLinear(torch.nn.Module):
The linear layer is defined as Y = XA + b. A is parallelized along
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
its second dimension as A = [A_1, ..., A_p].
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
Arguments:
input_size: first dimension of matrix A.
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
output_size: second dimension of matrix A.
...
@@ -262,6 +452,14 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -262,6 +452,14 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
can be fused with other elementwise operations. we skip
adding bias but instead return it.
adding bias but instead return it.
Keyword Arguments:
no_async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
"""
def
__init__
(
def
__init__
(
...
@@ -278,8 +476,11 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -278,8 +476,11 @@ class ColumnParallelLinear(torch.nn.Module):
no_async_tensor_model_parallel_allreduce
=
False
,
no_async_tensor_model_parallel_allreduce
=
False
,
params_dtype
=
torch
.
float32
,
params_dtype
=
torch
.
float32
,
use_cpu_initialization
=
False
,
use_cpu_initialization
=
False
,
gradient_accumulation_fusion
=
False
,
accumulation_in_fp16
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
):
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
super
().
__init__
()
# Keep input parameters
# Keep input parameters
self
.
input_size
=
input_size
self
.
input_size
=
input_size
...
@@ -295,7 +496,9 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -295,7 +496,9 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose.
# we allocate the transpose.
# Initialize weight.
# Initialize weight.
if
use_cpu_initialization
:
if
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
dtype
=
params_dtype
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
dtype
=
params_dtype
)
)
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
weight
,
self
.
output_size
,
self
.
output_size
,
...
@@ -323,7 +526,11 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -323,7 +526,11 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
))
else
:
else
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
)
torch
.
empty
(
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
,
)
)
)
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
stride
)
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
stride
)
# Always initialize bias to zero.
# Always initialize bias to zero.
...
@@ -333,28 +540,69 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -333,28 +540,69 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
self
.
async_tensor_model_parallel_allreduce
=
(
not
no_async_tensor_model_parallel_allreduce
and
not
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
world_size
>
1
)
)
if
sequence_parallel_enabled
:
if
world_size
<=
1
:
warnings
.
warn
(
f
"`sequence_parallel_enabled` is set to `True`, but got world_size of
{
world_size
}
"
)
# sequence_parallel_enabled = False
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
gradient_accumulation_fusion
:
if
not
_grad_accum_fusion_available
:
# Basically, apex.transformer module users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
warnings
.
warn
(
"`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
"found. Thus `gradient_accumulation_fusion` set to `False`. "
"Note that the extension requires CUDA>=11."
)
gradient_accumulation_fusion
=
False
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
if
self
.
async_tensor_model_parallel_allreduce
and
self
.
sequence_parallel_enabled
:
raise
RuntimeError
(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time."
)
self
.
_forward_impl
=
(
linear_with_grad_accumulation_and_async_allreduce_in16bit
if
accumulation_in_fp16
else
linear_with_grad_accumulation_and_async_allreduce
)
def
forward
(
self
,
input_
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
:
if
self
.
async_tensor_model_parallel_allreduce
or
self
.
sequence_parallel_enabled
:
input_shape
=
input_
.
shape
input_parallel
=
input_
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Matrix multiply with asynchronous all-reduce execution
output_parallel
=
column_parallel_linear
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
weight
=
self
.
weight
,
bias
=
bias
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
self
.
async_tensor_model_parallel_allreduce
,
sequence_parallel_enabled
=
self
.
sequence_parallel_enabled
,
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
assert
not
self
.
sequence_parallel_enabled
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
@@ -374,6 +622,11 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -374,6 +622,11 @@ class RowParallelLinear(torch.nn.Module):
| . |
| . |
| A_p |
| A_p |
- -
- -
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
Arguments:
input_size: first dimension of matrix A.
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
output_size: second dimension of matrix A.
...
@@ -390,6 +643,12 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -390,6 +643,12 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimization where bias
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
can be fused with other elementwise operations. We skip
adding bias but instead return it.
adding bias but instead return it.
Keyword Arguments:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
"""
def
__init__
(
def
__init__
(
...
@@ -405,8 +664,11 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -405,8 +664,11 @@ class RowParallelLinear(torch.nn.Module):
*
,
*
,
params_dtype
=
torch
.
float32
,
params_dtype
=
torch
.
float32
,
use_cpu_initialization
=
False
,
use_cpu_initialization
=
False
,
gradient_accumulation_fusion
=
False
,
accumulation_in_fp16
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
):
):
super
(
RowParallelLinear
,
self
).
__init__
()
super
().
__init__
()
# Keep input parameters
# Keep input parameters
self
.
input_size
=
input_size
self
.
input_size
=
input_size
...
@@ -416,6 +678,10 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -416,6 +678,10 @@ class RowParallelLinear(torch.nn.Module):
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
self
.
sequence_parallel_enabled
and
not
self
.
input_is_parallel
:
raise
RuntimeError
(
"To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`"
)
# as an argument to this function?
# as an argument to this function?
# Parameters.
# Parameters.
...
@@ -423,7 +689,11 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -423,7 +689,11 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose.
# we allocate the transpose.
# Initialize weight.
# Initialize weight.
if
use_cpu_initialization
:
if
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
dtype
=
params_dtype
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
dtype
=
params_dtype
)
)
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
weight
,
self
.
output_size
,
self
.
output_size
,
...
@@ -444,30 +714,63 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -444,30 +714,63 @@ class RowParallelLinear(torch.nn.Module):
dtype
=
params_dtype
,
dtype
=
params_dtype
,
)
)
)
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
)
if
bias
:
if
bias
:
if
use_cpu_initialization
:
if
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
else
:
else
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
)
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
,
)
)
)
# Always initialize bias to zero.
# Always initialize bias to zero.
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
setattr
(
self
.
bias
,
"sequence_parallel_enabled"
,
sequence_parallel_enabled
)
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
input_
):
self
.
_forward_impl
=
(
linear_with_grad_accumulation_and_async_allreduce_in16bit
if
accumulation_in_fp16
else
linear_with_grad_accumulation_and_async_allreduce
)
def
forward
(
self
,
input_
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
assert
not
self
.
sequence_parallel_enabled
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
weight
=
self
.
weight
,
bias
=
None
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
False
,
sequence_parallel_enabled
=
False
,
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
self
.
sequence_parallel_enabled
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
output_bias
=
None
...
...
apex/transformer/tensor_parallel/mappings.py
View file @
96850dfa
# coding=utf-8
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -20,7 +20,7 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_rank
...
@@ -20,7 +20,7 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from
apex.transformer.tensor_parallel.utils
import
split_tensor_along_last_dim
from
apex.transformer.tensor_parallel.utils
import
split_tensor_along_last_dim
def
_reduce
(
input_
)
:
def
_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group."""
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
...
@@ -33,7 +33,7 @@ def _reduce(input_):
...
@@ -33,7 +33,7 @@ def _reduce(input_):
return
input_
return
input_
def
_split
(
input_
)
:
def
_split
_along_last_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Split the tensor along its last dimension and keep the
"""Split the tensor along its last dimension and keep the
corresponding slice."""
corresponding slice."""
...
@@ -52,8 +52,24 @@ def _split(input_):
...
@@ -52,8 +52,24 @@ def _split(input_):
return
output
return
output
def
_gather
(
input_
):
def
_split_along_first_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Gather tensors and concatinate along the last dimension."""
"""Split the tensor along its first dimension and keep the corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU for tensor model parallel.
if
world_size
==
1
:
return
input_
# Split along first dimension.
dim_size
=
input_
.
size
(
0
)
assert
dim_size
%
world_size
==
0
local_dim_size
=
dim_size
//
world_size
dim_offset
=
get_tensor_model_parallel_rank
()
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
return
output
def
_gather_along_last_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Gather tensors and concatenate along the last dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
...
@@ -66,7 +82,9 @@ def _gather(input_):
...
@@ -66,7 +82,9 @@ def _gather(input_):
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_tensor_model_parallel_group
())
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_tensor_model_parallel_group
()
)
# Note: torch.cat already creates a contiguous tensor.
# Note: torch.cat already creates a contiguous tensor.
output
=
torch
.
cat
(
tensor_list
,
dim
=
last_dim
).
contiguous
()
output
=
torch
.
cat
(
tensor_list
,
dim
=
last_dim
).
contiguous
()
...
@@ -74,9 +92,49 @@ def _gather(input_):
...
@@ -74,9 +92,49 @@ def _gather(input_):
return
output
return
output
def
_gather_along_first_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Gather tensors and concatenate along the first dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
shape
=
list
(
input_
.
shape
)
shape
[
0
]
*=
world_size
output
=
torch
.
empty
(
shape
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_all_gather_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
()
)
return
output
def
_reduce_scatter_along_first_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
shape
=
list
(
input_
.
shape
)
assert
shape
[
0
]
%
world_size
==
0
shape
[
0
]
//=
world_size
output
=
torch
.
empty
(
shape
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
()
)
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
"""Pass the input to the
tensor
model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
input_
return
input_
...
@@ -91,8 +149,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
...
@@ -91,8 +149,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the model parallel region."""
"""All-reduce the input from the
tensor
model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
return
_reduce
(
input_
)
...
@@ -109,33 +169,95 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
...
@@ -109,33 +169,95 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
"""Split the input and keep only the corresponding chuck to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
)
return
_gather
_along_last_dim
(
grad_output
)
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concat
i
nate."""
"""Gather the input from
tensor
model parallel region and concat
e
nate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
)
return
_split_along_last_dim
(
grad_output
)
class
_ScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
,
to_model_parallel
:
bool
=
True
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
to_model_parallel
:
bool
=
True
):
ctx
.
to_model_parallel
=
to_model_parallel
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
to_model_parallel
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
else
:
return
_split_along_first_dim
(
grad_output
),
None
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
# -----------------
# -----------------
...
@@ -143,17 +265,40 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
...
@@ -143,17 +265,40 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# -----------------
# -----------------
def
copy_to_tensor_model_parallel_region
(
input_
)
:
def
copy_to_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_CopyToModelParallelRegion
.
apply
(
input_
)
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_tensor_model_parallel_region
(
input_
)
:
def
reduce_from_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_tensor_model_parallel_region
(
input_
)
:
def
scatter_to_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_tensor_model_parallel_region
(
input_
)
:
def
gather_from_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_sequence_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_sequence_parallel_region
(
input_
:
torch
.
Tensor
,
to_model_parallel
:
bool
=
True
)
->
torch
.
Tensor
:
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
to_model_parallel
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
__all__
=
[
"copy_to_tensor_model_parallel_region"
,
"reduce_from_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
"gather_from_sequence_parallel_region"
,
"reduce_scatter_to_sequence_parallel_region"
,
]
apex/transformer/tensor_parallel/memory.py
View file @
96850dfa
# coding=utf-8
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,6 +13,8 @@
...
@@ -12,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO(mkozuki): Remove this file as Megatron-LM seems to have done so.
import
torch
import
torch
...
@@ -49,13 +52,20 @@ class MemoryBuffer:
...
@@ -49,13 +52,20 @@ class MemoryBuffer:
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
print
(
print
(
"> building the {} memory buffer with {} num elements "
"> building the {} memory buffer with {} num elements "
"and {} dtype ({:.1f} MB)..."
.
format
(
name
,
numel
,
dtype
,
numel
*
element_size
/
1024
/
1024
),
"and {} dtype ({:.1f} MB)..."
.
format
(
name
,
numel
,
dtype
,
numel
*
element_size
/
1024
/
1024
),
flush
=
True
,
flush
=
True
,
)
)
self
.
name
=
name
self
.
name
=
name
self
.
numel
=
numel
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
data
=
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
self
.
data
=
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
# Index tracking the start of the free memory.
# Index tracking the start of the free memory.
self
.
_start
=
0
self
.
_start
=
0
...
@@ -81,13 +91,17 @@ class MemoryBuffer:
...
@@ -81,13 +91,17 @@ class MemoryBuffer:
def
add
(
self
,
tensor
):
def
add
(
self
,
tensor
):
"""Allocate a chunk of memory from the buffer to tensor and copy
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
the values."""
assert
tensor
.
dtype
==
self
.
dtype
,
"Input tensor type {} different from buffer type {}"
.
format
(
assert
(
tensor
.
dtype
==
self
.
dtype
),
"Input tensor type {} different from buffer type {}"
.
format
(
tensor
.
dtype
,
self
.
dtype
tensor
.
dtype
,
self
.
dtype
)
)
# Number of elements of the input tensor.
# Number of elements of the input tensor.
tensor_numel
=
torch
.
numel
(
tensor
)
tensor_numel
=
torch
.
numel
(
tensor
)
new_start
=
self
.
_start
+
tensor_numel
new_start
=
self
.
_start
+
tensor_numel
assert
new_start
<=
self
.
numel
,
"Not enough memory left in the buffer ({} > {})"
.
format
(
assert
(
new_start
<=
self
.
numel
),
"Not enough memory left in the buffer ({} > {})"
.
format
(
tensor_numel
,
self
.
numel
-
self
.
_start
tensor_numel
,
self
.
numel
-
self
.
_start
)
)
# New tensor is a view into the memory.
# New tensor is a view into the memory.
...
@@ -124,7 +138,8 @@ class RingMemBuffer:
...
@@ -124,7 +138,8 @@ class RingMemBuffer:
def
__init__
(
self
,
name
,
num_buffers
,
numel
,
dtype
,
track_usage
):
def
__init__
(
self
,
name
,
num_buffers
,
numel
,
dtype
,
track_usage
):
self
.
num_buffers
=
num_buffers
self
.
num_buffers
=
num_buffers
self
.
buffers
=
[
self
.
buffers
=
[
allocate_mem_buff
(
name
+
" {}"
.
format
(
i
),
numel
,
dtype
,
track_usage
)
for
i
in
range
(
num_buffers
)
allocate_mem_buff
(
name
+
" {}"
.
format
(
i
),
numel
,
dtype
,
track_usage
)
for
i
in
range
(
num_buffers
)
]
]
self
.
_index
=
-
1
self
.
_index
=
-
1
...
...
apex/transformer/tensor_parallel/random.py
View file @
96850dfa
# coding=utf-8
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,9 +13,13 @@
...
@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO (mkozuki): Audit this file.
# NOTE(mkozuki): This file is based on megatron-lm/mpu/random.py with some differences:
# I don't think some functions strongly relate to `random` in tensor_parallel.
# - Not using "viewless" tensor:
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# - _kernel_make_viewless_tensor
# - MakeViewlessTensor
# - make_viewless_tensor
# - assert_viewless_tensor
# - safely_set_viewless_tensor_data
# Parts of the code here are adapted from PyTorch
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
# repo: https://github.com/pytorch/pytorch
...
@@ -35,13 +39,12 @@ from apex.transformer.utils import gather_split_1d_tensor
...
@@ -35,13 +39,12 @@ from apex.transformer.utils import gather_split_1d_tensor
# Default name for the model parallel rng tracker.
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
"model-parallel-rng"
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
"model-parallel-rng"
# TODO(mkozuki): Remove `_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER` as megatron-lm doesn't seem to use.
# Whether apply model parallelism to checkpointed hidden states.
# Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
# TODO (mkozuki): Consider the possibility of removing `tensor_model_parallel_size`,
# TODO(mkozuki): Remove `init_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
# `get_tensor_model_parallel_world_size()` might be alternative.
def
init_checkpointed_activations_memory_buffer
(
def
init_checkpointed_activations_memory_buffer
(
micro_batch_size
,
micro_batch_size
,
max_position_embeddings
,
max_position_embeddings
,
...
@@ -53,8 +56,15 @@ def init_checkpointed_activations_memory_buffer(
...
@@ -53,8 +56,15 @@ def init_checkpointed_activations_memory_buffer(
):
):
"""Initializ the memory buffer for the checkpointed activations."""
"""Initializ the memory buffer for the checkpointed activations."""
per_layer
=
micro_batch_size
*
max_position_embeddings
*
hidden_size
//
tensor_model_parallel_size
per_layer
=
(
assert
num_layers
%
checkpoint_num_layers
==
0
,
"number of layers is not divisible by checkpoint-num-layers"
micro_batch_size
*
max_position_embeddings
*
hidden_size
//
tensor_model_parallel_size
)
assert
(
num_layers
%
checkpoint_num_layers
==
0
),
"number of layers is not divisible by checkpoint-num-layers"
num_checkpointer_layers
=
num_layers
//
checkpoint_num_layers
num_checkpointer_layers
=
num_layers
//
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
dtype
=
torch
.
half
...
@@ -70,6 +80,7 @@ def init_checkpointed_activations_memory_buffer(
...
@@ -70,6 +80,7 @@ def init_checkpointed_activations_memory_buffer(
)
)
# TODO(mkozuki): Remove `reset_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def
reset_checkpointed_activations_memory_buffer
():
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
...
@@ -79,7 +90,7 @@ def reset_checkpointed_activations_memory_buffer():
...
@@ -79,7 +90,7 @@ def reset_checkpointed_activations_memory_buffer():
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
"""Sets the random number generator state of the current GPU.
Arguments
s
:
Arguments:
new_state (torch.ByteTensor): The desired state
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
with a single change: the input state is not cloned. Cloning caused
...
@@ -217,7 +228,9 @@ def model_parallel_cuda_manual_seed(seed):
...
@@ -217,7 +228,9 @@ def model_parallel_cuda_manual_seed(seed):
# Set the default state.
# Set the default state.
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
# and model parallel state.
# and model parallel state.
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
...
@@ -230,8 +243,9 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -230,8 +243,9 @@ class CheckpointFunction(torch.autograd.Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_saved_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
distribute_saved_activations
=
distribute_saved_activations
# Copy the rng states.
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -243,10 +257,8 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -243,10 +257,8 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
if
ctx
.
distribute_saved_activations
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
ctx
.
input_0_shape
=
args
[
0
].
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
args
[
0
].
data
)
# Store everything.
# Store everything.
ctx
.
save_for_backward
(
*
args
)
ctx
.
save_for_backward
(
*
args
)
...
@@ -255,11 +267,11 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -255,11 +267,11 @@ class CheckpointFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
):
def
backward
(
ctx
,
*
args
):
if
not
torch
.
autograd
.
_is_checkpoint_valid
():
if
not
torch
.
autograd
.
_is_checkpoint_valid
():
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
# Store the current states.
# Store the current states.
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -284,11 +296,16 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -284,11 +296,16 @@ class CheckpointFunction(torch.autograd.Function):
if
isinstance
(
outputs
,
torch
.
Tensor
):
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
(
outputs
,)
outputs
=
(
outputs
,)
torch
.
autograd
.
backward
(
outputs
,
args
)
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
grads
=
tuple
(
return
(
None
,)
+
grads
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
# NOTE(mkozuki): It doesn't look like `distribute_saved_activations` is used in apex.transformer
# but I added this change to reduce the superficial difference from Megatron-LM.
def
checkpoint
(
function
,
distribute_saved_activations
,
*
args
):
"""Checkpoint a model or part of the model.
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
*
args
)
return
CheckpointFunction
.
apply
(
function
,
distribute_saved_activations
,
*
args
)
apex/transformer/tensor_parallel/utils.py
View file @
96850dfa
...
@@ -12,12 +12,18 @@
...
@@ -12,12 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
List
,
Sequence
import
torch
import
torch
from
apex.transformer.utils
import
divide
from
apex.transformer.utils
import
divide
def
split_tensor_along_last_dim
(
tensor
,
num_partitions
,
contiguous_split_chunks
=
False
):
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
List
[
torch
.
Tensor
]:
"""Split a tensor along its last dimension.
"""Split a tensor along its last dimension.
Arguments:
Arguments:
tensor: input tensor.
tensor: input tensor.
...
@@ -43,12 +49,16 @@ class VocabUtility:
...
@@ -43,12 +49,16 @@ class VocabUtility:
partition: Note that indices in [fist, last)"""
partition: Note that indices in [fist, last)"""
@
staticmethod
@
staticmethod
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
):
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
:
int
,
rank
,
world_size
:
int
)
->
Sequence
[
int
]:
index_f
=
rank
*
per_partition_vocab_size
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
return
index_f
,
index_l
@
staticmethod
@
staticmethod
def
vocab_range_from_global_vocab_size
(
global_vocab_size
,
rank
,
world_size
)
:
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
)
->
Sequence
[
int
]
:
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
)
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
)
apex/transformer/testing/arguments.py
View file @
96850dfa
...
@@ -39,9 +39,13 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -39,9 +39,13 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_data_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vi
t
_args
(
parser
)
parser
=
_add_vi
sion
_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
# NOTE(mkozuki): This option is added to investigate the potential of `torch.autograd.graph.save_on_cpu()`.
# ref: https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.save_on_cpu.
parser
.
add_argument
(
'--cpu-offload'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Turns on CPU offloading'
)
# Custom arguments.
# Custom arguments.
if
extra_args_provider
is
not
None
:
if
extra_args_provider
is
not
None
:
parser
=
extra_args_provider
(
parser
)
parser
=
extra_args_provider
(
parser
)
...
@@ -65,6 +69,11 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -65,6 +69,11 @@ def parse_args(extra_args_provider=None, defaults={},
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
args
.
transformer_pipeline_model_parallel_size
=
(
args
.
pipeline_model_parallel_size
-
1
if
args
.
standalone_embedding_stage
else
args
.
pipeline_model_parallel_size
)
# Checks.
# Checks.
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
args
.
tensor_model_parallel_size
...
@@ -98,13 +107,18 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -98,13 +107,18 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead'
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
if
args
.
checkpoint_activations
:
args
.
activations_checkpoint_method
=
'uniform'
args
.
recompute_granularity
=
'full'
args
.
recompute_method
=
'uniform'
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
print
(
'--checkpoint-activations is no longer valid, '
'use --
activation-checkpoint
-method instead. '
'use --
recompute-granularity and --recompute
-method
instead. '
'Defaulting to
activation-checkpoint
-method=uniform.'
)
'Defaulting to
recompute-granularity=full and recompute
-method=uniform.'
)
del
args
.
checkpoint_activations
del
args
.
checkpoint_activations
if
args
.
recompute_activations
:
args
.
recompute_granularity
=
'selective'
del
args
.
recompute_activations
# Set input defaults.
# Set input defaults.
for
key
in
defaults
:
for
key
in
defaults
:
# For default to be valid, it should not be provided in the
# For default to be valid, it should not be provided in the
...
@@ -166,6 +180,14 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -166,6 +180,14 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
accumulate_allreduce_grads_in_fp32
:
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
assert
args
.
use_contiguous_buffers_in_local_ddp
else
:
if
args
.
gradient_accumulation_fusion
:
args
.
gradient_accumulation_fusion
=
False
if
args
.
rank
==
0
:
print
(
'Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False'
,
flush
=
True
)
# For torch DDP, we do not use contiguous buffer
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
...
@@ -244,17 +266,51 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -244,17 +266,51 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
assert
args
.
fp16
or
args
.
bf16
,
\
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
weight_decay_incr_style
==
'constant'
:
assert
args
.
start_weight_decay
is
None
assert
args
.
end_weight_decay
is
None
args
.
start_weight_decay
=
args
.
weight_decay
args
.
end_weight_decay
=
args
.
weight_decay
else
:
assert
args
.
start_weight_decay
is
not
None
assert
args
.
end_weight_decay
is
not
None
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
# Persistent fused layer norm.
if
TORCH_MAJOR
<
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
11
):
args
.
no_persist_layer_norm
=
True
if
args
.
rank
==
0
:
print
(
'Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True'
)
# Activation recomputing.
if
args
.
distribute_saved_activations
:
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'
checkpoin
ted activations only across tensor model '
\
'
recompu
ted activations only across tensor model '
\
'parallel groups'
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
assert
args
.
recompute_granularity
==
'full'
,
\
'for distribute-checkpointed-activations to work you '
\
'distributed recompute activations is only '
\
'need to use a activation-checkpoint method '
'application to full recompute granularity'
assert
args
.
num_layers_per_virtual_pipeline_stage
is
None
,
\
assert
args
.
recompute_method
is
not
None
,
\
'currently distrobuted checkpoint activations only supported for '
\
'for distributed recompute activations to work you '
\
'nointerleaved pipeline parallelism'
'need to use a recompute method '
assert
TORCH_MAJOR
>=
1
and
TORCH_MINOR
>=
10
,
\
'distributed recompute activations are supported for pytorch '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
if
args
.
recompute_granularity
==
'selective'
:
assert
args
.
recompute_method
is
None
,
\
'recompute method is not yet supported for '
\
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if
args
.
sequence_parallel
:
args
.
async_tensor_model_parallel_allreduce
=
False
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -279,6 +335,18 @@ def _check_arg_is_not_none(args, arg):
...
@@ -279,6 +335,18 @@ def _check_arg_is_not_none(args, arg):
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
def
_add_inference_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'inference'
)
group
.
add_argument
(
'--inference-batch-times-seqlen-threshold'
,
type
=
int
,
default
=
512
,
help
=
'During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.'
)
return
parser
def
_add_network_size_args
(
parser
):
def
_add_network_size_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'network size'
)
group
=
parser
.
add_argument_group
(
title
=
'network size'
)
...
@@ -318,6 +386,8 @@ def _add_network_size_args(parser):
...
@@ -318,6 +386,8 @@ def _add_network_size_args(parser):
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
help
=
'Disable BERT binary head.'
,
help
=
'Disable BERT binary head.'
,
dest
=
'bert_binary_head'
)
dest
=
'bert_binary_head'
)
group
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
help
=
'Number of Experts in Switch Transformer (None means no Switch)'
)
return
parser
return
parser
...
@@ -354,6 +424,9 @@ def _add_logging_args(parser):
...
@@ -354,6 +424,9 @@ def _add_logging_args(parser):
group
.
add_argument
(
'--log-memory-to-tensorboard'
,
group
.
add_argument
(
'--log-memory-to-tensorboard'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Enable memory logging to tensorboard.'
)
help
=
'Enable memory logging to tensorboard.'
)
group
.
add_argument
(
'--log-world-size-to-tensorboard'
,
action
=
'store_true'
,
help
=
'Enable world size logging to tensorboard.'
)
return
parser
return
parser
...
@@ -367,6 +440,13 @@ def _add_regularization_args(parser):
...
@@ -367,6 +440,13 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-weight-decay'
,
type
=
float
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-weight-decay'
,
type
=
float
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--weight-decay-incr-style'
,
type
=
str
,
default
=
'constant'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
help
=
'Gradient clipping based on global L2 norm.'
)
help
=
'Gradient clipping based on global L2 norm.'
)
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
...
@@ -413,27 +493,40 @@ def _add_training_args(parser):
...
@@ -413,27 +493,40 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--
checkpoint
-activations'
,
action
=
'store_true'
,
group
.
add_argument
(
'--
recompute
-activations'
,
action
=
'store_true'
,
help
=
'
Checkpoint
activation to allow for training '
help
=
'
recompute
activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
group
.
add_argument
(
'--recompute-granularity'
,
type
=
str
,
default
=
None
,
choices
=
[
'full'
,
'selective'
],
help
=
'Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.'
)
group
.
add_argument
(
'--distribute-saved-activations'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, distribute
checkpoin
ted activations '
help
=
'If set, distribute
recompu
ted activations '
'across model parallel group.'
)
'across model parallel group.'
)
group
.
add_argument
(
'--
activations-checkpoint
-method'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--
recompute
-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'uniform'
,
'block'
],
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and
checkpoint
the input activation of '
'Transformer layers and
recompute
the input activation of '
'each divided chunk, '
'each divided chunk
at specified granularity
, '
'2)
checkpoint
the input activations of only a set number of '
'2)
recompute
the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'individual Transformer layers per pipeline stage and do the '
'rest without any
checkpointing
'
'rest without any
recomputing at specified granularity
'
'default) do not apply activations
checkpoint
to any layers'
)
'default) do not apply activations
recompute
to any layers'
)
group
.
add_argument
(
'--
activations-checkpoint
-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--
recompute
-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided
checkpoint
unit, '
'uniformly divided
recompute
unit, '
'2) block: the number of individual Transformer layers '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.'
)
'to recompute within each pipeline stage.'
)
# deprecated
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'training runs. Note that either train-iters or '
...
@@ -472,7 +565,20 @@ def _add_training_args(parser):
...
@@ -472,7 +565,20 @@ def _add_training_args(parser):
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Disable asynchronous execution of '
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
'gradient compuation of a column-linear layer.'
,
dest
=
'async_tensor_model_parallel_allreduce'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
group
.
add_argument
(
'--sequence-parallel'
,
action
=
'store_true'
,
help
=
'Enable sequence parallel optimization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
return
parser
return
parser
...
@@ -645,6 +751,11 @@ def _add_distributed_args(parser):
...
@@ -645,6 +751,11 @@ def _add_distributed_args(parser):
help
=
'Call torch.cuda.empty_cache() each iteration '
help
=
'Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.'
)
'0=off, 1=moderate, 2=aggressive.'
)
group
.
add_argument
(
'--standalone-embedding-stage'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)'
)
return
parser
return
parser
...
@@ -791,16 +902,70 @@ def _add_biencoder_args(parser):
...
@@ -791,16 +902,70 @@ def _add_biencoder_args(parser):
return
parser
return
parser
def
_add_vi
t
_args
(
parser
):
def
_add_vi
sion
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"vi
t
"
)
group
=
parser
.
add_argument_group
(
title
=
"vi
sion
"
)
# general vision arguments
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
help
=
'num of classes in vision classificaiton task'
)
help
=
'num of classes in vision classificaiton task'
)
group
.
add_argument
(
'--img-dim'
,
type
=
int
,
default
=
224
,
group
.
add_argument
(
'--img-h'
,
type
=
int
,
default
=
224
,
help
=
'Image size for vision classification task'
)
help
=
'Image height for vision classification task'
)
group
.
add_argument
(
'--img-w'
,
type
=
int
,
default
=
224
,
help
=
'Image height for vision classification task'
)
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
help
=
'Number of channels in input image data'
)
help
=
'Number of channels in input image data'
)
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
help
=
'patch dimension used in vit'
)
help
=
'patch dimension'
)
group
.
add_argument
(
'--classes-fraction'
,
type
=
float
,
default
=
1.0
,
help
=
'training with fraction of classes.'
)
group
.
add_argument
(
'--data-per-class-fraction'
,
type
=
float
,
default
=
1.0
,
help
=
'training with fraction of data per class.'
)
group
.
add_argument
(
'--no-data-sharding'
,
action
=
'store_false'
,
help
=
'Disable data sharding.'
,
dest
=
'data_sharding'
)
group
.
add_argument
(
'--head-lr-mult'
,
type
=
float
,
default
=
1.0
,
help
=
'learning rate multiplier for head during finetuning'
)
# pretraining type and backbone selection`
group
.
add_argument
(
'--vision-pretraining'
,
action
=
'store_true'
,
help
=
'flag to indicate vision pretraining'
)
group
.
add_argument
(
'--vision-pretraining-type'
,
type
=
str
,
default
=
'classify'
,
choices
=
[
'classify'
,
'inpaint'
,
'dino'
],
help
=
'pretraining objectives'
)
group
.
add_argument
(
'--vision-backbone-type'
,
type
=
str
,
default
=
'vit'
,
choices
=
[
'vit'
,
'mit'
,
'swin'
],
help
=
'backbone types types'
)
group
.
add_argument
(
'--swin-backbone-type'
,
type
=
str
,
default
=
'tiny'
,
choices
=
[
'tiny'
,
'base'
,
'h3'
],
help
=
'pretraining objectives'
)
# inpainting arguments
group
.
add_argument
(
'--mask-type'
,
type
=
str
,
default
=
'random'
,
choices
=
[
'random'
,
'row'
],
help
=
'mask types'
)
group
.
add_argument
(
'--mask-factor'
,
type
=
float
,
default
=
1.0
,
help
=
'mask size scaling parameter'
)
# dino arguments
group
.
add_argument
(
'--iter-per-epoch'
,
type
=
int
,
default
=
1250
,
help
=
'iterations per epoch'
)
group
.
add_argument
(
'--dino-local-img-size'
,
type
=
int
,
default
=
96
,
help
=
'Image size for vision classification task'
)
group
.
add_argument
(
'--dino-local-crops-number'
,
type
=
int
,
default
=
10
,
help
=
'Number of local crops'
)
group
.
add_argument
(
'--dino-head-hidden-size'
,
type
=
int
,
default
=
2048
,
help
=
'Hidden dimension size in dino head'
)
group
.
add_argument
(
'--dino-bottleneck-size'
,
type
=
int
,
default
=
256
,
help
=
'Bottle neck dimension in dino head '
)
group
.
add_argument
(
'--dino-freeze-last-layer'
,
type
=
float
,
default
=
1
,
help
=
'Freezing last layer weights'
)
group
.
add_argument
(
'--dino-norm-last-layer'
,
action
=
'store_true'
,
help
=
'Disable Norm in last layer.'
)
group
.
add_argument
(
'--dino-warmup-teacher-temp'
,
type
=
float
,
default
=
0.04
,
help
=
'warump teacher temperature'
)
group
.
add_argument
(
'--dino-teacher-temp'
,
type
=
float
,
default
=
0.07
,
help
=
'teacher temperature'
)
group
.
add_argument
(
'--dino-warmup-teacher-temp-epochs'
,
type
=
int
,
default
=
30
,
help
=
'warmup teacher temperaure epochs'
)
return
parser
return
parser
apex/transformer/testing/commons.py
View file @
96850dfa
...
@@ -12,15 +12,28 @@
...
@@ -12,15 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
import
datetime
import
os
import
os
import
random
import
random
from
typing
import
Optional
,
Union
,
List
from
typing
import
Optional
,
Union
,
List
,
Tuple
,
Callable
,
Dict
import
numpy
import
numpy
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
apex
import
transformer
from
apex
import
transformer
from
apex.transformer.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
scatter_to_sequence_parallel_region
,
)
from
apex.transformer.pipeline_parallel.utils
import
(
average_losses_across_data_parallel_group
,
)
from
apex.transformer.pipeline_parallel.schedules.common
import
(
Batch
,
)
from
apex.transformer.testing
import
global_vars
from
apex.transformer.testing
import
global_vars
...
@@ -29,7 +42,6 @@ TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
...
@@ -29,7 +42,6 @@ TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class
MyLayer
(
nn
.
Module
):
class
MyLayer
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
):
super
().
__init__
()
super
().
__init__
()
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
...
@@ -39,17 +51,28 @@ class MyLayer(nn.Module):
...
@@ -39,17 +51,28 @@ class MyLayer(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
layer
(
x
)
return
self
.
layer
(
x
)
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
=
False
,
post_process
:
bool
=
False
)
->
None
:
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
=
False
,
post_process
:
bool
=
False
,
*
,
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
layer
=
MyLayer
(
hidden_size
=
hidden_size
,
pre_process
=
pre_process
,
post_process
=
post_process
)
self
.
layer
=
MyLayer
(
hidden_size
=
hidden_size
,
pre_process
=
pre_process
,
post_process
=
post_process
)
self
.
input_tensor
=
None
self
.
input_tensor
=
None
def
set_input_tensor
(
self
,
input_tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]])
->
None
:
def
set_input_tensor
(
self
.
input_tensor
=
input_tensor
self
,
input_tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
)
->
None
:
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
self
.
input_tensor
=
input_tensor
[
0
]
def
forward
(
self
,
x
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
if
self
.
input_tensor
is
None
:
if
self
.
input_tensor
is
None
:
...
@@ -57,8 +80,154 @@ class MyModel(nn.Module):
...
@@ -57,8 +80,154 @@ class MyModel(nn.Module):
return
self
.
layer
(
self
.
input_tensor
)
return
self
.
layer
(
self
.
input_tensor
)
def
model_provider_func
(
hidden_size
,
pre_process
,
post_process
)
->
MyModel
:
class
ToyParallelMLP
(
nn
.
Module
):
return
MyModel
(
hidden_size
,
pre_process
,
post_process
)
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
=
False
,
post_process
:
bool
=
False
,
*
,
sequence_parallel_enabled
:
bool
=
False
,
# TODO(mkozuki): Support these two?
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
ffn_hidden_size
=
4
*
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
ffn_hidden_size
,
gather_output
=
False
,
# init_method=init_method,
skip_bias_add
=
True
,
# use_cpu_initialization=use_cpu_initialization,
bias
=
True
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
no_async_tensor_model_parallel_allreduce
=
True
,
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
ffn_hidden_size
,
hidden_size
,
input_is_parallel
=
True
,
# init_method=output_layer_init_method,
skip_bias_add
=
False
,
# use_cpu_initialization=use_cpu_initialization,
bias
=
True
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
self
.
activation_func
=
torch
.
nn
.
GELU
()
def
set_input_tensor
(
self
,
input_tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
)
->
None
:
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
self
.
input_tensor
=
input_tensor
[
0
]
def
forward
(
self
,
x
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""Forward of Simplified ParallelMLP.
Args:
x: :obj:`None` if pipeline rank != pippeline first rank. When :obj:`None`,
`self.input_tensor` is taken care of by `forward_step` defined in
apex/transformer/pipeline_parallel/schedules/common.py
"""
# [s, b, h]
if
self
.
input_tensor
is
None
:
input
=
x
else
:
input
=
self
.
input_tensor
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
input
)
if
bias_parallel
is
not
None
:
intermediate_parallel
+=
bias_parallel
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
def
model_provider_func
(
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
,
*
,
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
)
->
MyModel
:
return
MyModel
(
hidden_size
,
pre_process
,
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
)
def
mlp_provider_func
(
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
,
*
,
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
ToyParallelMLP
:
return
ToyParallelMLP
(
hidden_size
,
pre_process
,
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
def
process_batch
(
batch
):
if
isinstance
(
batch
,
list
):
x
=
batch
[
0
]
else
:
x
=
batch
return
x
def
fwd_step_func
(
batch
,
model
):
x
=
process_batch
(
batch
)
y
=
model
(
x
)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def
loss_func
(
x
):
loss
=
torch
.
sum
(
x
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
"avg"
:
averaged_loss
}
return
y
,
loss_func
@
dataclass
(
frozen
=
True
)
class
ToyParallelMLPFwdBwdStepFunc
:
sequence_parallel_enabled
:
bool
def
__call__
(
self
,
batch
:
Batch
,
model
:
torch
.
nn
.
Module
,
)
->
Tuple
[
torch
.
Tensor
,
Callable
[[
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]]]:
x
=
batch
[
0
]
if
isinstance
(
batch
,
list
)
else
batch
if
isinstance
(
x
,
torch
.
Tensor
):
x
=
x
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
sequence_parallel_enabled
:
x
=
scatter_to_sequence_parallel_region
(
x
)
y
=
model
(
x
)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def
loss_func
(
x
):
loss
=
torch
.
sum
(
x
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
"avg"
:
averaged_loss
}
return
y
,
loss_func
class
IdentityLayer
(
torch
.
nn
.
Module
):
class
IdentityLayer
(
torch
.
nn
.
Module
):
...
@@ -78,22 +247,28 @@ def set_random_seed(seed):
...
@@ -78,22 +247,28 @@ def set_random_seed(seed):
transformer
.
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
transformer
.
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
def
initialize_distributed
(
backend
=
'
nccl
'
):
def
initialize_distributed
(
backend
=
"
nccl
"
):
"""Initialize torch.distributed."""
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
# Get local rank in case it is provided.
# parser = argparse.ArgumentParser()
# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', type=int, default=None,
# parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
# help='local rank passed from distributed launcher')
# args = parser.parse_args()
# args = parser.parse_args()
if
backend
not
in
(
"nccl"
,
"ucc"
):
raise
RuntimeError
(
f
"Currently only nccl & ucc are supported but
{
backend
}
"
)
if
backend
==
"ucc"
:
import
torch_ucc
# NOQA
args
=
global_vars
.
get_args
()
args
=
global_vars
.
get_args
()
local_rank
=
args
.
local_rank
local_rank
=
args
.
local_rank
# Get rank and world size.
# Get rank and world size.
rank
=
int
(
os
.
getenv
(
'
RANK
'
,
'0'
))
rank
=
int
(
os
.
getenv
(
"
RANK
"
,
"0"
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
print
(
'> initializing torch.distributed with local rank: {}, '
print
(
'rank: {}, world size: {}'
.
format
(
local_rank
,
rank
,
world_size
))
"> initializing torch.distributed with local rank: {}, "
"rank: {}, world size: {}"
.
format
(
local_rank
,
rank
,
world_size
)
)
# Set the device id.
# Set the device id.
device
=
rank
%
torch
.
cuda
.
device_count
()
device
=
rank
%
torch
.
cuda
.
device_count
()
...
@@ -102,22 +277,21 @@ def initialize_distributed(backend='nccl'):
...
@@ -102,22 +277,21 @@ def initialize_distributed(backend='nccl'):
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
# Call the init process.
# Call the init process.
init_method
=
'
tcp://
'
init_method
=
"
tcp://
"
master_ip
=
os
.
getenv
(
'
MASTER_ADDR
'
,
'
localhost
'
)
master_ip
=
os
.
getenv
(
"
MASTER_ADDR
"
,
"
localhost
"
)
master_port
=
os
.
getenv
(
'
MASTER_PORT
'
,
'
6000
'
)
master_port
=
os
.
getenv
(
"
MASTER_PORT
"
,
"
6000
"
)
init_method
+=
master_ip
+
':'
+
master_port
init_method
+=
master_ip
+
":"
+
master_port
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
init_method
=
init_method
,
world_size
=
world_size
,
timeout
=
datetime
.
timedelta
(
seconds
=
60
),
rank
=
rank
,
)
init_method
=
init_method
)
def
print_separator
(
message
):
def
print_separator
(
message
):
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
filler_len
=
(
78
-
len
(
message
))
//
2
filler_len
=
(
78
-
len
(
message
))
//
2
filler
=
'-'
*
filler_len
filler
=
"-"
*
filler_len
string
=
'
\n
'
+
filler
+
'
{}
'
.
format
(
message
)
+
filler
string
=
"
\n
"
+
filler
+
"
{}
"
.
format
(
message
)
+
filler
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
string
,
flush
=
True
)
print
(
string
,
flush
=
True
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment