Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
79906517
Commit
79906517
authored
Dec 08, 2021
by
hubertlu-tw
Browse files
Merge remote-tracking branch 'upstream/master' into IFU-master-2021-12-08
parents
cc92a4b4
aa756cec
Changes
71
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1791 additions
and
65 deletions
+1791
-65
apex/transformer/enums.py
apex/transformer/enums.py
+1
-1
apex/transformer/functional/__init__.py
apex/transformer/functional/__init__.py
+1
-1
apex/transformer/functional/fused_softmax.py
apex/transformer/functional/fused_softmax.py
+2
-2
apex/transformer/log_util.py
apex/transformer/log_util.py
+19
-0
apex/transformer/microbatches.py
apex/transformer/microbatches.py
+23
-15
apex/transformer/parallel_state.py
apex/transformer/parallel_state.py
+57
-7
apex/transformer/pipeline_parallel/__init__.py
apex/transformer/pipeline_parallel/__init__.py
+8
-0
apex/transformer/pipeline_parallel/_timers.py
apex/transformer/pipeline_parallel/_timers.py
+83
-0
apex/transformer/pipeline_parallel/p2p_communication.py
apex/transformer/pipeline_parallel/p2p_communication.py
+404
-0
apex/transformer/pipeline_parallel/schedules/__init__.py
apex/transformer/pipeline_parallel/schedules/__init__.py
+39
-0
apex/transformer/pipeline_parallel/schedules/common.py
apex/transformer/pipeline_parallel/schedules/common.py
+218
-0
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
...rmer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
+91
-0
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
...arallel/schedules/fwd_bwd_pipelining_with_interleaving.py
+308
-0
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
...llel/schedules/fwd_bwd_pipelining_without_interleaving.py
+170
-0
apex/transformer/pipeline_parallel/utils.py
apex/transformer/pipeline_parallel/utils.py
+333
-0
apex/transformer/tensor_parallel/__init__.py
apex/transformer/tensor_parallel/__init__.py
+6
-11
apex/transformer/tensor_parallel/cross_entropy.py
apex/transformer/tensor_parallel/cross_entropy.py
+7
-7
apex/transformer/tensor_parallel/data.py
apex/transformer/tensor_parallel/data.py
+4
-4
apex/transformer/tensor_parallel/layers.py
apex/transformer/tensor_parallel/layers.py
+12
-12
apex/transformer/tensor_parallel/mappings.py
apex/transformer/tensor_parallel/mappings.py
+5
-5
No files found.
apex/transformer/enums.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
apex/transformer/functional/__init__.py
View file @
79906517
from
.fused_softmax
import
FusedScaleMaskSoftmax
from
apex.transformer.functional
.fused_softmax
import
FusedScaleMaskSoftmax
__all__
=
[
__all__
=
[
"FusedScaleMaskSoftmax"
,
"FusedScaleMaskSoftmax"
,
...
...
apex/transformer/functional/fused_softmax.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
torch
import
torch
from
apex._autocast_utils
import
_cast_if_autocast_enabled
from
apex._autocast_utils
import
_cast_if_autocast_enabled
from
.
.enums
import
AttnMaskType
from
apex.transformer
.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
...
...
apex/transformer/log_util.py
0 → 100644
View file @
79906517
from
typing
import
Optional
import
logging
import
os
import
threading
def
get_transformer_logger
(
name
:
str
)
->
logging
.
Logger
:
name_wo_ext
=
os
.
path
.
splitext
(
name
)[
0
]
return
logging
.
getLogger
(
name_wo_ext
)
def
set_logging_level
(
verbosity
)
->
None
:
"""Change logging severity.
Args:
verbosity
"""
from
apex
import
_library_root_logger
_library_root_logger
.
setLevel
(
verbosity
)
apex/transformer/
tensor_parallel/
microbatches.py
→
apex/transformer/microbatches.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -15,35 +15,41 @@
...
@@ -15,35 +15,41 @@
"""Megatron number of micro-batches calculators."""
"""Megatron number of micro-batches calculators."""
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Optional
,
List
def
build_num_microbatches_calculator
(
args
):
def
build_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
):
# Constant num micro-batches.
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
if
rampup_batch_size
is
None
:
num_microbatches_calculator
=
ConstantNumMicroBatches
(
num_microbatches_calculator
=
ConstantNumMicroBatches
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
)
if
args
.
rank
==
0
:
if
rank
==
0
:
print
(
print
(
"setting number of micro-batches to constant {}"
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
"setting number of micro-batches to constant {}"
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
)
)
else
:
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
(
assert
len
(
rampup_batch_size
)
==
3
,
(
"expected the following "
"expected the following "
"format: --rampup-batch-size <start batch size> "
"format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>"
"<batch size incerement> <ramp-up samples>"
)
)
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
start_batch_size
=
int
(
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
batch_size_increment
=
int
(
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
ramup_samples
=
int
(
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
if
rank
==
0
:
print
(
print
(
"will use batch size rampup starting from global batch "
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"size {} to global batch size {} with batch size increments "
"{} over {} samples."
.
format
(
"{} over {} samples."
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
start_batch_size
,
global_batch_size
,
batch_size_increment
,
ramup_samples
),
),
flush
=
True
,
flush
=
True
,
)
)
...
@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args):
...
@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args):
start_batch_size
,
start_batch_size
,
batch_size_increment
,
batch_size_increment
,
ramup_samples
,
ramup_samples
,
args
.
global_batch_size
,
global_batch_size
,
args
.
micro_batch_size
,
micro_batch_size
,
args
.
data_parallel_size
,
data_parallel_size
,
)
)
return
num_microbatches_calculator
return
num_microbatches_calculator
...
@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
assert
self
.
num_micro_batches
>=
1
assert
self
.
num_micro_batches
>=
1
self
.
current_global_batch_size
=
global_batch_size
self
.
current_global_batch_size
=
global_batch_size
self
.
micro_batch_size
=
micro_batch_size
def
update
(
self
,
consumed_samples
,
consistency_check
):
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
pass
...
...
apex/transformer/parallel_state.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,11 +13,13 @@
...
@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Model and data parallel groups."""
"""Model and data parallel groups."""
from
typing
import
Tuple
import
torch
import
torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
# only for ensure_divisibility
from
.te
nsor
_parallel
import
utils
from
apex.tra
ns
f
or
mer.utils
import
ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
# Intra-layer model parallel group that the current rank belongs to.
...
@@ -40,6 +42,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
...
@@ -40,6 +42,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
...
@@ -76,23 +81,26 @@ def initialize_model_parallel(
...
@@ -76,23 +81,26 @@ def initialize_model_parallel(
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
ranks 8 to 15 belong to the second box.
"""
"""
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size_
))
print
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size_
))
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
min
(
tensor_model_parallel_size_
,
world_size
)
tensor_model_parallel_size
=
min
(
tensor_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
=
min
(
pipeline_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
=
min
(
pipeline_model_parallel_size_
,
world_size
)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file.
ensure_divisibility
(
world_size
,
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
utils
.
ensure_divisibility
(
world_size
,
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size
))
print
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
))
print
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
))
num_tensor_model_parallel_groups
=
world_size
//
tensor_model_parallel_size
num_tensor_model_parallel_groups
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
if
virtual_pipeline_model_parallel_size_
is
not
None
:
assert
pipeline_model_parallel_size_
>
2
,
\
'pipeline-model-parallel size should be greater than 2 with '
\
'interleaved schedule'
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
...
@@ -138,6 +146,7 @@ def initialize_model_parallel(
...
@@ -138,6 +146,7 @@ def initialize_model_parallel(
global
_PIPELINE_GLOBAL_RANKS
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
"pipeline model parallel group is already initialized"
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
"pipeline model parallel group is already initialized"
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
for
i
in
range
(
num_pipeline_model_parallel_groups
):
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
...
@@ -154,6 +163,19 @@ def initialize_model_parallel(
...
@@ -154,6 +163,19 @@ def initialize_model_parallel(
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
def
get_rank_info
()
->
Tuple
[
int
,
int
,
int
]:
"""Returns a tuple of (tensor, pipeline, data)-parallel-rank for logger."""
if
model_parallel_is_initialized
():
return
(
get_tensor_model_parallel_rank
(),
get_pipeline_model_parallel_rank
(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank
(),
)
return
(
0
,
0
,
0
)
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
...
@@ -193,6 +215,22 @@ def get_embedding_group():
...
@@ -193,6 +215,22 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
return
_EMBEDDING_GROUP
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_EMBEDDING_GLOBAL_RANKS
if
ignore_virtual
:
return
rank
in
_EMBEDDING_GLOBAL_RANKS
if
rank
in
_EMBEDDING_GLOBAL_RANKS
:
if
rank
==
_EMBEDDING_GLOBAL_RANKS
[
0
]:
return
is_pipeline_first_stage
(
ignore_virtual
=
False
)
elif
rank
==
_EMBEDDING_GLOBAL_RANKS
[
-
1
]:
return
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
:
return
True
return
False
def
set_tensor_model_parallel_world_size
(
world_size
):
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
@@ -344,3 +382,15 @@ def destroy_model_parallel():
...
@@ -344,3 +382,15 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
apex/transformer/pipeline_parallel/__init__.py
0 → 100644
View file @
79906517
from
apex.transformer.pipeline_parallel.schedules
import
get_forward_backward_func
from
apex.transformer.pipeline_parallel.schedules.common
import
build_model
__all__
=
[
"get_forward_backward_func"
,
"build_model"
,
]
apex/transformer/pipeline_parallel/_timers.py
0 → 100644
View file @
79906517
import
time
import
torch
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
"timer has already been started"
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
"timer is not started"
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
time
.
time
()
-
self
.
start_time
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
_Timers
:
"""Group of timers."""
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
_Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
"-time"
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
"time (ms)"
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
" | {}: {:.2f}"
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
apex/transformer/pipeline_parallel/p2p_communication.py
0 → 100644
View file @
79906517
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
operator
from
typing
import
Union
,
Optional
,
Tuple
import
warnings
import
torch
from
apex._autocast_utils
import
_get_current_dtype
from
apex.transformer
import
parallel_state
from
apex.transformer.utils
import
split_tensor_into_1d_equal_chunks
from
apex.transformer.utils
import
gather_split_1d_tensor
from
apex.transformer.pipeline_parallel.utils
import
Shape
from
apex.transformer.pipeline_parallel._timers
import
_Timers
def
_run_p2pops
(
tensor_send_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_next
:
Union
[
torch
.
Tensor
,
None
],
):
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
def
_communicate
(
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Optional
[
Shape
]
=
None
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
dtype_
:
torch
.
dtype
=
torch
.
float
,
*
,
scatter_gather_tensors_in_pipeline
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
fp32_residual_connection
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
Union
[
torch
.
Tensor
,
None
]]:
"""Base function for communication of tensors between stages.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
recv_prev: boolean for whether tensor should be received from previous rank.
recv_next: boolean for whether tensor should be received from next rank.
tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline:
optional, this is used when tensor_shape is provided to override scatter gather tensors
dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape
Keyword args:
scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
Returns:
tuple containing
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
if
tensor_shape
is
None
:
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise
RuntimeError
(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`"
)
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
(
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
parallel_state
.
get_tensor_model_parallel_world_size
(),)
else
:
tensor_chunk_shape
=
tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if
dtype_
!=
torch
.
float32
or
params_dtype
is
not
None
:
if
torch
.
distributed
.
get_rank
()
==
0
:
warnings
.
warn
(
"Tensor P2P communications are executed in FP32"
)
dtype
=
torch
.
float32
requires_grad
=
True
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
,
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
,
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops
(
tensor_send_prev
,
tensor_send_next
,
tensor_recv_prev
,
tensor_recv_next
)
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
(
gather_split_1d_tensor
(
tensor_recv_prev
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
if
recv_next
:
tensor_recv_next
=
(
gather_split_1d_tensor
(
tensor_recv_next
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
:
Shape
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
parallel_state
.
is_pipeline_first_stage
():
return
None
if
timers
is
not
None
:
timers
(
"forward-recv"
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-recv"
).
stop
()
return
input_tensor
def
recv_backward
(
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
parallel_state
.
is_pipeline_last_stage
():
return
None
if
timers
is
not
None
:
timers
(
"backward-recv"
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-recv"
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Send tensor to next rank in pipeline (forward send)."""
if
parallel_state
.
is_pipeline_last_stage
():
return
if
timers
is
not
None
:
timers
(
"forward-send"
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-send"
).
stop
()
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Send tensor to previous rank in pipeline (backward send)."""
if
parallel_state
.
is_pipeline_first_stage
():
return
if
timers
is
not
None
:
timers
(
"backward-send"
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-send"
).
stop
()
def
send_forward_recv_backward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Batched send and recv with next rank in pipeline."""
if
parallel_state
.
is_pipeline_last_stage
():
return
None
if
timers
is
not
None
:
timers
(
"forward-send-backward-recv"
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-send-backward-recv"
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Batched send and recv with previous rank in pipeline."""
if
parallel_state
.
is_pipeline_first_stage
():
return
None
if
timers
is
not
None
:
timers
(
"backward-send-forward-recv"
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-send-forward-recv"
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
recv_prev
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
"forward-send-forward-recv"
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-send-forward-recv"
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
:
torch
.
Tensor
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
"backward-send-backward-recv"
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-send-backward-recv"
).
stop
()
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
"forward-backward-send-forward-backward-recv"
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-backward-send-forward-backward-recv"
).
stop
()
return
input_tensor
,
output_tensor_grad
apex/transformer/pipeline_parallel/schedules/__init__.py
0 → 100644
View file @
79906517
import
warnings
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
forward_backward_no_pipelining
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
_forward_backward_pipelining_with_interleaving
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving
import
(
forward_backward_pipelining_without_interleaving
,
)
class
ExperimentalWarning
(
Warning
):
pass
def
get_forward_backward_func
(
virtual_pipeline_model_parallel_size
,
pipeline_model_parallel_size
,
):
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
get_num_microbatches
()
%
pipeline_model_parallel_size
!=
0
:
msg
=
"number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise
RuntimeError
(
msg
)
warnings
.
warn
(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f
"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`:
{
virtual_pipeline_model_parallel_size
}
"
,
ExperimentalWarning
)
forward_backward_func
=
_forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
__all__
=
[
"get_forward_backward_func"
,
]
apex/transformer/pipeline_parallel/schedules/common.py
0 → 100644
View file @
79906517
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
,
Optional
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
unwrap_model
from
apex.transformer.tensor_parallel.layers
import
set_defaults_if_not_set_tensor_model_parallel_attributes
Batch
=
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
...]]
LossFunc
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
FwdStepFunc
=
Callable
[[
Batch
,
torch
.
nn
.
Module
],
Tuple
[
torch
.
Tensor
,
LossFunc
]]
def
build_model
(
model_provider_func
:
Callable
[[
Any
,
Dict
[
str
,
Any
]],
torch
.
nn
.
Module
],
wrap_with_ddp
:
bool
=
True
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
*
args
,
**
kwargs
)
->
List
[
torch
.
nn
.
Module
]:
"""Build the model satisfying pipeline model parallel requirements.
This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
`model_provider_func`.
Args:
model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
Returns:
a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
the list has multiple models, otherwise one.
"""
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
virtual_pipeline_model_parallel_size
is
not
None
):
model
=
[]
for
i
in
range
(
virtual_pipeline_model_parallel_size
):
cur_args
=
args
cur_kwargs
=
kwargs
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
i
)
# Set pre_process and post_process only after virtual rank is set.
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
({
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
})
this_model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
model
.
append
(
this_model
)
else
:
cur_args
=
args
cur_kwargs
=
kwargs
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
({
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
})
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
parallel_state
.
get_data_parallel_rank
()
==
0
:
msg
=
" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}"
.
format
(
parallel_state
.
get_tensor_model_parallel_rank
(),
parallel_state
.
get_pipeline_model_parallel_rank
(),
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])
)
print
(
msg
,
flush
=
True
)
# GPU allocation.
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
if
wrap_with_ddp
:
i
=
torch
.
cuda
.
current_device
()
model
=
[
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
(
model_module
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
parallel_state
.
get_data_parallel_group
(),
)
for
model_module
in
model
]
return
model
def
_get_params_for_weight_decay_optimization
(
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
)
->
Dict
[
str
,
torch
.
nn
.
Parameter
]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules
=
listify_model
(
model
)
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
# NOQA
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
FusedLayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
def
forward_step
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
torch
.
nn
.
Module
,
input_tensor
:
Optional
[
torch
.
Tensor
],
losses_reduced
:
List
[
torch
.
Tensor
],
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor.
Args:
forward_step_func: Model specific function. This takes a minibatch and model as its arguments and
returns the model's output and the loss function.
batch: minibatch
model: unwrappable model
input_tensor:
losses_reduced:
Returns:
output_tensor
"""
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model
=
unwrap_model
(
model
)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
batch
,
model
)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
# timers("forward-compute").stop()
return
output_tensor
def
backward_step
(
input_tensor
:
Optional
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
output_tensor_grad
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage).
Args:
input_tensor:
output_tensor:
output_tensor_grad:
Returns:
input_tensor_grad
"""
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
# timers("backward-compute").stop()
return
input_tensor_grad
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
0 → 100644
View file @
79906517
from
contextlib
import
contextmanager
from
typing
import
List
,
Union
import
torch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.log_util
import
get_transformer_logger
_all__
=
[
"forward_backward_no_pipelining"
]
_logger
=
get_transformer_logger
(
__name__
)
@
contextmanager
def
placeholder_handler
():
try
:
yield
finally
:
pass
def
forward_backward_no_pipelining
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
**
kwargs
,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A List of torch.Tensors
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
model
=
listify_model
(
model
)
if
len
(
model
)
!=
1
:
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
context_handler
=
placeholder_handler
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
):
context_handler
=
model
.
no_sync
losses_reduced
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
num_micro_batches
=
get_num_microbatches
()
with
context_handler
():
for
i
in
range
(
num_micro_batches
-
1
):
_logger
.
info
(
f
"Iter
{
i
}
of
{
num_micro_batches
-
1
}
"
)
cur_micro_batch
=
get_kth_microbatch
(
batch
,
i
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_micro_batch
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger
.
info
(
"Cooldown"
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
forward_step_func
,
get_kth_microbatch
(
batch
,
num_micro_batches
-
1
),
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
0 → 100644
View file @
79906517
from
typing
import
List
,
Union
,
Optional
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.log_util
import
get_transformer_logger
__all__
=
[
"_forward_backward_pipelining_with_interleaving"
]
_logger
=
get_transformer_logger
(
__name__
)
# TODO (mkozuki): Reduce cyclomatic complexity
def
_forward_backward_pipelining_with_interleaving
(
forward_step_func
:
FwdStepFunc
,
batch
:
List
[
Batch
],
model
:
List
[
torch
.
nn
.
Module
],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
This means that model is split into model chunks.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown
Note that if `forward_only` this scheduling consists of only warmup phase.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
if
not
isinstance
(
model
,
list
):
raise
RuntimeError
(
"`model` must be a list of `nn.Module`'s'"
)
num_model_chunks
=
len
(
model
)
input_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
output_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
curr_iters
=
[
0
for
_
in
range
(
num_model_chunks
)]
losses_reduced
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
num_model_chunks
)]
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
get_num_microbatches
()
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
f
"num_warmup_microbatches:
{
num_warmup_microbatches
}
, "
f
"num_microbatches_remaining:
{
num_microbatches_remaining
}
"
)
###################################################################################################################
# Helper function definitions.
###################################################################################################################
def
get_model_chunk_id
(
microbatch_id
:
int
,
forward
:
bool
)
->
int
:
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
,
curr_iters
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# forward step
if
(
parallel_state
.
is_pipeline_first_stage
()
and
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
])
):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
get_kth_microbatch
(
batch
,
curr_iters
[
model_chunk_id
]),
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
,
)
curr_iters
[
model_chunk_id
]
+=
1
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
))
_logger
.
info
(
"Warmup phase"
)
for
k
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
k
}
/
{
num_warmup_microbatches
}
"
)
output_tensor
=
forward_step_helper
(
k
,
curr_iters
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
_logger
.
debug
(
f
"next fwd model chunk ID:
{
next_forward_model_chunk_id
}
, recv_prev:
{
recv_prev
}
"
)
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
():
_logger
.
debug
(
"Pipeline last stage, not sending tensor downstream"
)
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
:
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
_logger
.
debug
(
"send fwd&bwd and receive fwd&bwd"
)
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
_logger
.
debug
(
"send fwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
_logger
.
info
(
"Steady phase"
)
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
_logger
.
debug
(
f
" steady phase iter
{
k
}
/
{
num_microbatches_remaining
}
"
)
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
,
curr_iters
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
_logger
.
debug
(
f
"fwd/bwd model chunk id:
{
forward_model_chunk_id
}
/
{
backward_model_chunk_id
}
"
)
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Communicate tensors.
_logger
.
debug
(
"send fwd&bwd and receive fwd&bwd"
)
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
###################################################################################################################
# Run cooldown backward passes (flush out pipeline).
###################################################################################################################
_logger
.
info
(
"Cooldown phase"
)
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
_logger
.
debug
(
f
"cooldown iter
{
k
}
in range(
{
num_microbatches_remaining
}
,
{
num_microbatches
}
)"
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
)
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
0 → 100644
View file @
79906517
from
typing
import
Union
,
List
,
Optional
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.log_util
import
get_transformer_logger
__all__
=
[
"forward_backward_pipelining_without_interleaving"
]
_logger
=
get_transformer_logger
(
__name__
)
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown if not forward_only
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
# timers = get_timers()
model
=
listify_model
(
model
)
if
len
(
model
)
!=
1
:
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
f
"num_warmup_microbatches:
{
num_warmup_microbatches
}
, "
f
"num_microbatches_remaining:
{
num_microbatches_remaining
}
"
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
output_tensors
=
None
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
_logger
.
info
(
"Warmup"
)
for
i
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
i
}
/
{
num_warmup_microbatches
}
"
)
_logger
.
debug
(
"receive fwd"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
_logger
.
debug
(
"recv_forward before steady state start"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
_logger
.
info
(
"Steady phase"
)
for
i
in
range
(
num_microbatches_remaining
):
_logger
.
debug
(
f
"steady iter:
{
i
}
/
{
num_microbatches_remaining
}
"
)
last_iteration
=
i
==
(
num_microbatches_remaining
-
1
)
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
+
num_warmup_microbatches
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
if
forward_only
:
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
if
not
last_iteration
:
_logger
.
debug
(
"receive fwd (last iteration)"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
else
:
_logger
.
debug
(
"send fwd & receive bwd"
)
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
input_tensor
=
None
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
else
:
_logger
.
debug
(
"send bwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
_logger
.
info
(
"Cooldown phase"
)
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"cooldown iter:
{
i
}
/
{
num_warmup_microbatches
}
"
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
_logger
.
debug
(
"receive bwd"
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
return
losses_reduced
apex/transformer/pipeline_parallel/utils.py
0 → 100644
View file @
79906517
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for pipeline model parallel."""
from
typing
import
Optional
,
List
,
Union
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.transformer
import
parallel_state
from
apex.transformer.microbatches
import
build_num_microbatches_calculator
from
apex.transformer.pipeline_parallel._timers
import
_Timers
if
multi_tensor_applier
.
available
:
import
amp_C
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_TOKENIZER
=
None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
listify_model
(
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]])
->
List
[
torch
.
nn
.
Module
]:
if
isinstance
(
model
,
list
):
return
model
return
[
model
]
def
_ensure_var_is_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
not
None
,
"{} is not initialized."
.
format
(
name
)
def
_ensure_var_is_not_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
None
,
"{} is already initialized."
.
format
(
name
)
def
setup_microbatch_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
)
->
None
:
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'num microbatches calculator'
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
def
_reconfigure_microbatch_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
)
->
None
:
if
torch
.
distributed
.
get_rank
()
==
0
:
import
warnings
warnings
.
warn
(
"This function is only for unittest"
)
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
def
get_micro_batch_size
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
micro_batch_size
def
get_num_microbatches
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
get_current_global_batch_size
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
,
consistency_check
=
True
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
)
# note (mkozuki): Comment out in favor of `get_kth_microbatch`
def
_split_batch_into_microbatch
(
batch
:
List
[
torch
.
Tensor
],
*
,
_micro_batch_size
:
Optional
[
int
]
=
None
,
_global_batch_size
:
Optional
[
int
]
=
None
,
)
->
List
[
List
[
torch
.
Tensor
]]:
micro_batch_size
=
_micro_batch_size
global_batch_size
=
_global_batch_size
if
micro_batch_size
is
None
:
micro_batch_size
=
get_micro_batch_size
()
if
global_batch_size
is
None
:
global_batch_size
=
get_current_global_batch_size
()
for
i
in
range
(
0
,
global_batch_size
,
micro_batch_size
):
yield
[
x
[
i
*
micro_batch_size
:(
i
+
1
)
*
micro_batch_size
]
for
x
in
batch
]
# TODO(mkozuki): Support non-tensor local minibatches?
def
get_kth_microbatch
(
batch
:
List
[
torch
.
Tensor
],
k
:
int
)
->
List
[
torch
.
Tensor
]:
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
micro_batch_size
=
get_micro_batch_size
()
return
[
x
[
k
*
micro_batch_size
:(
k
+
1
)
*
micro_batch_size
]
for
x
in
batch
]
def
get_autoresume
():
return
_GLOBAL_AUTORESUME
def
_set_timers
():
"""Initialize timers."""
global
_GLOBAL_TIMERS
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
"timers"
)
_GLOBAL_TIMERS
=
_Timers
()
def
get_timers
():
"""Return timers."""
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
"timers"
)
return
_GLOBAL_TIMERS
def
print_rank_0
(
message
:
str
)
->
None
:
"""If distributed is initialized, print only on rank 0."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
is_last_rank
():
return
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
)
def
print_rank_last
(
message
):
"""If distributed is initialized, print only on last rank."""
if
torch
.
distributed
.
is_initialized
():
if
is_last_rank
():
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
param_is_not_shared
(
param
:
torch
.
nn
.
Parameter
)
->
bool
:
return
getattr
(
param
,
"shared"
,
False
)
def
unwrap_model
(
model
,
module_instances
=
(
DistributedDataParallel
,)):
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
return_list
=
False
unwrapped_model
=
[]
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
calc_params_l2_norm
(
model
:
torch
.
nn
.
Module
,
bf16
:
bool
):
"""Calculate l2 norm of parameters """
# args = get_args()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Remove duplicate params.
params_data
=
[]
for
model_
in
model
:
for
param
in
model_
.
parameters
():
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
parallel_state
.
param_is_not_tensor_parallel_duplicate
(
param
)
if
is_not_shared
and
is_not_tp_duplicate
:
if
bf16
:
params_data
.
append
(
param
.
data
.
float
())
else
:
params_data
.
append
(
param
.
data
)
# Calculate norm
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
params_data
],
False
# no per-parameter norm
)
norm_2
=
norm
*
norm
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
norm_2
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
parallel_state
.
get_model_parallel_group
()
)
return
norm_2
.
item
()
**
0.5
def
average_losses_across_data_parallel_group
(
losses
):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses
=
torch
.
cat
([
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
averaged_losses
,
group
=
parallel_state
.
get_data_parallel_group
())
averaged_losses
=
averaged_losses
/
torch
.
distributed
.
get_world_size
(
group
=
parallel_state
.
get_data_parallel_group
()
)
return
averaged_losses
def
report_memory
(
name
):
"""Simple GPU memory report."""
mega_bytes
=
1024.0
*
1024.0
string
=
name
+
" memory (MB)"
string
+=
" | allocated: {}"
.
format
(
torch
.
cuda
.
memory_allocated
()
/
mega_bytes
)
string
+=
" | max allocated: {}"
.
format
(
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
" | reserved: {}"
.
format
(
torch
.
cuda
.
memory_reserved
()
/
mega_bytes
)
string
+=
" | max reserved: {}"
.
format
(
torch
.
cuda
.
max_memory_reserved
()
/
mega_bytes
)
if
parallel_state
.
get_data_parallel_rank
()
==
0
:
print
(
"[Rank {}] {}"
.
format
(
torch
.
distributed
.
get_rank
(),
string
),
flush
=
True
)
def
print_params_min_max_norm
(
optimizer
,
iteration
):
"""Print min, max, and norm of all parameters."""
index
=
0
rank
=
torch
.
distributed
.
get_rank
()
string
=
"iteration, rank, index, tensor-model-parallel, min, max, norm
\n
"
optimizer_
=
optimizer
.
optimizer
for
param_group
in
optimizer_
.
param_groups
:
for
param
in
param_group
[
"params"
]:
index
+=
1
min_
=
param
.
data
.
min
()
max_
=
param
.
data
.
max
()
norm
=
torch
.
linalg
.
norm
(
param
.
data
)
string
+=
"{:7d}, {:4d}, {:4d}, {:2d}, "
.
format
(
iteration
,
rank
,
index
,
int
(
param
.
tensor_model_parallel
)
)
string
+=
"{:.6E}, {:.6E}, {:.6E}
\n
"
.
format
(
min_
,
max_
,
norm
)
print
(
string
,
flush
=
True
)
# NOTE (mkozuki): APEX doesn't have anything equivalent for
# `_GLOBAL_ADLR_AUTORESUME` like Megatron-LM.
# def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler, save: bool):
# """Check for autoresume signal and exit if it is received."""
# from apex.ppu.checkpointing import save_checkpoint
#
# autoresume = get_adlr_autoresume()
# # Add barrier to ensure consistency.
# torch.distributed.barrier()
# if autoresume.termination_requested():
# if save:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
# print_rank_0(">>> autoresume termination request found!")
# if torch.distributed.get_rank() == 0:
# autoresume.request_resume()
# print_rank_0(">>> training terminated. Returning")
# sys.exit(0)
def
get_ltor_masks_and_position_ids
(
data
,
eod_token
,
reset_position_ids
,
reset_attention_mask
,
eod_mask_loss
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask (lower triangular).
if
reset_attention_mask
:
att_mask_batch
=
micro_batch_size
else
:
att_mask_batch
=
1
attention_mask
=
torch
.
tril
(
torch
.
ones
((
att_mask_batch
,
seq_length
,
seq_length
),
device
=
data
.
device
)
).
view
(
att_mask_batch
,
1
,
seq_length
,
seq_length
)
# Loss mask.
loss_mask
=
torch
.
ones
(
data
.
size
(),
dtype
=
torch
.
float
,
device
=
data
.
device
)
if
eod_mask_loss
:
loss_mask
[
data
==
eod_token
]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# We need to clone as the ids will be modifed based on batch index.
if
reset_position_ids
:
position_ids
=
position_ids
.
clone
()
if
reset_position_ids
or
reset_attention_mask
:
# Loop through the batches:
for
b
in
range
(
micro_batch_size
):
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
data
[
b
]
==
eod_token
]
# Detach indecies from positions if going to modify positions.
if
reset_position_ids
:
eod_index
=
eod_index
.
clone
()
# Loop through EOD indecies:
prev_index
=
0
for
j
in
range
(
eod_index
.
size
()[
0
]):
i
=
eod_index
[
j
]
# Mask attention loss.
if
reset_attention_mask
:
attention_mask
[
b
,
0
,
(
i
+
1
)
:,
:
(
i
+
1
)]
=
0
# Reset positions.
if
reset_position_ids
:
position_ids
[
b
,
(
i
+
1
)
:]
-=
i
+
1
-
prev_index
prev_index
=
i
+
1
# Convert attention mask to binary:
attention_mask
=
attention_mask
<
0.5
return
attention_mask
,
loss_mask
,
position_ids
apex/transformer/tensor_parallel/__init__.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -15,11 +15,11 @@
...
@@ -15,11 +15,11 @@
"""Model parallel utility interface."""
"""Model parallel utility interface."""
from
.cross_entropy
import
vocab_parallel_cross_entropy
from
apex.transformer.tensor_parallel
.cross_entropy
import
vocab_parallel_cross_entropy
from
.data
import
broadcast_data
from
apex.transformer.tensor_parallel
.data
import
broadcast_data
from
.layers
import
(
from
apex.transformer.tensor_parallel
.layers
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
@@ -28,7 +28,7 @@ from .layers import (
...
@@ -28,7 +28,7 @@ from .layers import (
copy_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
)
)
from
.mappings
import
(
from
apex.transformer.tensor_parallel
.mappings
import
(
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
...
@@ -41,11 +41,9 @@ from .random import (
...
@@ -41,11 +41,9 @@ from .random import (
init_checkpointed_activations_memory_buffer
,
init_checkpointed_activations_memory_buffer
,
model_parallel_cuda_manual_seed
,
model_parallel_cuda_manual_seed
,
reset_checkpointed_activations_memory_buffer
,
reset_checkpointed_activations_memory_buffer
,
gather_split_1d_tensor
,
split_tensor_into_1d_equal_chunks
,
)
)
from
.utils
import
divide
,
split_tensor_along_last_dim
from
apex.transformer.tensor_parallel
.utils
import
split_tensor_along_last_dim
__all__
=
[
__all__
=
[
...
@@ -71,9 +69,6 @@ __all__ = [
...
@@ -71,9 +69,6 @@ __all__ = [
"init_checkpointed_activations_memory_buffer"
,
"init_checkpointed_activations_memory_buffer"
,
"model_parallel_cuda_manual_seed"
,
"model_parallel_cuda_manual_seed"
,
"reset_checkpointed_activations_memory_buffer"
,
"reset_checkpointed_activations_memory_buffer"
,
"gather_split_1d_tensor"
,
"split_tensor_into_1d_equal_chunks"
,
# utils.py
# utils.py
"divide"
,
"split_tensor_along_last_dim"
,
"split_tensor_along_last_dim"
,
]
]
apex/transformer/tensor_parallel/cross_entropy.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,10 +14,10 @@
...
@@ -14,10 +14,10 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
.
.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_world_size
from
.utils
import
VocabUtility
from
apex.transformer.tensor_parallel
.utils
import
VocabUtility
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
...
@@ -30,7 +30,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -30,7 +30,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_tensor_model_parallel_group
()
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_tensor_model_parallel_group
()
)
)
# Subtract the maximum value.
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
)
)
vocab_parallel_logits
=
vocab_parallel_logits
-
logits_max
.
unsqueeze
(
dim
=-
1
)
# Get the partition's vocab indecies
# Get the partition's vocab indecies
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
...
@@ -100,4 +100,4 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -100,4 +100,4 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
def
vocab_parallel_cross_entropy
(
vocab_parallel_logits
,
target
):
def
vocab_parallel_cross_entropy
(
vocab_parallel_logits
,
target
):
"""Helper function for the cross entropy."""
"""Helper function for the cross entropy."""
return
_VocabParallelCrossEntropy
.
apply
(
torch
.
clone
(
vocab_parallel_logits
)
,
target
)
return
_VocabParallelCrossEntropy
.
apply
(
vocab_parallel_logits
,
target
)
apex/transformer/tensor_parallel/data.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
.
.parallel_state
import
get_tensor_model_parallel_src_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_src_rank
_MAX_DATA_DIM
=
5
_MAX_DATA_DIM
=
5
...
...
apex/transformer/tensor_parallel/layers.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -21,17 +21,17 @@ import torch.nn.functional as F
...
@@ -21,17 +21,17 @@ import torch.nn.functional as F
import
torch.nn.init
as
init
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
..
._autocast_utils
import
_cast_if_autocast_enabled
from
apex
._autocast_utils
import
_cast_if_autocast_enabled
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
.
.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_world_size
from
.mappings
import
copy_to_tensor_model_parallel_region
from
apex.transformer.utils
import
divide
from
.mappings
import
gather_from
_tensor_model_parallel_region
from
apex.transformer.tensor_parallel
.mappings
import
copy_to
_tensor_model_parallel_region
from
.mappings
import
reduce
_from_tensor_model_parallel_region
from
apex.transformer.tensor_parallel
.mappings
import
gather
_from_tensor_model_parallel_region
from
.mappings
import
scatter_to
_tensor_model_parallel_region
from
apex.transformer.tensor_parallel
.mappings
import
reduce_from
_tensor_model_parallel_region
from
.random
import
get_cuda_rng_tracker
from
apex.transformer.tensor_parallel.mappings
import
scatter_to_tensor_model_parallel_region
from
.utils
import
divide
from
apex.transformer.tensor_parallel.random
import
get_cuda_rng_tracker
from
.utils
import
VocabUtility
from
apex.transformer.tensor_parallel
.utils
import
VocabUtility
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
...
...
apex/transformer/tensor_parallel/mappings.py
View file @
79906517
# coding=utf-8
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,10 +14,10 @@
...
@@ -14,10 +14,10 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_world_size
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
.utils
import
split_tensor_along_last_dim
from
apex.transformer.tensor_parallel
.utils
import
split_tensor_along_last_dim
def
_reduce
(
input_
):
def
_reduce
(
input_
):
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment