Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
795a5e5b
Commit
795a5e5b
authored
Jul 29, 2022
by
hubertlu-tw
Browse files
Merge remote-tracking branch 'upstream/master' into IFU-master-2022-07-29
parents
016c8d4f
3c19f106
Changes
230
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2718 additions
and
600 deletions
+2718
-600
apex/transformer/log_util.py
apex/transformer/log_util.py
+1
-2
apex/transformer/microbatches.py
apex/transformer/microbatches.py
+40
-17
apex/transformer/parallel_state.py
apex/transformer/parallel_state.py
+329
-43
apex/transformer/pipeline_parallel/p2p_communication.py
apex/transformer/pipeline_parallel/p2p_communication.py
+335
-161
apex/transformer/pipeline_parallel/schedules/__init__.py
apex/transformer/pipeline_parallel/schedules/__init__.py
+10
-14
apex/transformer/pipeline_parallel/schedules/common.py
apex/transformer/pipeline_parallel/schedules/common.py
+250
-70
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
...rmer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
+55
-14
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
...arallel/schedules/fwd_bwd_pipelining_with_interleaving.py
+153
-46
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
...llel/schedules/fwd_bwd_pipelining_without_interleaving.py
+364
-45
apex/transformer/pipeline_parallel/utils.py
apex/transformer/pipeline_parallel/utils.py
+26
-2
apex/transformer/tensor_parallel/__init__.py
apex/transformer/tensor_parallel/__init__.py
+2
-1
apex/transformer/tensor_parallel/data.py
apex/transformer/tensor_parallel/data.py
+15
-6
apex/transformer/tensor_parallel/layers.py
apex/transformer/tensor_parallel/layers.py
+374
-71
apex/transformer/tensor_parallel/mappings.py
apex/transformer/tensor_parallel/mappings.py
+164
-19
apex/transformer/tensor_parallel/memory.py
apex/transformer/tensor_parallel/memory.py
+20
-5
apex/transformer/tensor_parallel/random.py
apex/transformer/tensor_parallel/random.py
+41
-24
apex/transformer/tensor_parallel/utils.py
apex/transformer/tensor_parallel/utils.py
+14
-4
apex/transformer/testing/arguments.py
apex/transformer/testing/arguments.py
+197
-32
apex/transformer/testing/commons.py
apex/transformer/testing/commons.py
+198
-24
apex/transformer/testing/distributed_test_base.py
apex/transformer/testing/distributed_test_base.py
+130
-0
No files found.
apex/transformer/log_util.py
View file @
795a5e5b
from
typing
import
Optional
import
logging
import
os
import
threading
def
get_transformer_logger
(
name
:
str
)
->
logging
.
Logger
:
...
...
@@ -16,4 +14,5 @@ def set_logging_level(verbosity) -> None:
verbosity
"""
from
apex
import
_library_root_logger
_library_root_logger
.
setLevel
(
verbosity
)
apex/transformer/microbatches.py
View file @
795a5e5b
...
...
@@ -17,13 +17,18 @@ from abc import ABC
from
abc
import
abstractmethod
from
typing
import
Optional
,
List
from
apex.transformer.log_util
import
get_transformer_logger
_logger
=
get_transformer_logger
(
__name__
)
def
build_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
):
# Constant num micro-batches.
if
rampup_batch_size
is
None
:
...
...
@@ -31,8 +36,10 @@ def build_num_microbatches_calculator(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
if
rank
==
0
:
print
(
"setting number of micro-batches to constant {}"
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
_logger
.
info
(
"setting number of micro-batches to constant {}"
.
format
(
num_microbatches_calculator
.
get
()
)
)
else
:
...
...
@@ -45,13 +52,15 @@ def build_num_microbatches_calculator(
batch_size_increment
=
int
(
rampup_batch_size
[
1
])
ramup_samples
=
int
(
rampup_batch_size
[
2
])
if
rank
==
0
:
print
(
_logger
.
info
(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples."
.
format
(
start_batch_size
,
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
,
start_batch_size
,
global_batch_size
,
batch_size_increment
,
ramup_samples
,
)
)
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
...
...
@@ -86,7 +95,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
micro_batch_times_data_parallel
=
micro_batch_size
*
data_parallel_size
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
(
"global batch size ({}) is not divisible by micro batch size ({})"
" times data parallel size ({})"
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
" times data parallel size ({})"
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
)
self
.
num_micro_batches
=
global_batch_size
//
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
...
...
@@ -126,7 +137,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
self
.
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
(
self
.
micro_batch_size
*
self
.
data_parallel_size
)
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
...
...
@@ -158,15 +171,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self
.
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
self
.
current_global_batch_size
=
self
.
start_batch_size
+
steps
*
self
.
batch_size_increment
self
.
current_global_batch_size
=
(
self
.
start_batch_size
+
steps
*
self
.
batch_size_increment
)
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
if
consistency_check
:
assert
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
,
(
assert
(
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
),
(
"current global "
"batch size ({}) is not divisible by micro-batch-size ({}) times"
"data parallel size ({})"
.
format
(
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
,
)
)
self
.
num_micro_batches
=
self
.
current_global_batch_size
//
self
.
micro_batch_times_data_parallel_size
self
.
num_micro_batches
=
(
self
.
current_global_batch_size
//
self
.
micro_batch_times_data_parallel_size
)
apex/transformer/parallel_state.py
View file @
795a5e5b
...
...
@@ -12,14 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Replace assert with RuntimeError.
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
"""Model and data parallel groups."""
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
import
warnings
import
torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from
apex.transformer.utils
import
ensure_divisibility
from
apex.transformer.log_util
import
get_transformer_logger
_logger
=
get_transformer_logger
(
__name__
)
# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# {
# 'get_num_layers',
# }
# Intra-layer model parallel group that the current rank belongs to.
...
...
@@ -30,11 +40,17 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP
=
None
# Embedding group.
_EMBEDDING_GROUP
=
None
# Position embedding group.
_POSITION_EMBEDDING_GROUP
=
None
# Relative position embedding group.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
...
...
@@ -45,6 +61,13 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the relative position embedding.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS
=
None
...
...
@@ -56,14 +79,31 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
tensor_model_parallel_size_
:
int
=
1
,
pipeline_model_parallel_size_
:
int
=
1
,
virtual_pipeline_model_parallel_size_
:
Optional
[
int
]
=
None
,
pipeline_model_parallel_split_rank_
:
Optional
[
int
]
=
None
,
*
,
default_backend
:
Optional
[
str
]
=
None
,
p2p_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
Keyword Arguments:
default_backend: Backend of process groups except for pipeline parallel ones.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
p2p_backend: Backend of process groups for pipeline model parallel.
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
.. note::
`torch_ucc <https://github.com/facebookresearch/torch_ucc>`_ is
necessary for "ucc" backend.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...
...
@@ -83,28 +123,61 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
min
(
tensor_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
=
min
(
pipeline_model_parallel_size_
,
world_size
)
ensure_divisibility
(
world_size
,
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
assert
default_backend
is
None
or
default_backend
in
(
"nccl"
,
"ucc"
)
assert
p2p_backend
is
None
or
p2p_backend
in
(
"nccl"
,
"ucc"
)
if
"ucc"
in
(
default_backend
,
p2p_backend
):
check_torch_ucc_availability
()
warnings
.
warn
(
"`ucc` backend support is experimental"
,
ExperimentalWarning
)
if
default_backend
==
"ucc"
:
warnings
.
warn
(
"The UCC's functionality as `default_backend` is not well verified"
,
ExperimentalWarning
)
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
:
int
=
min
(
tensor_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
:
int
=
min
(
pipeline_model_parallel_size_
,
world_size
)
if
world_size
%
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
!=
0
:
raise
RuntimeError
(
f
"`world_size` (
{
world_size
}
) is not divisible by tensor_model_parallel_size (
{
tensor_model_parallel_size
}
) x pipeline_model_parallel_size (
{
pipeline_model_parallel_size
}
)"
)
data_parallel_size
:
int
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size
))
print
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
))
print
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
))
_logger
.
info
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size
)
)
_logger
.
info
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
)
)
_logger
.
info
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
)
)
num_tensor_model_parallel_groups
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
:
int
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
assert
pipeline_model_parallel_size_
>
2
,
\
'pipeline-model-parallel size should be greater than 2 with '
\
'interleaved schedule'
# n.b. (eqy) This check was inherited from Megatron-LM, need to revisit
# the root cause as we do see numerical mismatches with 2 stages and
# the interleaved schedule
assert
pipeline_model_parallel_size_
>
2
,
(
"pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule"
)
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
(
virtual_pipeline_model_parallel_size_
)
if
pipeline_model_parallel_split_rank_
is
not
None
:
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
pipeline_model_parallel_split_rank_
rank
=
torch
.
distributed
.
get_rank
()
...
...
@@ -118,7 +191,7 @@ def initialize_model_parallel(
for
j
in
range
(
tensor_model_parallel_size
):
ranks
=
range
(
start_rank
+
j
,
end_rank
,
tensor_model_parallel_size
)
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
default_backend
)
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
...
...
@@ -126,17 +199,24 @@ def initialize_model_parallel(
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
"model parallel group is already initialized"
for
i
in
range
(
data_parallel_size
):
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
group
=
torch
.
distributed
.
new_group
(
ranks
)
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
default_backend
)
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
# Build the tensor model-parallel groups.
global
_TENSOR_MODEL_PARALLEL_GROUP
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
None
,
"tensor model parallel group is already initialized"
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
),
"tensor model parallel group is already initialized"
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
ranks
=
list
(
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
default_backend
)
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
...
...
@@ -144,43 +224,111 @@ def initialize_model_parallel(
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
"pipeline model parallel group is already initialized"
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
),
"pipeline model parallel group is already initialized"
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
(
_POSITION_EMBEDDING_GROUP
is
None
),
"position embedding group is already initialized"
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
assert
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
None
or
\
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
None
,
\
'relative position embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
p2p_backend
)
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
encoder_relative_position_embedding_ranks
=
None
decoder_relative_position_embedding_ranks
=
None
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
encoder_relative_position_embedding_ranks
=
[
ranks
[
0
]]
decoder_relative_position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank_
is
not
None
:
encoder_relative_position_embedding_ranks
=
\
ranks
[:
pipeline_model_parallel_split_rank_
]
decoder_relative_position_embedding_ranks
=
\
ranks
[
pipeline_model_parallel_split_rank_
:]
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
],
]
if
(
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
position_embedding_ranks
):
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
]
else
:
embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
position_embedding_ranks
=
ranks
encoder_relative_position_embedding_ranks
=
ranks
decoder_relative_position_embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
,
backend
=
default_backend
)
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
,
backend
=
default_backend
)
if
rank
in
position_embedding_ranks
:
_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
if
encoder_relative_position_embedding_ranks
:
group
=
torch
.
distributed
.
new_group
(
encoder_relative_position_embedding_ranks
)
if
rank
in
encoder_relative_position_embedding_ranks
:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
\
encoder_relative_position_embedding_ranks
if
decoder_relative_position_embedding_ranks
:
group
=
torch
.
distributed
.
new_group
(
decoder_relative_position_embedding_ranks
)
if
rank
in
decoder_relative_position_embedding_ranks
:
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
=
\
decoder_relative_position_embedding_ranks
def
get_rank_info
()
->
Tuple
[
int
,
int
,
int
]:
"""Returns a tuple of (tensor, pipeline,
data
)-parallel-rank for logger."""
"""Returns a tuple of (
data,
tensor, pipeline,
virtual pipeline
)-parallel-rank for logger."""
if
model_parallel_is_initialized
():
return
(
get_data_parallel_rank
(),
get_tensor_model_parallel_rank
(),
get_pipeline_model_parallel_rank
(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank
(),
get_virtual_pipeline_model_parallel_rank
(),
)
return
(
0
,
0
,
0
)
return
(
0
,
0
,
0
,
0
)
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
if
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
:
if
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
):
return
False
return
True
...
...
@@ -193,13 +341,17 @@ def get_model_parallel_group():
def
get_tensor_model_parallel_group
():
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
,
"intra_layer_model parallel group is not initialized"
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
),
"intra_layer_model parallel group is not initialized"
return
_TENSOR_MODEL_PARALLEL_GROUP
def
get_pipeline_model_parallel_group
():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
,
"pipeline_model parallel group is not initialized"
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
),
"pipeline_model parallel group is not initialized"
return
_PIPELINE_MODEL_PARALLEL_GROUP
...
...
@@ -215,6 +367,25 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
def
get_position_embedding_group
():
"""Get the position embedding group the caller rank belongs to."""
assert
(
_POSITION_EMBEDDING_GROUP
is
not
None
),
"position embedding group is not initialized"
return
_POSITION_EMBEDDING_GROUP
def
get_encoder_relative_position_embedding_group
():
"""Get the encoder relative position embedding group the caller rank belongs to."""
assert
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
not
None
,
\
'encoder relative position embedding group is not initialized'
return
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
def
get_decoder_relative_position_embedding_group
():
"""Get the decoder relative position embedding group the caller rank belongs to."""
assert
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
is
not
None
,
\
'decoder relative position embedding group is not initialized'
return
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
...
...
@@ -231,6 +402,64 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return
False
def
is_rank_in_position_embedding_group
():
"""Return whether the current rank is in position embedding group."""
rank
=
torch
.
distributed
.
get_rank
()
global
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_rank_in_encoder_relative_position_embedding_group
():
"""Return true if current rank is in encoder relative position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_rank_in_decoder_relative_position_embedding_group
():
"""Return true if current rank is in decoder relative position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
<
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_after_split
(
rank
=
None
):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
>=
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_at_split
():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank
=
get_pipeline_model_parallel_rank
()
return
is_pipeline_stage_before_split
(
rank
)
and
is_pipeline_stage_after_split
(
rank
+
1
)
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
...
@@ -287,6 +516,21 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe?
def
get_pipeline_model_parallel_split_rank
():
"""Return my rank for the pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def
set_pipeline_model_parallel_split_rank
(
pipeline_model_parallel_split_rank
:
int
):
"""Set my rank for the pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
pipeline_model_parallel_split_rank
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
...
...
@@ -301,12 +545,16 @@ def is_pipeline_first_stage(ignore_virtual=False):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
get_virtual_pipeline_model_parallel_world_size
()
virtual_pipeline_model_parallel_world_size
=
(
get_virtual_pipeline_model_parallel_world_size
()
)
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
...
...
@@ -335,26 +583,42 @@ def get_tensor_model_parallel_src_rank():
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank in the data parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
:
int
=
get_data_parallel_world_size
()
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
global_rank
%
num_data_parallel_groups
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
assert
(
_PIPELINE_GLOBAL_RANKS
is
not
None
),
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
...
...
@@ -370,6 +634,9 @@ def get_data_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM.
# Otherwise pipeline parallel forward_backward functions test hangs possibly because
# the clean-up of the original is NOT enough.
def
destroy_model_parallel
():
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
...
...
@@ -382,6 +649,12 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
global
_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP
=
None
global
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
global
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
...
...
@@ -394,3 +667,16 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
# Used to warn when the UCC is specified.
class
ExperimentalWarning
(
Warning
):
pass
def
check_torch_ucc_availability
()
->
None
:
try
:
import
torch_ucc
# NOQA
except
ImportError
:
raise
ImportError
(
"UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
)
apex/transformer/pipeline_parallel/p2p_communication.py
View file @
795a5e5b
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,63 +12,108 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Consider removing `timers`.
from
functools
import
reduce
import
operator
from
typing
import
Union
,
Optional
,
Tuple
import
warnings
import
torch
from
apex._autocast_utils
import
_get_current_dtype
from
apex.transformer
import
parallel_state
from
apex.transformer.log_util
import
get_transformer_logger
from
apex.transformer.utils
import
split_tensor_into_1d_equal_chunks
from
apex.transformer.utils
import
gather_split_1d_tensor
from
apex.transformer.pipeline_parallel.utils
import
Shape
from
apex.transformer.pipeline_parallel._timers
import
_Timers
_logger
=
get_transformer_logger
(
__name__
)
class
FutureTensor
:
def
__init__
(
self
,
tensor
:
torch
.
Tensor
,
waitfunc
):
self
.
tensor
=
tensor
self
.
waitfunc
=
waitfunc
def
get
(
self
):
if
self
.
waitfunc
is
not
None
:
res
=
self
.
waitfunc
()
if
isinstance
(
res
,
torch
.
Tensor
):
self
.
tensor
=
res
self
.
waitfunc
=
None
return
self
.
tensor
def
_run_p2pops
(
tensor_send_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_next
:
Union
[
torch
.
Tensor
,
None
],
async_comm
:
bool
=
False
):
ops
=
[]
p2p_group
=
parallel_state
.
get_pipeline_model_parallel_group
()
default_group
=
parallel_state
.
get_model_parallel_group
()
need_to_sync
=
p2p_group
.
name
()
!=
default_group
.
name
()
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
op
=
torch
.
distributed
.
isend
,
tensor
=
tensor_send_prev
,
peer
=
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
group
=
p2p_group
,
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
op
=
torch
.
distributed
.
irecv
,
tensor
=
tensor_recv_prev
,
peer
=
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
group
=
p2p_group
,
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
op
=
torch
.
distributed
.
isend
,
tensor
=
tensor_send_next
,
peer
=
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
group
=
p2p_group
,
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
op
=
torch
.
distributed
.
irecv
,
tensor
=
tensor_recv_next
,
peer
=
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
group
=
p2p_group
,
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
if
need_to_sync
:
torch
.
cuda
.
synchronize
()
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
if
async_comm
:
assert
len
(
reqs
)
==
len
(
ops
)
tensor_send_prev_req
=
None
if
tensor_send_prev
is
None
else
reqs
.
pop
(
0
)
tensor_recv_prev_req
=
None
if
tensor_recv_prev
is
None
else
reqs
.
pop
(
0
)
tensor_send_next_req
=
None
if
tensor_send_next
is
None
else
reqs
.
pop
(
0
)
tensor_recv_next_req
=
None
if
tensor_recv_next
is
None
else
reqs
.
pop
(
0
)
return
(
tensor_send_prev_req
,
tensor_recv_prev_req
,
tensor_send_next_req
,
tensor_recv_next_req
)
else
:
for
req
in
reqs
:
req
.
wait
()
return
(
None
,
None
,
None
,
None
)
return
(
None
,
None
,
None
,
None
)
# TODO(mkozuki): Check if it's possible to sunset `override_scatter_gather_tensors_in_pipeline`.
# TODO(mkozuki): Think about if it's possible to push some logic and arguments e.g.
# `scatter_gather_tensors_in_pipeline`, `sequence_parallel_enabled`, and
# `override_scatter_gather_tensors_in_pipeline` # to the user of
# apex.transformer forward_backwardfunctions.
def
_communicate
(
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
...
...
@@ -76,14 +121,26 @@ def _communicate(
recv_next
:
bool
,
tensor_shape
:
Optional
[
Shape
]
=
None
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
dtype_
:
torch
.
dtype
=
torch
.
float
,
dtype_
:
Optional
[
torch
.
dtype
]
=
None
,
*
,
scatter_gather_tensors_in_pipeline
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
fp32_residual_connection
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
Union
[
torch
.
Tensor
,
None
]]:
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
FutureTensor
,
None
],
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]]:
"""Base function for communication of tensors between stages.
.. note::
Reference https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/cfd2e2160700b7f2c1bf35298ac14bc341f4c759/megatron/p2p_communication.py#L24-L159
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
torch.float32 is used.
See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159
for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
...
...
@@ -99,6 +156,9 @@ def _communicate(
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled.
This argument is here for consistency with Megatron-LM.
This argument has an effect on the communication optimization, not on tensor_shape update.
Returns:
tuple containing
...
...
@@ -106,6 +166,13 @@ def _communicate(
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
if
async_comm
and
sequence_parallel_enabled
:
import
warnings
# NOQA
class
ExperimentalWarning
(
UserWarning
):
pass
# NOQA
warnings
.
warn
(
"The combination of `async_comm` and `sequence_parallel_enabled` is not well tested."
,
ExperimentalWarning
,
)
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
...
...
@@ -113,25 +180,45 @@ def _communicate(
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise
RuntimeError
(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`"
)
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
(
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
parallel_state
.
get_tensor_model_parallel_world_size
(),)
tensor_parallel_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
override_scatter_gather_tensors_in_pipeline_
=
False
# TODO(mkozuki): Demystify hardcode False of `scatter_gather_tensors_in_pipeline` and add a testcase if possible.
# NOTE(mkozuki): This is super strange and doesn't make sense to me. I have no idea what is happening here.
# However, I can say that this hardcoding override is necessary for sequence parallel in nemo megatron to work.
# I've not managed to reproduce the hang using standalone GPT with sequence parallel.
# The hang in NeMo Megatron happens in the 3rd iteration, the last iteration of stead phase inside
# forward_backward_pipelining_without_interleaving, pipeline parallel rank of 0 (tensor model parallel world
# size of 2 and pipeline model parallel world size of 2). The commit then of APEX and NeMo were
# https://github.com/NVIDIA/apex/pull/1396/commits/3060c98dd8ba42abf7702ea9d2cff0f39ea74f45 and
# https://github.com/NVIDIA/NeMo/pull/4232/commits/1cb32dfca2ab9b20f53ebdb84476c34cb42f0205.
# The PyTorch version was 1.13.0a0+git2d354cd, for what is worth.
# Currently, indiscriminately this is set to `False`, which can lead to an unexpected performance regression
# for non sequence parallel case.
scatter_gather_tensors_in_pipeline
=
False
if
scatter_gather_tensors_in_pipeline
and
not
sequence_parallel_enabled
:
tensor_chunk_size
=
int
(
reduce
(
operator
.
mul
,
tensor_shape
,
1
))
if
tensor_chunk_size
%
tensor_parallel_size
==
0
:
tensor_chunk_shape
=
[
tensor_chunk_size
//
tensor_parallel_size
]
else
:
tensor_chunk_shape
=
tensor_shape
override_scatter_gather_tensors_in_pipeline_
=
True
else
:
tensor_chunk_shape
=
tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if
dtype_
!=
torch
.
float32
or
params_dtype
is
not
None
:
if
torch
.
distributed
.
get_rank
()
==
0
:
warnings
.
warn
(
"Tensor P2P communications are executed in FP32"
)
dtype
=
torch
.
float32
# The dtype logic below is copied from NVIDIA/Megatron-LM repo:
# https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
dtype
=
params_dtype
or
torch
.
float
if
fp32_residual_connection
:
dtype
=
torch
.
float
requires_grad
=
True
if
dtype_
is
not
None
:
dtype
=
dtype_
# TODO(mkozuki): Figure out why this logic of requires_grad isn't working
# when sequence_parallel_enabled=True. Otherwise, `x.retain_grad()` of
# https://github.com/crcrpar/apex/blob/069832078a652b4bd8a99db84faf953a81415ab3/apex/transformer/pipeline_parallel/schedules/common.py#L360
# fails.
# requires_grad = False
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
...
...
@@ -149,7 +236,12 @@ def _communicate(
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
scatter_gather_optimization_doable
=
(
not
override_scatter_gather_tensors_in_pipeline_
and
scatter_gather_tensors_in_pipeline
and
not
sequence_parallel_enabled
)
if
scatter_gather_optimization_doable
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
...
@@ -157,41 +249,89 @@ def _communicate(
tensor_send_prev
=
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops
(
tensor_send_prev
,
tensor_send_next
,
tensor_recv_prev
,
tensor_recv_next
)
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
tensor_send_prev_req
,
tensor_recv_prev_req
,
tensor_send_next_req
,
tensor_recv_next_req
=
_run_p2pops
(
tensor_send_prev
,
tensor_send_next
,
tensor_recv_prev
,
tensor_recv_next
,
async_comm
=
async_comm
)
if
async_comm
:
tensor_recv_prev_waitfunc
=
None
tensor_recv_next_waitfunc
=
None
# TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642)
# see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait
if
tensor_recv_prev_req
is
not
None
:
def
tensor_recv_prev_wait
():
tensor_recv_prev_req
.
wait
()
torch
.
cuda
.
synchronize
()
tensor_recv_prev_waitfunc
=
tensor_recv_prev_wait
if
tensor_recv_next_req
is
not
None
:
def
tensor_recv_next_wait
():
tensor_recv_next_req
.
wait
()
torch
.
cuda
.
synchronize
()
tensor_recv_next_waitfunc
=
tensor_recv_next_wait
else
:
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
(
gather_split_1d_tensor
(
tensor_recv_prev
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
if
recv_next
:
tensor_recv_next
=
(
gather_split_1d_tensor
(
tensor_recv_next
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
if
scatter_gather_optimization_doable
:
if
not
async_comm
:
if
recv_prev
:
tensor_recv_prev
=
(
gather_split_1d_tensor
(
tensor_recv_prev
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
if
recv_next
:
tensor_recv_next
=
(
gather_split_1d_tensor
(
tensor_recv_next
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
else
:
def
gather_recv_prev_wait
():
tensor_recv_prev_req
.
wait
()
# From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
# A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
torch
.
cuda
.
synchronize
()
return
(
gather_split_1d_tensor
(
tensor_recv_prev
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
def
gather_recv_next_wait
():
tensor_recv_next_req
.
wait
()
torch
.
cuda
.
synchronize
()
return
(
gather_split_1d_tensor
(
tensor_recv_next
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
tensor_recv_prev_waitfunc
=
gather_recv_prev_wait
tensor_recv_next_waitfunc
=
gather_recv_next_wait
if
async_comm
:
future_tensor_recv_prev
=
None
future_tensor_recv_next
=
None
if
tensor_recv_prev
is
not
None
:
future_tensor_recv_prev
=
FutureTensor
(
tensor_recv_prev
,
tensor_recv_prev_waitfunc
)
if
tensor_recv_next
is
not
None
:
future_tensor_recv_next
=
FutureTensor
(
tensor_recv_next
,
tensor_recv_next_waitfunc
)
return
future_tensor_recv_prev
,
future_tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
:
Shape
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
tensor_shape
:
Shape
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
parallel_state
.
is_pipeline_first_stage
():
return
None
if
timers
is
not
None
:
timers
(
"forward-recv"
).
start
()
#
if timers is not None:
#
timers("forward-recv").start()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
...
...
@@ -199,50 +339,58 @@ def recv_forward(
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"forward-recv"
).
stop
()
#
if timers is not None:
#
timers("forward-recv").stop()
return
input_tensor
def
recv_backward
(
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
):
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Receive tensor from next rank in pipeline (backward receive)."""
if
parallel_state
.
is_pipeline_last_stage
():
return
None
if
timers
is
not
None
:
timers
(
"backward-recv"
).
start
()
#
if timers is not None:
#
timers("backward-recv").start()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"backward-recv"
).
stop
()
#
if timers is not None:
#
timers("backward-recv").stop()
return
output_tensor_grad
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
output_tensor
:
torch
.
Tensor
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Send tensor to next rank in pipeline (forward send)."""
if
parallel_state
.
is_pipeline_last_stage
():
return
if
timers
is
not
None
:
timers
(
"forward-send"
).
start
()
#
if timers is not None:
#
timers("forward-send").start()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
...
...
@@ -250,155 +398,181 @@ def send_forward(
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"forward-send"
).
stop
()
#
if timers is not None:
#
timers("forward-send").stop()
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Send tensor to previous rank in pipeline (backward send)."""
if
parallel_state
.
is_pipeline_first_stage
():
return
if
timers
is
not
None
:
timers
(
"backward-send"
).
start
()
#
if timers is not None:
#
timers("backward-send").start()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"backward-send"
).
stop
()
#
if timers is not None:
#
timers("backward-send").stop()
def
send_forward_recv_backward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
None
:
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Batched send and recv with next rank in pipeline."""
if
parallel_state
.
is_pipeline_last_stage
():
return
None
if
timers
is
not
None
:
timers
(
"forward-send-backward-recv"
).
start
()
#
if timers is not None:
#
timers("forward-send-backward-recv").start()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"forward-send-backward-recv"
).
stop
()
#
if timers is not None:
#
timers("forward-send-backward-recv").stop()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
,
None
]:
"""Batched send and recv with previous rank in pipeline."""
if
parallel_state
.
is_pipeline_first_stage
():
return
None
if
timers
is
not
None
:
timers
(
"backward-send-forward-recv"
).
start
()
#
if timers is not None:
#
timers("backward-send-forward-recv").start()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"backward-send-forward-recv"
).
stop
()
#
if timers is not None:
#
timers("backward-send-forward-recv").stop()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
recv_prev
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
output_tensor
:
torch
.
Tensor
,
recv_prev
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
]:
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
"forward-send-forward-recv"
).
start
()
#
if timers is not None:
#
timers("forward-send-forward-recv").start()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"forward-send-forward-recv"
).
stop
()
#
if timers is not None:
#
timers("forward-send-forward-recv").stop()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
:
torch
.
Tensor
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
input_tensor_grad
:
torch
.
Tensor
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Union
[
torch
.
Tensor
,
FutureTensor
]:
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
"backward-send-backward-recv"
).
start
()
#
if timers is not None:
#
timers("backward-send-backward-recv").start()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"backward-send-backward-recv"
).
stop
()
#
if timers is not None:
#
timers("backward-send-backward-recv").stop()
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
):
output_tensor
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
timers
:
_Timers
=
None
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
FutureTensor
],
Union
[
torch
.
Tensor
,
FutureTensor
]]:
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
"forward-backward-send-forward-backward-recv"
).
start
()
#
if timers is not None:
#
timers("forward-backward-send-forward-backward-recv").start()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
dtype_
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
timers
is
not
None
:
timers
(
"forward-backward-send-forward-backward-recv"
).
stop
()
#
if timers is not None:
#
timers("forward-backward-send-forward-backward-recv").stop()
return
input_tensor
,
output_tensor_grad
apex/transformer/pipeline_parallel/schedules/__init__.py
View file @
795a5e5b
import
warnings
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
forward_backward_no_pipelining
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
_forward_backward_pipelining_with_interleaving
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
(
forward_backward_no_pipelining
,
)
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
(
_forward_backward_pipelining_with_interleaving
,
)
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving
import
(
forward_backward_pipelining_without_interleaving
,
)
__all__
=
[
"get_forward_backward_func"
,
]
class
ExperimentalWarning
(
Warning
):
pass
...
...
@@ -21,19 +27,9 @@ def get_forward_backward_func(
if
get_num_microbatches
()
%
pipeline_model_parallel_size
!=
0
:
msg
=
"number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise
RuntimeError
(
msg
)
warnings
.
warn
(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f
"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`:
{
virtual_pipeline_model_parallel_size
}
"
,
ExperimentalWarning
)
forward_backward_func
=
_forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
__all__
=
[
"get_forward_backward_func"
,
]
apex/transformer/pipeline_parallel/schedules/common.py
View file @
795a5e5b
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
,
Optional
,
Sequence
import
torch
from
torch.autograd.variable
import
Variable
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
from
apex.transformer
import
parallel_state
from
apex.transformer.enums
import
ModelType
from
apex.transformer.pipeline_parallel.p2p_communication
import
FutureTensor
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
unwrap_model
from
apex.transformer.tensor_parallel.layers
import
set_defaults_if_not_set_tensor_model_parallel_attributes
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.tensor_parallel.layers
import
(
set_defaults_if_not_set_tensor_model_parallel_attributes
,
)
from
apex.transformer.log_util
import
get_transformer_logger
Batch
=
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
...]]
_logger
=
get_transformer_logger
(
__name__
)
Batch
=
Union
[
torch
.
Tensor
,
FutureTensor
,
List
[
Union
[
torch
.
Tensor
,
FutureTensor
]],
Tuple
[
Union
[
torch
.
Tensor
,
FutureTensor
],
...]]
LossFunc
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
FwdStepFunc
=
Callable
[[
Batch
,
torch
.
nn
.
Module
],
Tuple
[
torch
.
Tensor
,
LossFunc
]]
FwdStepFunc
=
Callable
[
[
Optional
[
Batch
],
torch
.
nn
.
Module
],
Tuple
[
torch
.
Tensor
,
LossFunc
]
]
def
build_model
(
model_provider_func
:
Callable
[[
Any
,
Dict
[
str
,
Any
]],
torch
.
nn
.
Module
],
wrap_with_ddp
:
bool
=
True
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
*
args
,
**
kwargs
model_provider_func
:
Callable
[[
Any
,
Dict
[
str
,
Any
]],
torch
.
nn
.
Module
],
wrap_with_ddp
:
bool
=
True
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
model_type
:
ModelType
=
ModelType
.
encoder_or_decoder
,
*
args
:
Any
,
**
kwargs
:
Any
,
)
->
List
[
torch
.
nn
.
Module
]:
"""Build the model satisfying pipeline model parallel requirements.
...
...
@@ -32,6 +45,7 @@ def build_model(
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
model_type:
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
...
...
@@ -40,8 +54,8 @@ def build_model(
the list has multiple models, otherwise one.
"""
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
virtual_pipeline_model_parallel_size
is
not
None
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
virtual_pipeline_model_parallel_size
is
not
None
):
model
=
[]
for
i
in
range
(
virtual_pipeline_model_parallel_size
):
...
...
@@ -51,22 +65,48 @@ def build_model(
# Set pre_process and post_process only after virtual rank is set.
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
({
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
})
cur_kwargs
.
update
(
{
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,}
)
this_model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
model
.
append
(
this_model
)
else
:
cur_args
=
args
cur_kwargs
=
kwargs
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
({
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
})
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
if
model_type
==
ModelType
.
encoder_or_decoder
:
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
(
{
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,}
)
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
elif
model_type
==
ModelType
.
encoder_and_decoder
:
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
# `add_encoder` & `add_decoder` logic.
add_encoder
,
add_decoder
=
True
,
True
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
split_rank
=
parallel_state
.
get_pipeline_model_parallel_split_rank
()
if
split_rank
is
None
:
raise
RuntimeError
(
"Split rank needs to be specified for model with both encoder and decoder."
)
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
world_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pre_process
=
rank
==
0
or
rank
==
split_rank
post_process
=
rank
==
(
split_rank
-
1
)
or
rank
==
(
world_size
-
1
)
add_encoder
=
parallel_state
.
is_pipeline_stage_before_split
()
add_decoder
=
parallel_state
.
is_pipeline_stage_after_split
()
cur_kwargs
.
update
(
{
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
"add_encoder"
:
add_encoder
,
"add_decoder"
:
add_decoder
,
}
)
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
model
.
model_type
=
model_type
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
...
...
@@ -80,11 +120,14 @@ def build_model(
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
parallel_state
.
get_data_parallel_rank
()
==
0
:
if
(
parallel_state
.
model_parallel_is_initialized
()
and
parallel_state
.
get_data_parallel_rank
()
==
0
):
msg
=
" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}"
.
format
(
parallel_state
.
get_tensor_model_parallel_rank
(),
parallel_state
.
get_pipeline_model_parallel_rank
(),
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])
_calc_number_of_params
(
model
),
)
print
(
msg
,
flush
=
True
)
...
...
@@ -106,44 +149,119 @@ def build_model(
return
model
def
_calc_number_of_params
(
model
:
List
[
torch
.
nn
.
Module
])
->
int
:
assert
isinstance
(
model
,
list
)
return
sum
(
[
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
]
)
def
_get_params_for_weight_decay_optimization
(
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
no_weight_decay_modules
=
(
FusedLayerNorm
,),
)
->
Dict
[
str
,
torch
.
nn
.
Parameter
]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules
=
listify_model
(
model
)
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
# NOQA
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
weight_decay_params
=
{
"params"
:
[]}
no_weight_decay_params
=
{
"params"
:
[],
"weight_decay"
:
0.0
}
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
FusedLayerNorm
):
no_weight_decay_params
[
'
params
'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
]
)
if
isinstance
(
module_
,
no_weight_decay_modules
):
no_weight_decay_params
[
"
params
"
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
]
)
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
weight_decay_params
[
"params"
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
"bias"
]
)
no_weight_decay_params
[
"params"
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
"bias"
]
)
return
weight_decay_params
,
no_weight_decay_params
def
free_output_tensor
(
output_tensors
:
Optional
[
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]],
deallocate_pipeline_outputs
:
bool
=
False
,
)
->
None
:
"""Pseudo-free the output tensor's `.data` field.
This method should be called right after the output tensor has been sent to the next
pipeline stage. At this point, the output tensor is only useful for its `.grad_fn` field,
and not its `.data`.
"""
if
not
deallocate_pipeline_outputs
:
return
if
output_tensors
is
None
:
return
if
isinstance
(
output_tensors
,
torch
.
Tensor
):
output_tensors
=
[
output_tensors
]
for
output_tensor
in
output_tensors
:
output_tensor
.
data
=
torch
.
cuda
.
FloatTensor
([
0
])
def
custom_backward
(
output
:
torch
.
Tensor
,
grad_output
:
Optional
[
torch
.
Tensor
])
->
None
:
"""Directly call C++ autograd engine.
To make the `free_output_tensor` optimization work, the C++ autograd engine must be called
directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the
output and grad have the same shape, while C++ `backward` does not.
"""
assert
(
output
.
numel
()
==
1
),
"output should be pseudo-freed in schedule, to optimize memory consumption"
assert
isinstance
(
output
,
torch
.
Tensor
),
"output == {}."
.
format
(
type
(
output
).
__name__
)
assert
isinstance
(
grad_output
,
(
torch
.
Tensor
,
type
(
None
))
),
"grad_outptu == {}."
.
format
(
type
(
grad_output
).
__name__
)
# Handle scalar output
if
grad_output
is
None
:
assert
output
.
numel
()
==
1
,
"Implicit grad requires scalar output."
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
)
# Call C++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable
.
_execution_engine
.
run_backward
(
tensors
=
(
output
,),
grad_tensors
=
(
grad_output
,),
keep_graph
=
False
,
create_graph
=
False
,
inputs
=
(),
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
def
forward_step
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
torch
.
nn
.
Module
,
input_tensor
:
Optional
[
torch
.
Tensor
],
losses_reduced
:
List
[
torch
.
Tensor
],
):
forward_step_func
:
FwdStepFunc
,
batch
:
Optional
[
Batch
],
model
:
torch
.
nn
.
Module
,
input_tensor
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]],
losses_reduced
:
List
[
torch
.
Tensor
],
dtype
:
torch
.
dtype
,
disable_autocast
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]:
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used.
Returns output tensor.
...
...
@@ -154,6 +272,8 @@ def forward_step(
model: unwrappable model
input_tensor:
losses_reduced:
dtype:
disable_autocast:
Returns:
output_tensor
...
...
@@ -161,27 +281,51 @@ def forward_step(
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model
=
unwrap_model
(
model
)
model_type
=
get_model_type
(
unwrapped_model
)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrap_output_tensor
=
not
isinstance
(
input_tensor
,
list
)
if
unwrap_output_tensor
:
input_tensor
=
[
input_tensor
]
input_tensor
=
[
inp
.
get
()
if
isinstance
(
inp
,
FutureTensor
)
else
inp
for
inp
in
input_tensor
]
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
batch
,
model
)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
not
disable_autocast
and
dtype
in
(
torch
.
half
,
torch
.
bfloat16
),
dtype
=
dtype
,
):
output_tensor
,
loss_func
=
forward_step_func
(
batch
,
model
)
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
# timers("forward-compute").stop()
return
output_tensor
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]]
if
unwrap_output_tensor
:
return
output_tensor
return
[
output_tensor
]
def
backward_step
(
input_tensor
:
Optional
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
output_tensor_grad
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
input_tensor
:
Optional
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
output_tensor_grad
:
Optional
[
torch
.
Tensor
],
model_type
:
ModelType
,
*
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
deallocate_pipeline_outputs
:
bool
=
False
,
)
->
Union
[
None
,
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
...
...
@@ -194,25 +338,61 @@ def backward_step(
input_tensor:
output_tensor:
output_tensor_grad:
Keyword Arguments:
grad_scaler:
deallocate_pipeline_outputs: Experimental.
Returns:
input_tensor_grad
"""
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
not
isinstance
(
input_tensor
,
list
)
if
unwrap_input_tensor_grad
:
input_tensor
=
[
input_tensor
]
input_tensor
=
[
inp
.
get
()
if
isinstance
(
inp
,
FutureTensor
)
else
inp
for
inp
in
input_tensor
]
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
output_tensor
=
[
out
.
get
()
if
isinstance
(
out
,
FutureTensor
)
else
out
for
out
in
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
output_tensor_grad
=
[
ogr
.
get
()
if
isinstance
(
ogr
,
FutureTensor
)
else
ogr
for
ogr
in
output_tensor_grad
]
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
input_tensor_grad
=
None
if
grad_scaler
is
not
None
and
output_tensor_grad
[
0
]
is
None
:
output_tensor
[
0
]
=
grad_scaler
.
scale
(
output_tensor
[
0
])
if
deallocate_pipeline_outputs
:
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
# timers("backward-compute").stop()
input_tensor_grad
=
[]
for
x
in
input_tensor
:
input_tensor_grad
.
append
(
None
if
x
is
None
else
x
.
grad
)
return
input_tensor_grad
# Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
if
output_tensor_grad
[
1
]
is
not
None
:
# todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`?
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
# timers("backward-compute").stop()
return
input_tensor_grad
[
0
]
if
unwrap_input_tensor_grad
else
input_tensor_grad
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
View file @
795a5e5b
from
contextlib
import
contextmanager
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Optional
import
torch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
from
apex.transformer.pipeline_parallel.schedules.common
import
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.log_util
import
get_transformer_logger
...
...
@@ -27,12 +29,16 @@ def placeholder_handler():
def
forward_backward_no_pipelining
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
**
kwargs
,
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
disable_autocast
:
bool
=
False
,
custom_sync_context_handler
=
None
,
**
kwargs
,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
...
...
@@ -48,6 +54,12 @@ def forward_backward_no_pipelining(
Keyword args:
forward_only:
grad_scaler:
dtype:
disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`.
Should be used when your forward and loss computation is in the autocast context to
avoid unnecesarily nest autocast context.
custom_sync_context_handler:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
...
...
@@ -58,10 +70,14 @@ def forward_backward_no_pipelining(
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
model_type
=
get_model_type
(
model
)
context_handler
=
placeholder_handler
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
):
if
custom_sync_context_handler
is
not
None
:
context_handler
=
custom_sync_context_handler
elif
isinstance
(
model
,
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
):
context_handler
=
model
.
no_sync
else
:
context_handler
=
placeholder_handler
losses_reduced
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
...
...
@@ -72,20 +88,45 @@ def forward_backward_no_pipelining(
cur_micro_batch
=
get_kth_microbatch
(
batch
,
i
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_micro_batch
,
model
,
input_tensor
,
losses_reduced
)
forward_step_func
,
cur_micro_batch
,
model
,
input_tensor
,
losses_reduced
,
dtype
=
dtype
,
disable_autocast
=
disable_autocast
,
)
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger
.
info
(
"Cooldown"
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
forward_step_func
,
get_kth_microbatch
(
batch
,
num_micro_batches
-
1
),
model
,
input_tensor
,
losses_reduced
forward_step_func
,
get_kth_microbatch
(
batch
,
num_micro_batches
-
1
),
model
,
input_tensor
,
losses_reduced
,
dtype
=
dtype
,
disable_autocast
=
disable_autocast
,
)
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
)
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
View file @
795a5e5b
from
typing
import
List
,
Union
,
Optional
from
typing
import
List
,
Union
,
Optional
,
Sequence
import
warnings
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
from
apex.transformer.pipeline_parallel.schedules.common
import
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
free_output_tensor
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.log_util
import
get_transformer_logger
...
...
@@ -18,15 +22,22 @@ __all__ = ["_forward_backward_pipelining_with_interleaving"]
_logger
=
get_transformer_logger
(
__name__
)
# TODO
(mkozuki): Reduce cyclomatic complexity
# TODO(mkozuki): Reduce cyclomatic complexity
def
_forward_backward_pipelining_with_interleaving
(
forward_step_func
:
FwdStepFunc
,
batch
:
List
[
Batch
],
model
:
List
[
torch
.
nn
.
Module
],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
forward_step_func
:
FwdStepFunc
,
batch
:
List
[
Optional
[
Batch
]],
model
:
List
[
torch
.
nn
.
Module
],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
disable_autocast
:
bool
=
False
,
deallocate_pipeline_outputs
:
bool
=
False
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
**
kwargs
,
)
->
List
[
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
...
...
@@ -48,7 +59,17 @@ def _forward_backward_pipelining_with_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
...
...
@@ -56,22 +77,43 @@ def _forward_backward_pipelining_with_interleaving(
if
not
isinstance
(
model
,
list
):
raise
RuntimeError
(
"`model` must be a list of `nn.Module`'s'"
)
num_model_chunks
=
len
(
model
)
input_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
output_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
curr_iters
=
[
0
for
_
in
range
(
num_model_chunks
)]
losses_reduced
=
[]
if
deallocate_pipeline_outputs
:
warnings
.
warn
(
"`deallocate_pipeline_outputs` is experimental and subject to change. "
"This option is not recommended."
)
# mypy will blame the following if statement
if
sequence_parallel_enabled
:
seq_length
,
batch_size
,
hidden
=
tensor_shape
tensor_shape
=
(
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
(),
batch_size
,
hidden
,
)
num_model_chunks
:
int
=
len
(
model
)
input_tensors
:
List
[
List
[
Union
[
None
,
torch
.
Tensor
]]]
=
[
[]
for
_
in
range
(
num_model_chunks
)
]
output_tensors
:
List
[
List
[
Union
[
None
,
torch
.
Tensor
]]]
=
[
[]
for
_
in
range
(
num_model_chunks
)
]
curr_iters
:
List
[
int
]
=
[
0
for
_
in
range
(
num_model_chunks
)]
losses_reduced
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
num_model_chunks
)]
output_tensor_grads
:
List
[
List
[
Union
[
None
,
torch
.
Tensor
]]]
=
[
[]
for
_
in
range
(
num_model_chunks
)
]
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
pipeline_parallel_size
:
int
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
:
int
=
parallel_state
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
num_microbatches
:
int
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
:
bool
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
num_warmup_microbatches
:
int
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
...
...
@@ -83,10 +125,12 @@ def _forward_backward_pipelining_with_interleaving(
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
num_microbatches_remaining
:
int
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
...
...
@@ -100,24 +144,26 @@ def _forward_backward_pipelining_with_interleaving(
def
get_model_chunk_id
(
microbatch_id
:
int
,
forward
:
bool
)
->
int
:
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
,
curr_iters
)
:
def
forward_step_helper
(
microbatch_id
:
int
,
curr_iters
:
List
[
int
])
->
torch
.
Tensor
:
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).
"""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# forward step
if
(
parallel_state
.
is_pipeline_first_stage
()
and
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
])
):
if
parallel_state
.
is_pipeline_first_stage
()
and
len
(
input_tensors
[
model_chunk_id
]
)
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
...
...
@@ -126,6 +172,8 @@ def _forward_backward_pipelining_with_interleaving(
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
,
dtype
,
disable_autocast
,
)
curr_iters
[
model_chunk_id
]
+=
1
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
...
...
@@ -137,11 +185,13 @@ def _forward_backward_pipelining_with_interleaving(
return
output_tensor
def
backward_step_helper
(
microbatch_id
)
:
def
backward_step_helper
(
microbatch_id
:
int
)
->
torch
.
Tensor
:
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
(run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).
"""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
model_type
=
get_model_type
(
model
[
model_chunk_id
])
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
...
...
@@ -150,7 +200,14 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
deallocate_pipeline_outputs
=
deallocate_pipeline_outputs
,
)
return
input_tensor_grad
...
...
@@ -158,7 +215,14 @@ def _forward_backward_pipelining_with_interleaving(
# Run warmup forward passes.
###################################################################################################################
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
))
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
_logger
.
info
(
"Warmup phase"
)
for
k
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
k
}
/
{
num_warmup_microbatches
}
"
)
...
...
@@ -172,7 +236,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
_logger
.
debug
(
f
"next fwd model chunk ID:
{
next_forward_model_chunk_id
}
, recv_prev:
{
recv_prev
}
"
)
_logger
.
debug
(
f
"next fwd model chunk ID:
{
next_forward_model_chunk_id
}
, recv_prev:
{
recv_prev
}
"
)
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
():
...
...
@@ -181,7 +247,11 @@ def _forward_backward_pipelining_with_interleaving(
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
:
if
(
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
):
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
...
@@ -196,12 +266,23 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
_logger
.
debug
(
"send fwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
)
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
###################################################################################################################
# Run 1F1B in steady state.
...
...
@@ -229,7 +310,9 @@ def _forward_backward_pipelining_with_interleaving(
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
_logger
.
debug
(
f
"fwd/bwd model chunk id:
{
forward_model_chunk_id
}
/
{
backward_model_chunk_id
}
"
)
_logger
.
debug
(
f
"fwd/bwd model chunk id:
{
forward_model_chunk_id
}
/
{
backward_model_chunk_id
}
"
)
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
...
...
@@ -245,7 +328,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
...
@@ -257,7 +342,9 @@ def _forward_backward_pipelining_with_interleaving(
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
...
...
@@ -275,7 +362,11 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
...
...
@@ -290,9 +381,18 @@ def _forward_backward_pipelining_with_interleaving(
_logger
.
info
(
"Cooldown phase"
)
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
))
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
_logger
.
debug
(
f
"cooldown iter
{
k
}
in range(
{
num_microbatches_remaining
}
,
{
num_microbatches
}
)"
)
_logger
.
debug
(
f
"cooldown iter
{
k
}
in range(
{
num_microbatches_remaining
}
,
{
num_microbatches
}
)"
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
...
...
@@ -302,7 +402,14 @@ def _forward_backward_pipelining_with_interleaving(
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
View file @
795a5e5b
from
typing
import
Union
,
List
,
Optional
from
typing
import
Union
,
List
,
Optional
,
Sequence
import
warnings
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.enums
import
ModelType
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.p2p_communication
import
FutureTensor
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.utils
import
get_model_type
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
from
apex.transformer.pipeline_parallel.schedules.common
import
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
free_output_tensor
from
apex.transformer.log_util
import
get_transformer_logger
...
...
@@ -19,14 +25,222 @@ __all__ = ["forward_backward_pipelining_without_interleaving"]
_logger
=
get_transformer_logger
(
__name__
)
def
get_tensor_shapes
(
rank
:
int
,
model_type
:
ModelType
,
*
,
tensor_shape
:
Union
[
List
[
int
],
torch
.
Size
],
decoder_sequence_length
:
Optional
[
int
]
=
None
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
Sequence
[
Sequence
[
int
]]:
"""Get tensors shapes
Args:
rank: pipeline parallel rank
model_type:
Keyword Args:
tensor_shape:
decoder_sequence_length:
sequence_parallel_enabled:
"""
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
assert
(
len
(
tensor_shape
)
==
3
),
f
"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but
{
tensor_shape
}
"
sequence_length
,
micro_batch_size
,
hidden_size
=
tensor_shape
tensor_shapes
=
[]
if
sequence_parallel_enabled
:
seq_length
=
sequence_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
sequence_length
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
sequence_parallel_enabled
:
dec_seq_length
=
decoder_sequence_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
else
:
dec_seq_length
=
decoder_sequence_length
if
parallel_state
.
is_pipeline_stage_before_split
(
rank
):
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
else
:
tensor_shapes
.
append
((
dec_seq_length
,
micro_batch_size
,
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
else
:
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
hidden_size
))
return
tensor_shapes
def
recv_forward
(
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
return
input_tensors
def
recv_backward
(
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
)
return
output_tensor_grads
def
send_forward
(
output_tensors
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
None
:
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
def
send_backward
(
input_tensor_grads
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
None
:
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
def
send_forward_recv_backward
(
output_tensors
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
:
Union
[
torch
.
Tensor
,
List
[
Union
[
None
,
torch
.
Tensor
]]],
tensor_shapes
:
List
[
Union
[
None
,
List
[
int
]]],
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]:
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
forward_step_func
:
FwdStepFunc
,
batch
:
Optional
[
Batch
],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
decoder_sequence_length
:
Optional
[
int
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
grad_scaler
:
Optional
[
torch
.
cuda
.
amp
.
GradScaler
]
=
None
,
disable_autocast
:
bool
=
False
,
deallocate_pipeline_outputs
:
bool
=
False
,
async_comm
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
**
kwargs
,
)
->
List
[
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
...
...
@@ -44,28 +258,59 @@ def forward_backward_pipelining_without_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension
is supposed to be ``(sequence, batch, hidden)``.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
# timers = get_timers()
model
=
listify_model
(
model
)
if
deallocate_pipeline_outputs
:
warnings
.
warn
(
"`deallocate_pipeline_outputs` is experimental and subject to change. "
"This option is not recommended."
)
model
:
List
[
torch
.
nn
.
Module
]
=
listify_model
(
model
)
if
len
(
model
)
!=
1
:
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
model
:
torch
.
nn
.
Module
=
model
[
0
]
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
num_microbatches
:
int
=
get_num_microbatches
()
num_warmup_microbatches
:
int
=
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
:
int
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
:
int
=
num_microbatches
-
num_warmup_microbatches
model_type
=
get_model_type
(
model
)
rank
:
int
=
parallel_state
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
:
List
[
List
[
int
]]
=
get_tensor_shapes
(
rank
-
1
,
model_type
,
tensor_shape
=
tensor_shape
,
decoder_sequence_length
=
decoder_sequence_length
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
send_tensor_shapes
:
List
[
List
[
int
]]
=
get_tensor_shapes
(
rank
,
model_type
,
tensor_shape
=
tensor_shape
,
decoder_sequence_length
=
decoder_sequence_length
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
...
...
@@ -74,13 +319,9 @@ def forward_backward_pipelining_without_interleaving(
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
output_tensors
=
None
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
input_tensors
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
output_tensors
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
losses_reduced
:
List
[
Union
[
None
,
torch
.
Tensor
]]
=
[]
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
...
...
@@ -88,22 +329,42 @@ def forward_backward_pipelining_without_interleaving(
for
i
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
i
}
/
{
num_warmup_microbatches
}
"
)
_logger
.
debug
(
"receive fwd"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
=
recv_forward
(
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
cur_microbatch
:
Optional
[
torch
.
Tensor
]
=
get_kth_microbatch
(
batch
,
i
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
,
dtype
,
disable_autocast
,
)
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
send_forward
(
output_tensor
,
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
_logger
.
debug
(
"recv_forward before steady state start"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
input_tensor
:
List
[
Union
[
None
,
torch
.
Tensor
,
FutureTensor
]]
=
recv_forward
(
tensor_shape
s
=
recv_
tensor_shape
s
,
dtype
=
dtype
,
async_comm
=
async_comm
)
###################################################################################################################
# Run 1F1B in steady state.
...
...
@@ -111,42 +372,84 @@ def forward_backward_pipelining_without_interleaving(
_logger
.
info
(
"Steady phase"
)
for
i
in
range
(
num_microbatches_remaining
):
_logger
.
debug
(
f
"steady iter:
{
i
}
/
{
num_microbatches_remaining
}
"
)
last_iteration
=
i
==
(
num_microbatches_remaining
-
1
)
last_iteration
:
bool
=
i
==
(
num_microbatches_remaining
-
1
)
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
+
num_warmup_microbatches
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
cur_microbatch
:
Optional
[
torch
.
Tensor
]
=
get_kth_microbatch
(
batch
,
i
+
num_warmup_microbatches
)
output_tensor
:
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
,
dtype
,
disable_autocast
,
)
if
forward_only
:
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
send_forward
(
output_tensor
,
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
if
not
last_iteration
:
_logger
.
debug
(
"receive fwd (last iteration)"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
input_tensor
=
recv_forward
(
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
else
:
_logger
.
debug
(
"send fwd & receive bwd"
)
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
output_tensor_grad
=
send_forward_recv_backward
(
output_tensor
,
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
,
deallocate_pipeline_outputs
)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
deallocate_pipeline_outputs
=
deallocate_pipeline_outputs
,
)
if
last_iteration
:
input_tensor
=
None
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
send_backward
(
input_tensor_grad
,
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
else
:
_logger
.
debug
(
"send bwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
input_tensor
=
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
...
...
@@ -158,13 +461,29 @@ def forward_backward_pipelining_without_interleaving(
output_tensor
=
output_tensors
.
pop
(
0
)
_logger
.
debug
(
"receive bwd"
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
)
output_tensor_grad
=
recv_backward
(
tensor_shapes
=
send_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
=
model_type
,
grad_scaler
=
grad_scaler
,
deallocate_pipeline_outputs
=
deallocate_pipeline_outputs
,
)
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
send_backward
(
input_tensor_grad
,
tensor_shapes
=
recv_tensor_shapes
,
dtype
=
dtype
,
async_comm
=
async_comm
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
return
losses_reduced
apex/transformer/pipeline_parallel/utils.py
View file @
795a5e5b
...
...
@@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.transformer
import
parallel_state
from
apex.transformer.enums
import
ModelType
from
apex.transformer.microbatches
import
build_num_microbatches_calculator
from
apex.transformer.pipeline_parallel._timers
import
_Timers
if
multi_tensor_applier
.
available
:
...
...
@@ -118,14 +119,24 @@ def _split_batch_into_microbatch(
# TODO(mkozuki): Support non-tensor local minibatches?
def
get_kth_microbatch
(
batch
:
List
[
torch
.
Tensor
],
k
:
int
)
->
List
[
torch
.
Tensor
]:
def
get_kth_microbatch
(
batch
:
Optional
[
List
[
torch
.
Tensor
]
]
,
k
:
int
)
->
List
[
torch
.
Tensor
]:
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
if
batch
is
None
:
return
batch
micro_batch_size
=
get_micro_batch_size
()
return
[
x
[
k
*
micro_batch_size
:(
k
+
1
)
*
micro_batch_size
]
for
x
in
batch
]
start
=
k
*
micro_batch_size
end
=
start
+
micro_batch_size
microbatch
=
list
()
for
x
in
batch
:
size
=
x
.
size
(
0
)
assert
size
>
start
and
size
>=
end
microbatch
.
append
(
x
[
start
:
end
])
assert
len
(
microbatch
)
>
0
return
microbatch
def
get_autoresume
():
...
...
@@ -186,6 +197,19 @@ def unwrap_model(model, module_instances=(DistributedDataParallel,)):
return
unwrapped_model
def
get_model_type
(
model
:
torch
.
nn
.
Module
,
)
->
ModelType
:
"""Get `model_type` of `model`.
If ``model`` doesn't have ``model_type`` attribute, return ``ModelType.encoder_or_decoder``.
Args:
model
"""
return
getattr
(
unwrap_model
(
model
),
"model_type"
,
ModelType
.
encoder_or_decoder
)
def
calc_params_l2_norm
(
model
:
torch
.
nn
.
Module
,
bf16
:
bool
):
"""Calculate l2 norm of parameters """
# args = get_args()
...
...
apex/transformer/tensor_parallel/__init__.py
View file @
795a5e5b
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -33,6 +32,7 @@ from apex.transformer.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
scatter_to_sequence_parallel_region
,
)
from
.random
import
(
...
...
@@ -63,6 +63,7 @@ __all__ = [
"gather_from_tensor_model_parallel_region"
,
"reduce_from_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
# random.py
"checkpoint"
,
"get_cuda_rng_tracker"
,
...
...
apex/transformer/tensor_parallel/data.py
View file @
795a5e5b
...
...
@@ -25,8 +25,9 @@ _MAX_DATA_DIM = 5
def
_check_data_types
(
keys
,
data
,
target_dtype
):
"""Check that all the keys have the same target data type."""
for
key
in
keys
:
assert
data
[
key
].
dtype
==
target_dtype
,
"{} has data type {} which "
"is different than {}"
.
format
(
key
,
data
[
key
].
dtype
,
target_dtype
assert
data
[
key
].
dtype
==
target_dtype
,
(
"{} has data type {} which "
"is different than {}"
.
format
(
key
,
data
[
key
].
dtype
,
target_dtype
)
)
...
...
@@ -48,7 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
torch
.
distributed
.
broadcast
(
sizes_cuda
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
sizes_cuda
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
)
# Move back to cpu and unpack.
...
...
@@ -92,13 +95,19 @@ def broadcast_data(keys, data, datatype):
# Check that all keys have the same data type.
_check_data_types
(
keys
,
data
,
datatype
)
# Flatten the data associated with the keys
flatten_data
=
torch
.
cat
([
data
[
key
].
contiguous
().
view
(
-
1
)
for
key
in
keys
],
dim
=
0
).
cuda
()
flatten_data
=
torch
.
cat
(
[
data
[
key
].
contiguous
().
view
(
-
1
)
for
key
in
keys
],
dim
=
0
).
cuda
()
else
:
flatten_data
=
torch
.
empty
(
total_numel
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
datatype
)
flatten_data
=
torch
.
empty
(
total_numel
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
datatype
)
# Broadcast
torch
.
distributed
.
broadcast
(
flatten_data
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
flatten_data
,
get_tensor_model_parallel_src_rank
(),
group
=
get_tensor_model_parallel_group
(),
)
# Unpack
...
...
apex/transformer/tensor_parallel/layers.py
View file @
795a5e5b
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -16,6 +16,9 @@
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
from
typing
import
Optional
,
Dict
,
Tuple
,
List
import
warnings
import
torch
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
...
...
@@ -26,12 +29,34 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_group
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer.utils
import
divide
from
apex.transformer.tensor_parallel.mappings
import
copy_to_tensor_model_parallel_region
from
apex.transformer.tensor_parallel.mappings
import
gather_from_tensor_model_parallel_region
from
apex.transformer.tensor_parallel.mappings
import
reduce_from_tensor_model_parallel_region
from
apex.transformer.tensor_parallel.mappings
import
scatter_to_tensor_model_parallel_region
from
apex.transformer.tensor_parallel.mappings
import
(
copy_to_tensor_model_parallel_region
,
)
from
apex.transformer.tensor_parallel.mappings
import
(
gather_from_tensor_model_parallel_region
,
)
from
apex.transformer.tensor_parallel.mappings
import
(
reduce_from_tensor_model_parallel_region
,
)
from
apex.transformer.tensor_parallel.mappings
import
(
scatter_to_tensor_model_parallel_region
,
)
from
apex.transformer.tensor_parallel.mappings
import
(
reduce_scatter_to_sequence_parallel_region
,
)
from
apex.transformer.tensor_parallel.random
import
get_cuda_rng_tracker
from
apex.transformer.tensor_parallel.utils
import
VocabUtility
from
apex.transformer.log_util
import
get_transformer_logger
_logger
=
get_transformer_logger
(
__name__
)
_grad_accum_fusion_available
=
True
try
:
import
fused_weight_gradient_mlp_cuda
except
ImportError
:
_grad_accum_fusion_available
=
False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
...
...
@@ -41,13 +66,13 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
}
def
param_is_not_tensor_parallel_duplicate
(
param
)
:
return
(
hasattr
(
param
,
"tensor_model_parallel"
)
and
param
.
tensor_model_parallel
)
or
(
get_
tensor_model_parallel
_rank
()
==
0
)
def
param_is_not_tensor_parallel_duplicate
(
param
:
torch
.
Tensor
)
->
bool
:
return
(
hasattr
(
param
,
"
tensor_model_parallel
"
)
and
param
.
tensor_model_parallel
)
or
(
get_tensor_model_parallel_rank
()
==
0
)
def
set_tensor_model_parallel_attributes
(
tensor
,
is_parallel
,
dim
,
stride
)
:
def
set_tensor_model_parallel_attributes
(
tensor
:
torch
.
Tensor
,
is_parallel
:
bool
,
dim
:
int
,
stride
:
int
)
->
None
:
# Make sure the attributes are not set.
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
assert
not
hasattr
(
tensor
,
attribute
)
...
...
@@ -57,7 +82,7 @@ def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
setattr
(
tensor
,
"partition_stride"
,
stride
)
def
set_defaults_if_not_set_tensor_model_parallel_attributes
(
tensor
)
:
def
set_defaults_if_not_set_tensor_model_parallel_attributes
(
tensor
:
torch
.
Tensor
)
->
None
:
def
maybe_set
(
attribute
,
value
):
if
not
hasattr
(
tensor
,
attribute
):
setattr
(
tensor
,
attribute
,
value
)
...
...
@@ -66,7 +91,7 @@ def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
maybe_set
(
attribute
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
[
attribute
])
def
copy_tensor_model_parallel_attributes
(
destination_tensor
,
source_tensor
)
:
def
copy_tensor_model_parallel_attributes
(
destination_tensor
:
torch
.
Tensor
,
source_tensor
:
torch
.
Tensor
)
->
None
:
def
maybe_copy
(
attribute
):
if
hasattr
(
source_tensor
,
attribute
):
setattr
(
destination_tensor
,
attribute
,
getattr
(
source_tensor
,
attribute
))
...
...
@@ -76,9 +101,18 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def
_initialize_affine_weight_gpu
(
weight
,
init_method
,
partition_dim
,
stride
=
1
):
"""Initialize affine weight for model parallel on GPU.
"""
"""Initialize affine weight for model parallel on GPU.
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
Args:
weight (Parameter):
init_method (Callable[[Tensor], None]): Taking a Tensor and initialize its elements.
partition_dim (int): Dimension to apply partition.
stride (int):
"""
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
with
get_cuda_rng_tracker
().
fork
():
init_method
(
weight
)
...
...
@@ -103,16 +137,22 @@ def _initialize_affine_weight_cpu(
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
set_tensor_model_parallel_attributes
(
tensor
=
weight
,
is_parallel
=
True
,
dim
=
partition_dim
,
stride
=
stride
)
# Initialize master weight
master_weight
=
torch
.
empty
(
output_size
,
input_size
,
dtype
=
torch
.
float
,
requires_grad
=
False
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
,
dtype
=
torch
.
float
,
requires_grad
=
False
)
init_method
(
master_weight
)
master_weight
=
master_weight
.
to
(
dtype
=
params_dtype
)
# Split and copy
per_partition_per_stride_size
=
divide
(
per_partition_size
,
stride
)
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
my_weight_list
=
weight_list
[
rank
::
world_size
]
...
...
@@ -136,9 +176,15 @@ class VocabParallelEmbedding(torch.nn.Module):
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
init_method
=
init
.
xavier_normal_
,
*
,
params_dtype
=
torch
.
float32
,
use_cpu_initialization
=
False
,
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
init_method
=
init
.
xavier_normal_
,
*
,
params_dtype
:
torch
.
dtype
=
torch
.
float32
,
use_cpu_initialization
:
bool
=
False
,
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
super
().
__init__
()
# Keep the input dimensions.
self
.
num_embeddings
=
num_embeddings
self
.
embedding_dim
=
embedding_dim
...
...
@@ -150,19 +196,35 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
sparse
=
False
self
.
_weight
=
None
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
tensor_model_parallel_size
# Divide the weight matrix along the vocabulary dimension.
(
self
.
vocab_start_index
,
self
.
vocab_end_index
,
)
=
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
tensor_model_parallel_size
,
)
self
.
num_embeddings_per_partition
=
(
self
.
vocab_end_index
-
self
.
vocab_start_index
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
self
.
vocab_start_index
# Allocate weights and initialize.
if
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
dtype
=
params_dtype
)
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
dtype
=
params_dtype
,
)
)
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
,
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
,
params_dtype
=
params_dtype
,
)
else
:
...
...
@@ -174,12 +236,16 @@ class VocabParallelEmbedding(torch.nn.Module):
dtype
=
params_dtype
,
)
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
def
forward
(
self
,
input_
):
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
...
...
@@ -203,16 +269,44 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop."""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
use_16bit_in_wgrad_accum_fusion
:
bool
=
False
,
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
sequence_parallel_enabled
=
sequence_parallel_enabled
ctx
.
use_16bit_in_wgrad_accum_fusion
=
use_16bit_in_wgrad_accum_fusion
if
ctx
.
sequence_parallel_enabled
:
world_size
=
get_tensor_model_parallel_world_size
()
# `input` is supposed to be 3D and its order of dimension is [sequence, batch, hidden]
shape
=
list
(
input
.
shape
)
shape
[
0
]
*=
world_size
all_gather_buffer
=
torch
.
empty
(
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
())
total_input
=
all_gather_buffer
else
:
total_input
=
input
output
=
torch
.
matmul
(
total_input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
...
...
@@ -221,23 +315,115 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
if
ctx
.
sequence_parallel_enabled
:
world_size
=
get_tensor_model_parallel_world_size
()
shape
=
list
(
input
.
shape
)
shape
[
0
]
*=
world_size
all_gather_buffer
=
torch
.
empty
(
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
,
)
total_input
=
all_gather_buffer
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
if
ctx
.
sequence_parallel_enabled
:
handle
.
wait
()
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
]
)
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
if
ctx
.
sequence_parallel_enabled
:
assert
not
ctx
.
async_grad_allreduce
sub_grad_input
=
torch
.
empty
(
input
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
if
ctx
.
gradient_accumulation_fusion
:
if
not
ctx
.
use_16bit_in_wgrad_accum_fusion
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
else
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
def
column_parallel_linear
(
input
,
weight
,
bias
):
args
=
_cast_if_autocast_enabled
(
input
,
weight
,
bias
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
sequence_parallel_enabled
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
def
linear_with_grad_accumulation_and_async_allreduce
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
)
->
torch
.
Tensor
:
args
=
_cast_if_autocast_enabled
(
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel_enabled
,
False
,
# use_16bit_in_wgrad_accum_fusion
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
*
args
)
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
def
linear_with_grad_accumulation_and_async_allreduce_in16bit
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
)
->
torch
.
Tensor
:
args
=
_cast_if_autocast_enabled
(
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel_enabled
,
True
,
# use_16bit_in_wgrad_accum_fusion
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -246,6 +432,10 @@ class ColumnParallelLinear(torch.nn.Module):
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
...
...
@@ -262,6 +452,14 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
Keyword Arguments:
no_async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
def
__init__
(
...
...
@@ -278,8 +476,11 @@ class ColumnParallelLinear(torch.nn.Module):
no_async_tensor_model_parallel_allreduce
=
False
,
params_dtype
=
torch
.
float32
,
use_cpu_initialization
=
False
,
gradient_accumulation_fusion
=
False
,
accumulation_in_fp16
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
...
...
@@ -295,7 +496,9 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose.
# Initialize weight.
if
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
dtype
=
params_dtype
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
dtype
=
params_dtype
)
)
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
...
...
@@ -323,7 +526,11 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
))
else
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
)
torch
.
empty
(
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
,
)
)
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
stride
)
# Always initialize bias to zero.
...
...
@@ -333,28 +540,69 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
register_parameter
(
"bias"
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
not
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
not
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
if
sequence_parallel_enabled
:
if
world_size
<=
1
:
warnings
.
warn
(
f
"`sequence_parallel_enabled` is set to `True`, but got world_size of
{
world_size
}
"
)
# sequence_parallel_enabled = False
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
gradient_accumulation_fusion
:
if
not
_grad_accum_fusion_available
:
# Basically, apex.transformer module users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
warnings
.
warn
(
"`gradient_accumulation_fusion` is set to `True` but "
"the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not "
"found. Thus `gradient_accumulation_fusion` set to `False`. "
"Note that the extension requires CUDA>=11."
)
gradient_accumulation_fusion
=
False
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
if
self
.
async_tensor_model_parallel_allreduce
and
self
.
sequence_parallel_enabled
:
raise
RuntimeError
(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time."
)
self
.
_forward_impl
=
(
linear_with_grad_accumulation_and_async_allreduce_in16bit
if
accumulation_in_fp16
else
linear_with_grad_accumulation_and_async_allreduce
)
def
forward
(
self
,
input_
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Matrix multiply with asynchronous all-reduce execution
output_parallel
=
column_parallel_linear
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
if
self
.
async_tensor_model_parallel_allreduce
or
self
.
sequence_parallel_enabled
:
input_parallel
=
input_
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
# Matrix multiply.
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
weight
=
self
.
weight
,
bias
=
bias
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
self
.
async_tensor_model_parallel_allreduce
,
sequence_parallel_enabled
=
self
.
sequence_parallel_enabled
,
)
if
self
.
gather_output
:
# All-gather across the partitions.
assert
not
self
.
sequence_parallel_enabled
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
...
...
@@ -374,6 +622,11 @@ class RowParallelLinear(torch.nn.Module):
| . |
| A_p |
- -
.. note::
Input is supposed to be three dimensional and each dimension
is expected to be sequence, batch, and hidden feature, respectively.
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
...
...
@@ -390,6 +643,12 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
Keyword Arguments:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
sequence_parallel_enabled:
"""
def
__init__
(
...
...
@@ -405,8 +664,11 @@ class RowParallelLinear(torch.nn.Module):
*
,
params_dtype
=
torch
.
float32
,
use_cpu_initialization
=
False
,
gradient_accumulation_fusion
=
False
,
accumulation_in_fp16
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
):
super
(
RowParallelLinear
,
self
).
__init__
()
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
...
...
@@ -416,6 +678,10 @@ class RowParallelLinear(torch.nn.Module):
world_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
if
self
.
sequence_parallel_enabled
and
not
self
.
input_is_parallel
:
raise
RuntimeError
(
"To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`"
)
# as an argument to this function?
# Parameters.
...
...
@@ -423,7 +689,11 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose.
# Initialize weight.
if
use_cpu_initialization
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
dtype
=
params_dtype
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
dtype
=
params_dtype
)
)
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
...
...
@@ -444,30 +714,63 @@ class RowParallelLinear(torch.nn.Module):
dtype
=
params_dtype
,
)
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
)
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
)
if
bias
:
if
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
else
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
)
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
,
)
)
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
setattr
(
self
.
bias
,
"sequence_parallel_enabled"
,
sequence_parallel_enabled
)
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
input_
):
self
.
_forward_impl
=
(
linear_with_grad_accumulation_and_async_allreduce_in16bit
if
accumulation_in_fp16
else
linear_with_grad_accumulation_and_async_allreduce
)
def
forward
(
self
,
input_
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
assert
not
self
.
sequence_parallel_enabled
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
self
.
_forward_impl
(
input
=
input_parallel
,
weight
=
self
.
weight
,
bias
=
None
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
async_grad_allreduce
=
False
,
sequence_parallel_enabled
=
False
,
)
# All-reduce across all the partitions.
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
self
.
sequence_parallel_enabled
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
...
...
apex/transformer/tensor_parallel/mappings.py
View file @
795a5e5b
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -20,7 +20,7 @@ from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from
apex.transformer.tensor_parallel.utils
import
split_tensor_along_last_dim
def
_reduce
(
input_
)
:
def
_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
...
...
@@ -33,7 +33,7 @@ def _reduce(input_):
return
input_
def
_split
(
input_
)
:
def
_split
_along_last_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Split the tensor along its last dimension and keep the
corresponding slice."""
...
...
@@ -52,8 +52,24 @@ def _split(input_):
return
output
def
_gather
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
def
_split_along_first_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Split the tensor along its first dimension and keep the corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU for tensor model parallel.
if
world_size
==
1
:
return
input_
# Split along first dimension.
dim_size
=
input_
.
size
(
0
)
assert
dim_size
%
world_size
==
0
local_dim_size
=
dim_size
//
world_size
dim_offset
=
get_tensor_model_parallel_rank
()
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
return
output
def
_gather_along_last_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Gather tensors and concatenate along the last dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
...
...
@@ -66,7 +82,9 @@ def _gather(input_):
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_tensor_model_parallel_group
())
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_tensor_model_parallel_group
()
)
# Note: torch.cat already creates a contiguous tensor.
output
=
torch
.
cat
(
tensor_list
,
dim
=
last_dim
).
contiguous
()
...
...
@@ -74,9 +92,49 @@ def _gather(input_):
return
output
def
_gather_along_first_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Gather tensors and concatenate along the first dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
shape
=
list
(
input_
.
shape
)
shape
[
0
]
*=
world_size
output
=
torch
.
empty
(
shape
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_all_gather_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
()
)
return
output
def
_reduce_scatter_along_first_dim
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
shape
=
list
(
input_
.
shape
)
assert
shape
[
0
]
%
world_size
==
0
shape
[
0
]
//=
world_size
output
=
torch
.
empty
(
shape
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
()
)
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
"""Pass the input to the
tensor
model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
...
...
@@ -91,8 +149,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the model parallel region."""
"""All-reduce the input from the
tensor
model parallel region."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
...
...
@@ -109,33 +169,95 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
)
return
_gather
_along_last_dim
(
grad_output
)
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concat
i
nate."""
"""Gather the input from
tensor
model parallel region and concat
e
nate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
)
return
_split_along_last_dim
(
grad_output
)
class
_ScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chunk to the rank."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
,
to_model_parallel
:
bool
=
True
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
to_model_parallel
:
bool
=
True
):
ctx
.
to_model_parallel
=
to_model_parallel
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
to_model_parallel
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
else
:
return
_split_along_first_dim
(
grad_output
),
None
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
# FIXME(mkozuki): Definition of static symbolic methods don't look correct according to
# https://pytorch.org/docs/stable/onnx.html#static-symbolic-method
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
# -----------------
...
...
@@ -143,17 +265,40 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# -----------------
def
copy_to_tensor_model_parallel_region
(
input_
)
:
def
copy_to_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_tensor_model_parallel_region
(
input_
)
:
def
reduce_from_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_tensor_model_parallel_region
(
input_
)
:
def
scatter_to_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_tensor_model_parallel_region
(
input_
)
:
def
gather_from_tensor_model_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_sequence_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_sequence_parallel_region
(
input_
:
torch
.
Tensor
,
to_model_parallel
:
bool
=
True
)
->
torch
.
Tensor
:
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
to_model_parallel
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
__all__
=
[
"copy_to_tensor_model_parallel_region"
,
"reduce_from_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
"gather_from_sequence_parallel_region"
,
"reduce_scatter_to_sequence_parallel_region"
,
]
apex/transformer/tensor_parallel/memory.py
View file @
795a5e5b
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(mkozuki): Remove this file as Megatron-LM seems to have done so.
import
torch
...
...
@@ -49,13 +52,20 @@ class MemoryBuffer:
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
print
(
"> building the {} memory buffer with {} num elements "
"and {} dtype ({:.1f} MB)..."
.
format
(
name
,
numel
,
dtype
,
numel
*
element_size
/
1024
/
1024
),
"and {} dtype ({:.1f} MB)..."
.
format
(
name
,
numel
,
dtype
,
numel
*
element_size
/
1024
/
1024
),
flush
=
True
,
)
self
.
name
=
name
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
self
.
data
=
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
# Index tracking the start of the free memory.
self
.
_start
=
0
...
...
@@ -81,13 +91,17 @@ class MemoryBuffer:
def
add
(
self
,
tensor
):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert
tensor
.
dtype
==
self
.
dtype
,
"Input tensor type {} different from buffer type {}"
.
format
(
assert
(
tensor
.
dtype
==
self
.
dtype
),
"Input tensor type {} different from buffer type {}"
.
format
(
tensor
.
dtype
,
self
.
dtype
)
# Number of elements of the input tensor.
tensor_numel
=
torch
.
numel
(
tensor
)
new_start
=
self
.
_start
+
tensor_numel
assert
new_start
<=
self
.
numel
,
"Not enough memory left in the buffer ({} > {})"
.
format
(
assert
(
new_start
<=
self
.
numel
),
"Not enough memory left in the buffer ({} > {})"
.
format
(
tensor_numel
,
self
.
numel
-
self
.
_start
)
# New tensor is a view into the memory.
...
...
@@ -124,7 +138,8 @@ class RingMemBuffer:
def
__init__
(
self
,
name
,
num_buffers
,
numel
,
dtype
,
track_usage
):
self
.
num_buffers
=
num_buffers
self
.
buffers
=
[
allocate_mem_buff
(
name
+
" {}"
.
format
(
i
),
numel
,
dtype
,
track_usage
)
for
i
in
range
(
num_buffers
)
allocate_mem_buff
(
name
+
" {}"
.
format
(
i
),
numel
,
dtype
,
track_usage
)
for
i
in
range
(
num_buffers
)
]
self
.
_index
=
-
1
...
...
apex/transformer/tensor_parallel/random.py
View file @
795a5e5b
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021
-22
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Audit this file.
# I don't think some functions strongly relate to `random` in tensor_parallel.
# Rather, some functions are mainly for gradient checkpointing (torch.utils.checkpoint).
# NOTE(mkozuki): This file is based on megatron-lm/mpu/random.py with some differences:
# - Not using "viewless" tensor:
# - _kernel_make_viewless_tensor
# - MakeViewlessTensor
# - make_viewless_tensor
# - assert_viewless_tensor
# - safely_set_viewless_tensor_data
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
...
...
@@ -35,13 +39,12 @@ from apex.transformer.utils import gather_split_1d_tensor
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
"model-parallel-rng"
# TODO(mkozuki): Remove `_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER` as megatron-lm doesn't seem to use.
# Whether apply model parallelism to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
=
None
# TODO (mkozuki): Consider the possibility of removing `tensor_model_parallel_size`,
# `get_tensor_model_parallel_world_size()` might be alternative.
# TODO(mkozuki): Remove `init_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def
init_checkpointed_activations_memory_buffer
(
micro_batch_size
,
max_position_embeddings
,
...
...
@@ -53,8 +56,15 @@ def init_checkpointed_activations_memory_buffer(
):
"""Initializ the memory buffer for the checkpointed activations."""
per_layer
=
micro_batch_size
*
max_position_embeddings
*
hidden_size
//
tensor_model_parallel_size
assert
num_layers
%
checkpoint_num_layers
==
0
,
"number of layers is not divisible by checkpoint-num-layers"
per_layer
=
(
micro_batch_size
*
max_position_embeddings
*
hidden_size
//
tensor_model_parallel_size
)
assert
(
num_layers
%
checkpoint_num_layers
==
0
),
"number of layers is not divisible by checkpoint-num-layers"
num_checkpointer_layers
=
num_layers
//
checkpoint_num_layers
numel
=
per_layer
*
num_checkpointer_layers
dtype
=
torch
.
half
...
...
@@ -70,6 +80,7 @@ def init_checkpointed_activations_memory_buffer(
)
# TODO(mkozuki): Remove `reset_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use.
def
reset_checkpointed_activations_memory_buffer
():
"""Reset the memory used for checkpointing."""
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
...
...
@@ -79,7 +90,7 @@ def reset_checkpointed_activations_memory_buffer():
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
Arguments
s
:
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
...
...
@@ -217,7 +228,9 @@ def model_parallel_cuda_manual_seed(seed):
# Set the default state.
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file.
...
...
@@ -230,8 +243,9 @@ class CheckpointFunction(torch.autograd.Function):
"""
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_saved_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
distribute_saved_activations
=
distribute_saved_activations
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -243,10 +257,8 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
args
[
0
].
data
=
split_tensor_into_1d_equal_chunks
(
args
[
0
].
data
)
args
[
0
].
data
=
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
.
add
(
args
[
0
].
data
)
if
ctx
.
distribute_saved_activations
:
ctx
.
input_0_shape
=
args
[
0
].
shape
# Store everything.
ctx
.
save_for_backward
(
*
args
)
...
...
@@ -255,11 +267,11 @@ class CheckpointFunction(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
*
args
):
if
not
torch
.
autograd
.
_is_checkpoint_valid
():
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
if
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
is
not
None
:
inputs
[
0
].
data
=
gather_split_1d_tensor
(
inputs
[
0
].
data
)
inputs
[
0
].
data
=
inputs
[
0
].
data
.
view
(
ctx
.
input_0_shape
)
# Store the current states.
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -284,11 +296,16 @@ class CheckpointFunction(torch.autograd.Function):
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
(
outputs
,)
torch
.
autograd
.
backward
(
outputs
,
args
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
return
(
None
,)
+
grads
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
detached_inputs
)
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
# NOTE(mkozuki): It doesn't look like `distribute_saved_activations` is used in apex.transformer
# but I added this change to reduce the superficial difference from Megatron-LM.
def
checkpoint
(
function
,
distribute_saved_activations
,
*
args
):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
*
args
)
return
CheckpointFunction
.
apply
(
function
,
distribute_saved_activations
,
*
args
)
apex/transformer/tensor_parallel/utils.py
View file @
795a5e5b
...
...
@@ -12,12 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
List
,
Sequence
import
torch
from
apex.transformer.utils
import
divide
def
split_tensor_along_last_dim
(
tensor
,
num_partitions
,
contiguous_split_chunks
=
False
):
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
List
[
torch
.
Tensor
]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
...
...
@@ -43,12 +49,16 @@ class VocabUtility:
partition: Note that indices in [fist, last)"""
@
staticmethod
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
):
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
:
int
,
rank
,
world_size
:
int
)
->
Sequence
[
int
]:
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
@
staticmethod
def
vocab_range_from_global_vocab_size
(
global_vocab_size
,
rank
,
world_size
)
:
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
)
->
Sequence
[
int
]
:
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
)
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
)
apex/transformer/testing/arguments.py
View file @
795a5e5b
...
...
@@ -39,9 +39,13 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vi
t
_args
(
parser
)
parser
=
_add_vi
sion
_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
# NOTE(mkozuki): This option is added to investigate the potential of `torch.autograd.graph.save_on_cpu()`.
# ref: https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.save_on_cpu.
parser
.
add_argument
(
'--cpu-offload'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Turns on CPU offloading'
)
# Custom arguments.
if
extra_args_provider
is
not
None
:
parser
=
extra_args_provider
(
parser
)
...
...
@@ -65,6 +69,11 @@ def parse_args(extra_args_provider=None, defaults={},
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
,
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
args
.
transformer_pipeline_model_parallel_size
=
(
args
.
pipeline_model_parallel_size
-
1
if
args
.
standalone_embedding_stage
else
args
.
pipeline_model_parallel_size
)
# Checks.
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
...
...
@@ -98,13 +107,18 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
args
.
activations_checkpoint_method
=
'uniform'
args
.
recompute_granularity
=
'full'
args
.
recompute_method
=
'uniform'
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
'use --
activation-checkpoint
-method instead. '
'Defaulting to
activation-checkpoint
-method=uniform.'
)
'use --
recompute-granularity and --recompute
-method
instead. '
'Defaulting to
recompute-granularity=full and recompute
-method=uniform.'
)
del
args
.
checkpoint_activations
if
args
.
recompute_activations
:
args
.
recompute_granularity
=
'selective'
del
args
.
recompute_activations
# Set input defaults.
for
key
in
defaults
:
# For default to be valid, it should not be provided in the
...
...
@@ -166,6 +180,14 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
else
:
if
args
.
gradient_accumulation_fusion
:
args
.
gradient_accumulation_fusion
=
False
if
args
.
rank
==
0
:
print
(
'Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False'
,
flush
=
True
)
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
...
...
@@ -244,17 +266,51 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
fp32_residual_connection
:
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
weight_decay_incr_style
==
'constant'
:
assert
args
.
start_weight_decay
is
None
assert
args
.
end_weight_decay
is
None
args
.
start_weight_decay
=
args
.
weight_decay
args
.
end_weight_decay
=
args
.
weight_decay
else
:
assert
args
.
start_weight_decay
is
not
None
assert
args
.
end_weight_decay
is
not
None
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
# Persistent fused layer norm.
if
TORCH_MAJOR
<
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
11
):
args
.
no_persist_layer_norm
=
True
if
args
.
rank
==
0
:
print
(
'Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True'
)
# Activation recomputing.
if
args
.
distribute_saved_activations
:
assert
args
.
tensor_model_parallel_size
>
1
,
'can distribute '
\
'
checkpoin
ted activations only across tensor model '
\
'
recompu
ted activations only across tensor model '
\
'parallel groups'
assert
args
.
activations_checkpoint_method
is
not
None
,
\
'for distribute-checkpointed-activations to work you '
\
'need to use a activation-checkpoint method '
assert
args
.
num_layers_per_virtual_pipeline_stage
is
None
,
\
'currently distrobuted checkpoint activations only supported for '
\
'nointerleaved pipeline parallelism'
assert
args
.
recompute_granularity
==
'full'
,
\
'distributed recompute activations is only '
\
'application to full recompute granularity'
assert
args
.
recompute_method
is
not
None
,
\
'for distributed recompute activations to work you '
\
'need to use a recompute method '
assert
TORCH_MAJOR
>=
1
and
TORCH_MINOR
>=
10
,
\
'distributed recompute activations are supported for pytorch '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
if
args
.
recompute_granularity
==
'selective'
:
assert
args
.
recompute_method
is
None
,
\
'recompute method is not yet supported for '
\
'selective recomputing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if
args
.
sequence_parallel
:
args
.
async_tensor_model_parallel_allreduce
=
False
_print_args
(
args
)
return
args
...
...
@@ -279,6 +335,18 @@ def _check_arg_is_not_none(args, arg):
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
def
_add_inference_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'inference'
)
group
.
add_argument
(
'--inference-batch-times-seqlen-threshold'
,
type
=
int
,
default
=
512
,
help
=
'During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.'
)
return
parser
def
_add_network_size_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'network size'
)
...
...
@@ -318,6 +386,8 @@ def _add_network_size_args(parser):
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
help
=
'Disable BERT binary head.'
,
dest
=
'bert_binary_head'
)
group
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
help
=
'Number of Experts in Switch Transformer (None means no Switch)'
)
return
parser
...
...
@@ -354,6 +424,9 @@ def _add_logging_args(parser):
group
.
add_argument
(
'--log-memory-to-tensorboard'
,
action
=
'store_true'
,
help
=
'Enable memory logging to tensorboard.'
)
group
.
add_argument
(
'--log-world-size-to-tensorboard'
,
action
=
'store_true'
,
help
=
'Enable world size logging to tensorboard.'
)
return
parser
...
...
@@ -367,6 +440,13 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-weight-decay'
,
type
=
float
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-weight-decay'
,
type
=
float
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--weight-decay-incr-style'
,
type
=
str
,
default
=
'constant'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
help
=
'Gradient clipping based on global L2 norm.'
)
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
...
...
@@ -413,27 +493,40 @@ def _add_training_args(parser):
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--
checkpoint
-activations'
,
action
=
'store_true'
,
help
=
'
Checkpoint
activation to allow for training '
group
.
add_argument
(
'--
recompute
-activations'
,
action
=
'store_true'
,
help
=
'
recompute
activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
group
.
add_argument
(
'--recompute-granularity'
,
type
=
str
,
default
=
None
,
choices
=
[
'full'
,
'selective'
],
help
=
'Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.'
)
group
.
add_argument
(
'--distribute-saved-activations'
,
action
=
'store_true'
,
help
=
'If set, distribute
checkpoin
ted activations '
help
=
'If set, distribute
recompu
ted activations '
'across model parallel group.'
)
group
.
add_argument
(
'--
activations-checkpoint
-method'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--
recompute
-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and
checkpoint
the input activation of '
'each divided chunk, '
'2)
checkpoint
the input activations of only a set number of '
'Transformer layers and
recompute
the input activation of '
'each divided chunk
at specified granularity
, '
'2)
recompute
the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any
checkpointing
'
'default) do not apply activations
checkpoint
to any layers'
)
group
.
add_argument
(
'--
activations-checkpoint
-num-layers'
,
type
=
int
,
default
=
1
,
'rest without any
recomputing at specified granularity
'
'default) do not apply activations
recompute
to any layers'
)
group
.
add_argument
(
'--
recompute
-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided
checkpoint
unit, '
'uniformly divided
recompute
unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.'
)
'to recompute within each pipeline stage.'
)
# deprecated
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
...
...
@@ -472,7 +565,20 @@ def _add_training_args(parser):
action
=
'store_true'
,
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
'gradient compuation of a column-linear layer.'
,
dest
=
'async_tensor_model_parallel_allreduce'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
group
.
add_argument
(
'--sequence-parallel'
,
action
=
'store_true'
,
help
=
'Enable sequence parallel optimization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
return
parser
...
...
@@ -645,6 +751,11 @@ def _add_distributed_args(parser):
help
=
'Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.'
)
group
.
add_argument
(
'--standalone-embedding-stage'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)'
)
return
parser
...
...
@@ -791,16 +902,70 @@ def _add_biencoder_args(parser):
return
parser
def
_add_vi
t
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"vi
t
"
)
def
_add_vi
sion
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"vi
sion
"
)
# general vision arguments
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
help
=
'num of classes in vision classificaiton task'
)
group
.
add_argument
(
'--img-dim'
,
type
=
int
,
default
=
224
,
help
=
'Image size for vision classification task'
)
group
.
add_argument
(
'--img-h'
,
type
=
int
,
default
=
224
,
help
=
'Image height for vision classification task'
)
group
.
add_argument
(
'--img-w'
,
type
=
int
,
default
=
224
,
help
=
'Image height for vision classification task'
)
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
help
=
'Number of channels in input image data'
)
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
help
=
'patch dimension used in vit'
)
help
=
'patch dimension'
)
group
.
add_argument
(
'--classes-fraction'
,
type
=
float
,
default
=
1.0
,
help
=
'training with fraction of classes.'
)
group
.
add_argument
(
'--data-per-class-fraction'
,
type
=
float
,
default
=
1.0
,
help
=
'training with fraction of data per class.'
)
group
.
add_argument
(
'--no-data-sharding'
,
action
=
'store_false'
,
help
=
'Disable data sharding.'
,
dest
=
'data_sharding'
)
group
.
add_argument
(
'--head-lr-mult'
,
type
=
float
,
default
=
1.0
,
help
=
'learning rate multiplier for head during finetuning'
)
# pretraining type and backbone selection`
group
.
add_argument
(
'--vision-pretraining'
,
action
=
'store_true'
,
help
=
'flag to indicate vision pretraining'
)
group
.
add_argument
(
'--vision-pretraining-type'
,
type
=
str
,
default
=
'classify'
,
choices
=
[
'classify'
,
'inpaint'
,
'dino'
],
help
=
'pretraining objectives'
)
group
.
add_argument
(
'--vision-backbone-type'
,
type
=
str
,
default
=
'vit'
,
choices
=
[
'vit'
,
'mit'
,
'swin'
],
help
=
'backbone types types'
)
group
.
add_argument
(
'--swin-backbone-type'
,
type
=
str
,
default
=
'tiny'
,
choices
=
[
'tiny'
,
'base'
,
'h3'
],
help
=
'pretraining objectives'
)
# inpainting arguments
group
.
add_argument
(
'--mask-type'
,
type
=
str
,
default
=
'random'
,
choices
=
[
'random'
,
'row'
],
help
=
'mask types'
)
group
.
add_argument
(
'--mask-factor'
,
type
=
float
,
default
=
1.0
,
help
=
'mask size scaling parameter'
)
# dino arguments
group
.
add_argument
(
'--iter-per-epoch'
,
type
=
int
,
default
=
1250
,
help
=
'iterations per epoch'
)
group
.
add_argument
(
'--dino-local-img-size'
,
type
=
int
,
default
=
96
,
help
=
'Image size for vision classification task'
)
group
.
add_argument
(
'--dino-local-crops-number'
,
type
=
int
,
default
=
10
,
help
=
'Number of local crops'
)
group
.
add_argument
(
'--dino-head-hidden-size'
,
type
=
int
,
default
=
2048
,
help
=
'Hidden dimension size in dino head'
)
group
.
add_argument
(
'--dino-bottleneck-size'
,
type
=
int
,
default
=
256
,
help
=
'Bottle neck dimension in dino head '
)
group
.
add_argument
(
'--dino-freeze-last-layer'
,
type
=
float
,
default
=
1
,
help
=
'Freezing last layer weights'
)
group
.
add_argument
(
'--dino-norm-last-layer'
,
action
=
'store_true'
,
help
=
'Disable Norm in last layer.'
)
group
.
add_argument
(
'--dino-warmup-teacher-temp'
,
type
=
float
,
default
=
0.04
,
help
=
'warump teacher temperature'
)
group
.
add_argument
(
'--dino-teacher-temp'
,
type
=
float
,
default
=
0.07
,
help
=
'teacher temperature'
)
group
.
add_argument
(
'--dino-warmup-teacher-temp-epochs'
,
type
=
int
,
default
=
30
,
help
=
'warmup teacher temperaure epochs'
)
return
parser
apex/transformer/testing/commons.py
View file @
795a5e5b
...
...
@@ -12,15 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
import
datetime
import
os
import
random
from
typing
import
Optional
,
Union
,
List
from
typing
import
Optional
,
Union
,
List
,
Tuple
,
Callable
,
Dict
import
numpy
import
torch
import
torch.nn
as
nn
from
apex
import
transformer
from
apex.transformer.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
scatter_to_sequence_parallel_region
,
)
from
apex.transformer.pipeline_parallel.utils
import
(
average_losses_across_data_parallel_group
,
)
from
apex.transformer.pipeline_parallel.schedules.common
import
(
Batch
,
)
from
apex.transformer.testing
import
global_vars
...
...
@@ -29,7 +42,6 @@ TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes.
class
MyLayer
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
):
super
().
__init__
()
self
.
pre_process
=
pre_process
...
...
@@ -39,17 +51,28 @@ class MyLayer(nn.Module):
def
forward
(
self
,
x
):
return
self
.
layer
(
x
)
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
=
False
,
post_process
:
bool
=
False
)
->
None
:
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
=
False
,
post_process
:
bool
=
False
,
*
,
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
layer
=
MyLayer
(
hidden_size
=
hidden_size
,
pre_process
=
pre_process
,
post_process
=
post_process
)
self
.
layer
=
MyLayer
(
hidden_size
=
hidden_size
,
pre_process
=
pre_process
,
post_process
=
post_process
)
self
.
input_tensor
=
None
def
set_input_tensor
(
self
,
input_tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]])
->
None
:
self
.
input_tensor
=
input_tensor
def
set_input_tensor
(
self
,
input_tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
)
->
None
:
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
self
.
input_tensor
=
input_tensor
[
0
]
def
forward
(
self
,
x
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
if
self
.
input_tensor
is
None
:
...
...
@@ -57,8 +80,154 @@ class MyModel(nn.Module):
return
self
.
layer
(
self
.
input_tensor
)
def
model_provider_func
(
hidden_size
,
pre_process
,
post_process
)
->
MyModel
:
return
MyModel
(
hidden_size
,
pre_process
,
post_process
)
class
ToyParallelMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
pre_process
:
bool
=
False
,
post_process
:
bool
=
False
,
*
,
sequence_parallel_enabled
:
bool
=
False
,
# TODO(mkozuki): Support these two?
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
sequence_parallel_enabled
=
sequence_parallel_enabled
ffn_hidden_size
=
4
*
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
ffn_hidden_size
,
gather_output
=
False
,
# init_method=init_method,
skip_bias_add
=
True
,
# use_cpu_initialization=use_cpu_initialization,
bias
=
True
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
no_async_tensor_model_parallel_allreduce
=
True
,
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
ffn_hidden_size
,
hidden_size
,
input_is_parallel
=
True
,
# init_method=output_layer_init_method,
skip_bias_add
=
False
,
# use_cpu_initialization=use_cpu_initialization,
bias
=
True
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
self
.
activation_func
=
torch
.
nn
.
GELU
()
def
set_input_tensor
(
self
,
input_tensor
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
)
->
None
:
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
self
.
input_tensor
=
input_tensor
[
0
]
def
forward
(
self
,
x
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""Forward of Simplified ParallelMLP.
Args:
x: :obj:`None` if pipeline rank != pippeline first rank. When :obj:`None`,
`self.input_tensor` is taken care of by `forward_step` defined in
apex/transformer/pipeline_parallel/schedules/common.py
"""
# [s, b, h]
if
self
.
input_tensor
is
None
:
input
=
x
else
:
input
=
self
.
input_tensor
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
input
)
if
bias_parallel
is
not
None
:
intermediate_parallel
+=
bias_parallel
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
def
model_provider_func
(
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
,
*
,
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
)
->
MyModel
:
return
MyModel
(
hidden_size
,
pre_process
,
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
)
def
mlp_provider_func
(
hidden_size
:
int
,
pre_process
:
bool
,
post_process
:
bool
,
*
,
add_encoder
:
bool
=
False
,
add_decoder
:
bool
=
False
,
sequence_parallel_enabled
:
bool
=
False
,
)
->
ToyParallelMLP
:
return
ToyParallelMLP
(
hidden_size
,
pre_process
,
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
sequence_parallel_enabled
=
sequence_parallel_enabled
,
)
def
process_batch
(
batch
):
if
isinstance
(
batch
,
list
):
x
=
batch
[
0
]
else
:
x
=
batch
return
x
def
fwd_step_func
(
batch
,
model
):
x
=
process_batch
(
batch
)
y
=
model
(
x
)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def
loss_func
(
x
):
loss
=
torch
.
sum
(
x
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
"avg"
:
averaged_loss
}
return
y
,
loss_func
@
dataclass
(
frozen
=
True
)
class
ToyParallelMLPFwdBwdStepFunc
:
sequence_parallel_enabled
:
bool
def
__call__
(
self
,
batch
:
Batch
,
model
:
torch
.
nn
.
Module
,
)
->
Tuple
[
torch
.
Tensor
,
Callable
[[
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]]]:
x
=
batch
[
0
]
if
isinstance
(
batch
,
list
)
else
batch
if
isinstance
(
x
,
torch
.
Tensor
):
x
=
x
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
sequence_parallel_enabled
:
x
=
scatter_to_sequence_parallel_region
(
x
)
y
=
model
(
x
)
# note (mkozuki): I don't think this function is nice but I do think this is enough for now
# just to check the sanity of ported pipeline functions.
def
loss_func
(
x
):
loss
=
torch
.
sum
(
x
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
"avg"
:
averaged_loss
}
return
y
,
loss_func
class
IdentityLayer
(
torch
.
nn
.
Module
):
...
...
@@ -78,22 +247,28 @@ def set_random_seed(seed):
transformer
.
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
def
initialize_distributed
(
backend
=
'
nccl
'
):
def
initialize_distributed
(
backend
=
"
nccl
"
):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
# args = parser.parse_args()
if
backend
not
in
(
"nccl"
,
"ucc"
):
raise
RuntimeError
(
f
"Currently only nccl & ucc are supported but
{
backend
}
"
)
if
backend
==
"ucc"
:
import
torch_ucc
# NOQA
args
=
global_vars
.
get_args
()
local_rank
=
args
.
local_rank
# Get rank and world size.
rank
=
int
(
os
.
getenv
(
'
RANK
'
,
'0'
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
rank
=
int
(
os
.
getenv
(
"
RANK
"
,
"0"
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
print
(
'> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'
.
format
(
local_rank
,
rank
,
world_size
))
print
(
"> initializing torch.distributed with local rank: {}, "
"rank: {}, world size: {}"
.
format
(
local_rank
,
rank
,
world_size
)
)
# Set the device id.
device
=
rank
%
torch
.
cuda
.
device_count
()
...
...
@@ -102,22 +277,21 @@ def initialize_distributed(backend='nccl'):
torch
.
cuda
.
set_device
(
device
)
# Call the init process.
init_method
=
'
tcp://
'
master_ip
=
os
.
getenv
(
'
MASTER_ADDR
'
,
'
localhost
'
)
master_port
=
os
.
getenv
(
'
MASTER_PORT
'
,
'
6000
'
)
init_method
+=
master_ip
+
':'
+
master_port
init_method
=
"
tcp://
"
master_ip
=
os
.
getenv
(
"
MASTER_ADDR
"
,
"
localhost
"
)
master_port
=
os
.
getenv
(
"
MASTER_PORT
"
,
"
6000
"
)
init_method
+=
master_ip
+
":"
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
init_method
=
init_method
)
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
init_method
=
init_method
,
timeout
=
datetime
.
timedelta
(
seconds
=
60
),
)
def
print_separator
(
message
):
torch
.
distributed
.
barrier
()
filler_len
=
(
78
-
len
(
message
))
//
2
filler
=
'-'
*
filler_len
string
=
'
\n
'
+
filler
+
'
{}
'
.
format
(
message
)
+
filler
filler
=
"-"
*
filler_len
string
=
"
\n
"
+
filler
+
"
{}
"
.
format
(
message
)
+
filler
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
string
,
flush
=
True
)
torch
.
distributed
.
barrier
()
apex/transformer/testing/distributed_test_base.py
0 → 100644
View file @
795a5e5b
import
os
import
sys
import
unittest
from
packaging.version
import
Version
,
parse
import
torch
from
torch
import
distributed
as
dist
from
torch.utils
import
collect_env
from
torch.testing._internal
import
common_utils
from
torch.testing._internal
import
common_distributed
HAS_TORCH_UCC
=
None
try
:
import
torch_ucc
HAS_TORCH_UCC
=
True
except
ImportError
:
HAS_TORCH_UCC
=
False
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
=
Version
(
"470.42.01"
)
_driver_version
=
None
if
torch
.
cuda
.
is_available
():
_driver_version
=
parse
(
collect_env
.
get_nvidia_driver_version
(
collect_env
.
run
))
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
=
_driver_version
is
not
None
and
_driver_version
>=
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
class
DistributedTestBase
(
common_distributed
.
MultiProcessTestCase
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
def
setUp
(
self
)
->
None
:
super
().
setUp
()
self
.
_setup_pre_spawn
()
self
.
_spawn_processes
()
def
tearDown
(
self
)
->
None
:
super
().
tearDown
()
@
property
def
world_size
(
self
)
->
int
:
return
min
(
torch
.
cuda
.
device_count
(),
4
)
@
property
def
init_method
(
self
):
return
f
"
{
common_utils
.
FILE_SCHEMA
}{
self
.
file_name
}
"
@
classmethod
def
_run
(
cls
,
rank
,
test_name
,
file_name
,
pipe
):
self
=
cls
(
test_name
)
self
.
assertTrue
(
torch
.
cuda
.
is_available
())
self
.
assertTrue
(
hasattr
(
self
,
"DISTRIBUTED_BACKEND"
))
self
.
rank
=
rank
self
.
file_name
=
file_name
print
(
f
"[dist init] rank =
{
self
.
rank
}
, world_size =
{
self
.
world_size
}
"
)
try
:
dist
.
init_process_group
(
init_method
=
self
.
init_method
,
backend
=
self
.
DISTRIBUTED_BACKEND
,
world_size
=
int
(
self
.
world_size
),
rank
=
self
.
rank
,
)
except
RuntimeError
as
e
:
if
"recompile"
in
e
.
args
[
0
]:
print
(
f
"Backend of
{
self
.
DISTRIBUTED_BACKEND
}
not available"
)
sys
.
exit
(
0
)
raise
torch
.
cuda
.
set_device
(
self
.
rank
%
torch
.
cuda
.
device_count
())
dist
.
barrier
()
self
.
run_test
(
test_name
,
pipe
)
dist
.
barrier
()
dist
.
destroy_process_group
()
sys
.
exit
(
0
)
def
_setup_pre_spawn
(
self
):
pass
class
NcclDistributedTestBase
(
DistributedTestBase
):
DISTRIBUTED_BACKEND
=
"nccl"
@
unittest
.
skipUnless
(
HAS_TORCH_UCC
,
"Requires [`torch_ucc`](https://github.com/facebookresearch/torch_ucc)"
,
)
@
unittest
.
skipUnless
(
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
,
f
"`torch_ucc` requires NVIDIA driver >=
{
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
}
but
{
_driver_version
}
found. "
"See https://github.com/openucx/ucc/issues/496"
,
)
class
UccDistributedTestBase
(
DistributedTestBase
):
DISTRIBUTED_BACKEND
=
"ucc"
def
_setup_pre_spawn
(
self
)
->
None
:
self
.
master_addr
=
"localhost"
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
self
.
_has_master_port
=
"MASTER_PORT"
in
os
.
environ
if
self
.
_has_master_port
:
self
.
master_port
=
os
.
environ
[
"MASTER_PORT"
]
else
:
try
:
from
caffe2.torch.fb.common.utils
import
get_free_port
self
.
master_port
=
str
(
get_free_port
())
except
ImportError
:
self
.
master_port
=
"12375"
os
.
environ
[
"MASTER_PORT"
]
=
self
.
master_port
self
.
_has_ucx_tls
=
"UCX_TLS"
in
os
.
environ
if
not
self
.
_has_ucx_tls
:
os
.
environ
[
"UCX_TLS"
]
=
"tcp,cuda"
print
(
'os.environ[
\"
UCX_TLS
\"
] = {}'
.
format
(
os
.
environ
[
"UCX_TLS"
]))
def
tearDown
(
self
)
->
None
:
super
().
tearDown
()
if
not
self
.
_has_master_port
:
del
os
.
environ
[
"MASTER_PORT"
]
if
not
self
.
_has_ucx_tls
:
del
os
.
environ
[
"UCX_TLS"
]
@
property
def
init_method
(
self
):
return
"tcp://localhost:"
+
os
.
environ
[
"MASTER_PORT"
]
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment