Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
db92ee13
Unverified
Commit
db92ee13
authored
Dec 14, 2021
by
Jithun Nair
Committed by
GitHub
Dec 14, 2021
Browse files
Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08
IFU-master-2021-12-08
parents
d150afdc
68364b49
Changes
98
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1886 additions
and
48 deletions
+1886
-48
apex/transformer/amp/__init__.py
apex/transformer/amp/__init__.py
+6
-0
apex/transformer/amp/grad_scaler.py
apex/transformer/amp/grad_scaler.py
+106
-0
apex/transformer/enums.py
apex/transformer/enums.py
+1
-1
apex/transformer/functional/__init__.py
apex/transformer/functional/__init__.py
+1
-1
apex/transformer/functional/fused_softmax.py
apex/transformer/functional/fused_softmax.py
+2
-2
apex/transformer/log_util.py
apex/transformer/log_util.py
+19
-0
apex/transformer/microbatches.py
apex/transformer/microbatches.py
+23
-15
apex/transformer/parallel_state.py
apex/transformer/parallel_state.py
+57
-7
apex/transformer/pipeline_parallel/__init__.py
apex/transformer/pipeline_parallel/__init__.py
+8
-0
apex/transformer/pipeline_parallel/_timers.py
apex/transformer/pipeline_parallel/_timers.py
+83
-0
apex/transformer/pipeline_parallel/p2p_communication.py
apex/transformer/pipeline_parallel/p2p_communication.py
+404
-0
apex/transformer/pipeline_parallel/schedules/__init__.py
apex/transformer/pipeline_parallel/schedules/__init__.py
+39
-0
apex/transformer/pipeline_parallel/schedules/common.py
apex/transformer/pipeline_parallel/schedules/common.py
+218
-0
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
...rmer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
+91
-0
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
...arallel/schedules/fwd_bwd_pipelining_with_interleaving.py
+308
-0
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
...llel/schedules/fwd_bwd_pipelining_without_interleaving.py
+170
-0
apex/transformer/pipeline_parallel/utils.py
apex/transformer/pipeline_parallel/utils.py
+333
-0
apex/transformer/tensor_parallel/__init__.py
apex/transformer/tensor_parallel/__init__.py
+6
-11
apex/transformer/tensor_parallel/cross_entropy.py
apex/transformer/tensor_parallel/cross_entropy.py
+7
-7
apex/transformer/tensor_parallel/data.py
apex/transformer/tensor_parallel/data.py
+4
-4
No files found.
apex/transformer/amp/__init__.py
0 → 100644
View file @
db92ee13
from
apex.transformer.amp.grad_scaler
import
GradScaler
__all__
=
[
"GradScaler"
,
]
apex/transformer/amp/grad_scaler.py
0 → 100644
View file @
db92ee13
from
collections
import
defaultdict
import
torch
from
apex.transformer
import
parallel_state
class
GradScaler
(
torch
.
cuda
.
amp
.
GradScaler
):
"""
Gradient scaler for model-parallel inf check. The inf in gradients are checked across tensor-parallel
ranks in (1) executing optimizer step and (2) gradient scaler update.
"""
def
__init__
(
self
,
init_scale
=
2.0
**
16
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
2000
,
enabled
=
True
):
super
().
__init__
(
init_scale
=
init_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
enabled
=
enabled
,
)
def
_maybe_opt_step
(
self
,
optimizer
,
optimizer_state
,
*
args
,
**
kwargs
):
retval
=
None
found_inf
=
torch
.
cuda
.
FloatTensor
([
sum
(
v
.
item
()
for
v
in
optimizer_state
[
"found_inf_per_device"
].
values
())])
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
parallel_state
.
get_model_parallel_group
()
)
if
found_inf
.
item
()
==
0
:
retval
=
optimizer
.
step
(
*
args
,
**
kwargs
)
return
retval
def
update
(
self
,
new_scale
=
None
):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if
not
self
.
_enabled
:
return
_scale
,
_growth_tracker
=
self
.
_check_scale_growth_tracker
(
"update"
)
if
new_scale
is
not
None
:
# Accept a new user-defined scale.
if
isinstance
(
new_scale
,
float
):
self
.
_scale
.
fill_
(
new_scale
)
# type: ignore[union-attr]
else
:
reason
=
"new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert
isinstance
(
new_scale
,
torch
.
cuda
.
FloatTensor
),
reason
# type: ignore[attr-defined]
assert
new_scale
.
numel
()
==
1
,
reason
assert
new_scale
.
requires_grad
is
False
,
reason
self
.
_scale
.
copy_
(
new_scale
)
# type: ignore[union-attr]
else
:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs
=
[
found_inf
.
to
(
device
=
_scale
.
device
,
non_blocking
=
True
)
for
state
in
self
.
_per_optimizer_states
.
values
()
for
found_inf
in
state
[
"found_inf_per_device"
].
values
()
]
assert
len
(
found_infs
)
>
0
,
"No inf checks were recorded prior to update."
found_inf_combined
=
found_infs
[
0
]
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
found_inf_combined
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
parallel_state
.
get_model_parallel_group
()
)
if
len
(
found_infs
)
>
1
:
for
i
in
range
(
1
,
len
(
found_infs
)):
found_inf
=
found_infs
[
i
]
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
parallel_state
.
get_model_parallel_group
()
)
found_inf_combined
+=
found_inf
torch
.
_amp_update_scale_
(
_scale
,
_growth_tracker
,
found_inf_combined
,
self
.
_growth_factor
,
self
.
_backoff_factor
,
self
.
_growth_interval
,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self
.
_per_optimizer_states
=
defaultdict
(
torch
.
cuda
.
amp
.
grad_scaler
.
_refresh_per_optimizer_state
)
apex/transformer/enums.py
View file @
db92ee13
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
apex/transformer/functional/__init__.py
View file @
db92ee13
from
.fused_softmax
import
FusedScaleMaskSoftmax
from
apex.transformer.functional
.fused_softmax
import
FusedScaleMaskSoftmax
__all__
=
[
"FusedScaleMaskSoftmax"
,
...
...
apex/transformer/functional/fused_softmax.py
View file @
db92ee13
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,7 +15,7 @@
import
torch
from
apex._autocast_utils
import
_cast_if_autocast_enabled
from
.
.enums
import
AttnMaskType
from
apex.transformer
.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
...
...
apex/transformer/log_util.py
0 → 100644
View file @
db92ee13
from
typing
import
Optional
import
logging
import
os
import
threading
def
get_transformer_logger
(
name
:
str
)
->
logging
.
Logger
:
name_wo_ext
=
os
.
path
.
splitext
(
name
)[
0
]
return
logging
.
getLogger
(
name_wo_ext
)
def
set_logging_level
(
verbosity
)
->
None
:
"""Change logging severity.
Args:
verbosity
"""
from
apex
import
_library_root_logger
_library_root_logger
.
setLevel
(
verbosity
)
apex/transformer/
tensor_parallel/
microbatches.py
→
apex/transformer/microbatches.py
View file @
db92ee13
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,35 +15,41 @@
"""Megatron number of micro-batches calculators."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
typing
import
Optional
,
List
def
build_num_microbatches_calculator
(
args
):
def
build_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
):
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
if
rampup_batch_size
is
None
:
num_microbatches_calculator
=
ConstantNumMicroBatches
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
if
args
.
rank
==
0
:
if
rank
==
0
:
print
(
"setting number of micro-batches to constant {}"
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
)
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
(
assert
len
(
rampup_batch_size
)
==
3
,
(
"expected the following "
"format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>"
)
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
start_batch_size
=
int
(
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
rampup_batch_size
[
1
])
ramup_samples
=
int
(
rampup_batch_size
[
2
])
if
rank
==
0
:
print
(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples."
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
start_batch_size
,
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
,
)
...
...
@@ -51,9 +57,9 @@ def build_num_microbatches_calculator(args):
start_batch_size
,
batch_size_increment
,
ramup_samples
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
)
return
num_microbatches_calculator
...
...
@@ -86,6 +92,8 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
assert
self
.
num_micro_batches
>=
1
self
.
current_global_batch_size
=
global_batch_size
self
.
micro_batch_size
=
micro_batch_size
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
...
...
apex/transformer/parallel_state.py
View file @
db92ee13
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
from
typing
import
Tuple
import
torch
# TODO (mkozuki): Consider dissecting utils as this utils import is here
# only for ensure_divisibility
from
.te
nsor
_parallel
import
utils
from
apex.tra
ns
f
or
mer.utils
import
ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
...
...
@@ -40,6 +42,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS
=
None
...
...
@@ -76,23 +81,26 @@ def initialize_model_parallel(
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size_
))
print
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size_
))
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
min
(
tensor_model_parallel_size_
,
world_size
)
pipeline_model_parallel_size
=
min
(
pipeline_model_parallel_size_
,
world_size
)
# TODO (mkozuki): Consider moving `ensure_divisibility` to this file.
utils
.
ensure_divisibility
(
world_size
,
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
ensure_divisibility
(
world_size
,
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing tensor model parallel with size {}"
.
format
(
tensor_model_parallel_size
))
print
(
"> initializing pipeline model parallel with size {}"
.
format
(
pipeline_model_parallel_size
))
print
(
"> initializing data parallel with size {}"
.
format
(
data_parallel_size
))
num_tensor_model_parallel_groups
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
assert
pipeline_model_parallel_size_
>
2
,
\
'pipeline-model-parallel size should be greater than 2 with '
\
'interleaved schedule'
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
...
...
@@ -138,6 +146,7 @@ def initialize_model_parallel(
global
_PIPELINE_GLOBAL_RANKS
assert
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
,
"pipeline model parallel group is already initialized"
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
"embedding group is already initialized"
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
...
...
@@ -154,6 +163,19 @@ def initialize_model_parallel(
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
def
get_rank_info
()
->
Tuple
[
int
,
int
,
int
]:
"""Returns a tuple of (tensor, pipeline, data)-parallel-rank for logger."""
if
model_parallel_is_initialized
():
return
(
get_tensor_model_parallel_rank
(),
get_pipeline_model_parallel_rank
(),
# get_virtual_pipeline_model_parallel_rank(),
get_data_parallel_rank
(),
)
return
(
0
,
0
,
0
)
def
model_parallel_is_initialized
():
...
...
@@ -193,6 +215,22 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_EMBEDDING_GLOBAL_RANKS
if
ignore_virtual
:
return
rank
in
_EMBEDDING_GLOBAL_RANKS
if
rank
in
_EMBEDDING_GLOBAL_RANKS
:
if
rank
==
_EMBEDDING_GLOBAL_RANKS
[
0
]:
return
is_pipeline_first_stage
(
ignore_virtual
=
False
)
elif
rank
==
_EMBEDDING_GLOBAL_RANKS
[
-
1
]:
return
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
:
return
True
return
False
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
...
@@ -344,3 +382,15 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
apex/transformer/pipeline_parallel/__init__.py
0 → 100644
View file @
db92ee13
from
apex.transformer.pipeline_parallel.schedules
import
get_forward_backward_func
from
apex.transformer.pipeline_parallel.schedules.common
import
build_model
__all__
=
[
"get_forward_backward_func"
,
"build_model"
,
]
apex/transformer/pipeline_parallel/_timers.py
0 → 100644
View file @
db92ee13
import
time
import
torch
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
"timer has already been started"
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
"timer is not started"
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
time
.
time
()
-
self
.
start_time
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
_Timers
:
"""Group of timers."""
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
_Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
"-time"
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
"time (ms)"
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
" | {}: {:.2f}"
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
apex/transformer/pipeline_parallel/p2p_communication.py
0 → 100644
View file @
db92ee13
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
operator
from
typing
import
Union
,
Optional
,
Tuple
import
warnings
import
torch
from
apex._autocast_utils
import
_get_current_dtype
from
apex.transformer
import
parallel_state
from
apex.transformer.utils
import
split_tensor_into_1d_equal_chunks
from
apex.transformer.utils
import
gather_split_1d_tensor
from
apex.transformer.pipeline_parallel.utils
import
Shape
from
apex.transformer.pipeline_parallel._timers
import
_Timers
def
_run_p2pops
(
tensor_send_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_send_next
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_prev
:
Union
[
torch
.
Tensor
,
None
],
tensor_recv_next
:
Union
[
torch
.
Tensor
,
None
],
):
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
parallel_state
.
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
def
_communicate
(
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Optional
[
Shape
]
=
None
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
dtype_
:
torch
.
dtype
=
torch
.
float
,
*
,
scatter_gather_tensors_in_pipeline
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
fp32_residual_connection
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
Union
[
torch
.
Tensor
,
None
]]:
"""Base function for communication of tensors between stages.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
recv_prev: boolean for whether tensor should be received from previous rank.
recv_next: boolean for whether tensor should be received from next rank.
tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline:
optional, this is used when tensor_shape is provided to override scatter gather tensors
dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape
Keyword args:
scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors.
params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on
your model deliberately, pass this argument.
fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32.
Returns:
tuple containing
- tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise.
- tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
if
tensor_shape
is
None
:
# In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)`
raise
RuntimeError
(
"`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`"
)
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
(
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
parallel_state
.
get_tensor_model_parallel_world_size
(),)
else
:
tensor_chunk_shape
=
tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if
dtype_
!=
torch
.
float32
or
params_dtype
is
not
None
:
if
torch
.
distributed
.
get_rank
()
==
0
:
warnings
.
warn
(
"Tensor P2P communications are executed in FP32"
)
dtype
=
torch
.
float32
requires_grad
=
True
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
,
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
,
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops
(
tensor_send_prev
,
tensor_send_next
,
tensor_recv_prev
,
tensor_recv_next
)
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
(
gather_split_1d_tensor
(
tensor_recv_prev
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
if
recv_next
:
tensor_recv_next
=
(
gather_split_1d_tensor
(
tensor_recv_next
)
.
view
(
tensor_shape
)
.
requires_grad_
()
)
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
:
Shape
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
parallel_state
.
is_pipeline_first_stage
():
return
None
if
timers
is
not
None
:
timers
(
"forward-recv"
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-recv"
).
stop
()
return
input_tensor
def
recv_backward
(
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
parallel_state
.
is_pipeline_last_stage
():
return
None
if
timers
is
not
None
:
timers
(
"backward-recv"
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-recv"
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
override_scatter_gather_tensors_in_pipeline
:
bool
=
False
,
tensor_shape
:
Shape
=
None
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Send tensor to next rank in pipeline (forward send)."""
if
parallel_state
.
is_pipeline_last_stage
():
return
if
timers
is
not
None
:
timers
(
"forward-send"
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
override_scatter_gather_tensors_in_pipeline
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-send"
).
stop
()
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Send tensor to previous rank in pipeline (backward send)."""
if
parallel_state
.
is_pipeline_first_stage
():
return
if
timers
is
not
None
:
timers
(
"backward-send"
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-send"
).
stop
()
def
send_forward_recv_backward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
None
:
"""Batched send and recv with next rank in pipeline."""
if
parallel_state
.
is_pipeline_last_stage
():
return
None
if
timers
is
not
None
:
timers
(
"forward-send-backward-recv"
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-send-backward-recv"
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Batched send and recv with previous rank in pipeline."""
if
parallel_state
.
is_pipeline_first_stage
():
return
None
if
timers
is
not
None
:
timers
(
"backward-send-forward-recv"
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-send-forward-recv"
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
recv_prev
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
"forward-send-forward-recv"
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-send-forward-recv"
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
:
torch
.
Tensor
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float
,
timers
:
_Timers
=
None
,
)
->
torch
.
Tensor
:
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
"backward-send-backward-recv"
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"backward-send-backward-recv"
).
stop
()
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
timers
:
_Timers
=
None
,
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
"forward-backward-send-forward-backward-recv"
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
dtype_
=
_get_current_dtype
(
dtype
),
)
if
timers
is
not
None
:
timers
(
"forward-backward-send-forward-backward-recv"
).
stop
()
return
input_tensor
,
output_tensor_grad
apex/transformer/pipeline_parallel/schedules/__init__.py
0 → 100644
View file @
db92ee13
import
warnings
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining
import
forward_backward_no_pipelining
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving
import
_forward_backward_pipelining_with_interleaving
from
apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving
import
(
forward_backward_pipelining_without_interleaving
,
)
class
ExperimentalWarning
(
Warning
):
pass
def
get_forward_backward_func
(
virtual_pipeline_model_parallel_size
,
pipeline_model_parallel_size
,
):
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
get_num_microbatches
()
%
pipeline_model_parallel_size
!=
0
:
msg
=
"number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule"
raise
RuntimeError
(
msg
)
warnings
.
warn
(
"Pipeline Model Parallel with interleaving scheduling is experimental. "
f
"To use Pipeline Parallel without interleaving, set `virtual_pipeline_model_parallel_size` to `None`:
{
virtual_pipeline_model_parallel_size
}
"
,
ExperimentalWarning
)
forward_backward_func
=
_forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
__all__
=
[
"get_forward_backward_func"
,
]
apex/transformer/pipeline_parallel/schedules/common.py
0 → 100644
View file @
db92ee13
# NOTE (mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
,
Optional
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
unwrap_model
from
apex.transformer.tensor_parallel.layers
import
set_defaults_if_not_set_tensor_model_parallel_attributes
Batch
=
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
...]]
LossFunc
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
FwdStepFunc
=
Callable
[[
Batch
,
torch
.
nn
.
Module
],
Tuple
[
torch
.
Tensor
,
LossFunc
]]
def
build_model
(
model_provider_func
:
Callable
[[
Any
,
Dict
[
str
,
Any
]],
torch
.
nn
.
Module
],
wrap_with_ddp
:
bool
=
True
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
*
args
,
**
kwargs
)
->
List
[
torch
.
nn
.
Module
]:
"""Build the model satisfying pipeline model parallel requirements.
This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to
`model_provider_func`.
Args:
model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`.
wrap_with_ddp: If :obj:`True`, wrap the instantiated model
with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`.
virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel.
*args: arguments for model provider func
**kwargs: Keyword arguments for model provider func
Returns:
a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None,
the list has multiple models, otherwise one.
"""
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
virtual_pipeline_model_parallel_size
is
not
None
):
model
=
[]
for
i
in
range
(
virtual_pipeline_model_parallel_size
):
cur_args
=
args
cur_kwargs
=
kwargs
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
i
)
# Set pre_process and post_process only after virtual rank is set.
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
({
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
})
this_model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
model
.
append
(
this_model
)
else
:
cur_args
=
args
cur_kwargs
=
kwargs
pre_process
=
parallel_state
.
is_pipeline_first_stage
()
post_process
=
parallel_state
.
is_pipeline_last_stage
()
cur_kwargs
.
update
({
"pre_process"
:
pre_process
,
"post_process"
:
post_process
,
})
model
=
model_provider_func
(
*
cur_args
,
**
cur_kwargs
)
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
parallel_state
.
get_data_parallel_rank
()
==
0
:
msg
=
" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}"
.
format
(
parallel_state
.
get_tensor_model_parallel_rank
(),
parallel_state
.
get_pipeline_model_parallel_rank
(),
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])
)
print
(
msg
,
flush
=
True
)
# GPU allocation.
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
if
wrap_with_ddp
:
i
=
torch
.
cuda
.
current_device
()
model
=
[
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
(
model_module
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
parallel_state
.
get_data_parallel_group
(),
)
for
model_module
in
model
]
return
model
def
_get_params_for_weight_decay_optimization
(
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
)
->
Dict
[
str
,
torch
.
nn
.
Parameter
]:
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
modules
=
listify_model
(
model
)
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
# NOQA
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
FusedLayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
def
forward_step
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
torch
.
nn
.
Module
,
input_tensor
:
Optional
[
torch
.
Tensor
],
losses_reduced
:
List
[
torch
.
Tensor
],
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor.
Args:
forward_step_func: Model specific function. This takes a minibatch and model as its arguments and
returns the model's output and the loss function.
batch: minibatch
model: unwrappable model
input_tensor:
losses_reduced:
Returns:
output_tensor
"""
# timers = get_timers()
# timers("forward-compute").start()
unwrapped_model
=
unwrap_model
(
model
)
# NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`.
# See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA
# for the details of `set_input_tensor`.
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
batch
,
model
)
# print(f"forward_step| pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()} is_pipeline_last_stage?: {parallel_state.is_pipeline_last_stage()}")
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
# timers("forward-compute").stop()
return
output_tensor
def
backward_step
(
input_tensor
:
Optional
[
torch
.
Tensor
],
output_tensor
:
torch
.
Tensor
,
output_tensor_grad
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage).
Args:
input_tensor:
output_tensor:
output_tensor_grad:
Returns:
input_tensor_grad
"""
# timers = get_timers()
# timers("backward-compute").start()
# Retain the grad on the input_tensor.
# if parallel_state.get_pipeline_model_parallel_rank() == 0:
# print(f"{input_tensor}, {output_tensor}, {output_tensor_grad}")
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
# if output_tensor_grad is None:
# output_tensor = optimizer.scale_loss(output_tensor)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
# timers("backward-compute").stop()
return
input_tensor_grad
apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py
0 → 100644
View file @
db92ee13
from
contextlib
import
contextmanager
from
typing
import
List
,
Union
import
torch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.log_util
import
get_transformer_logger
_all__
=
[
"forward_backward_no_pipelining"
]
_logger
=
get_transformer_logger
(
__name__
)
@
contextmanager
def
placeholder_handler
():
try
:
yield
finally
:
pass
def
forward_backward_no_pipelining
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
**
kwargs
,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A List of torch.Tensors
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
model
=
listify_model
(
model
)
if
len
(
model
)
!=
1
:
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
context_handler
=
placeholder_handler
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
distributed
.
DistributedDataParallel
):
context_handler
=
model
.
no_sync
losses_reduced
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
num_micro_batches
=
get_num_microbatches
()
with
context_handler
():
for
i
in
range
(
num_micro_batches
-
1
):
_logger
.
info
(
f
"Iter
{
i
}
of
{
num_micro_batches
-
1
}
"
)
cur_micro_batch
=
get_kth_microbatch
(
batch
,
i
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_micro_batch
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger
.
info
(
"Cooldown"
)
_logger
.
debug
(
"Call `forward_step`"
)
output_tensor
=
forward_step
(
forward_step_func
,
get_kth_microbatch
(
batch
,
num_micro_batches
-
1
),
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
_logger
.
debug
(
"Call `backward_step`"
)
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py
0 → 100644
View file @
db92ee13
from
typing
import
List
,
Union
,
Optional
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.log_util
import
get_transformer_logger
__all__
=
[
"_forward_backward_pipelining_with_interleaving"
]
_logger
=
get_transformer_logger
(
__name__
)
# TODO (mkozuki): Reduce cyclomatic complexity
def
_forward_backward_pipelining_with_interleaving
(
forward_step_func
:
FwdStepFunc
,
batch
:
List
[
Batch
],
model
:
List
[
torch
.
nn
.
Module
],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively.
This means that model is split into model chunks.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown
Note that if `forward_only` this scheduling consists of only warmup phase.
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
if
not
isinstance
(
model
,
list
):
raise
RuntimeError
(
"`model` must be a list of `nn.Module`'s'"
)
num_model_chunks
=
len
(
model
)
input_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
output_tensors
=
[[]
for
_
in
range
(
num_model_chunks
)]
curr_iters
=
[
0
for
_
in
range
(
num_model_chunks
)]
losses_reduced
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
num_model_chunks
)]
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
get_num_microbatches
()
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
f
"num_warmup_microbatches:
{
num_warmup_microbatches
}
, "
f
"num_microbatches_remaining:
{
num_microbatches_remaining
}
"
)
###################################################################################################################
# Helper function definitions.
###################################################################################################################
def
get_model_chunk_id
(
microbatch_id
:
int
,
forward
:
bool
)
->
int
:
"""Helper function to get the model chunk ID given the iteration number."""
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
,
curr_iters
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# forward step
if
(
parallel_state
.
is_pipeline_first_stage
()
and
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
])
):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
get_kth_microbatch
(
batch
,
curr_iters
[
model_chunk_id
]),
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
,
)
curr_iters
[
model_chunk_id
]
+=
1
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
))
_logger
.
info
(
"Warmup phase"
)
for
k
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
k
}
/
{
num_warmup_microbatches
}
"
)
output_tensor
=
forward_step_helper
(
k
,
curr_iters
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
_logger
.
debug
(
f
"next fwd model chunk ID:
{
next_forward_model_chunk_id
}
, recv_prev:
{
recv_prev
}
"
)
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
():
_logger
.
debug
(
"Pipeline last stage, not sending tensor downstream"
)
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
:
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
_logger
.
debug
(
"send fwd&bwd and receive fwd&bwd"
)
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
_logger
.
debug
(
"send fwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
_logger
.
info
(
"Steady phase"
)
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
_logger
.
debug
(
f
" steady phase iter
{
k
}
/
{
num_microbatches_remaining
}
"
)
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
,
curr_iters
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
_logger
.
debug
(
f
"fwd/bwd model chunk id:
{
forward_model_chunk_id
}
/
{
backward_model_chunk_id
}
"
)
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Communicate tensors.
_logger
.
debug
(
"send fwd&bwd and receive fwd&bwd"
)
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
###################################################################################################################
# Run cooldown backward passes (flush out pipeline).
###################################################################################################################
_logger
.
info
(
"Cooldown phase"
)
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
_logger
.
debug
(
f
"cooldown iter
{
k
}
in range(
{
num_microbatches_remaining
}
,
{
num_microbatches
}
)"
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
)
return
losses_reduced
apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py
0 → 100644
View file @
db92ee13
from
typing
import
Union
,
List
,
Optional
import
torch
from
apex.transformer
import
parallel_state
from
apex.transformer.pipeline_parallel
import
p2p_communication
from
apex.transformer.pipeline_parallel.utils
import
get_kth_microbatch
from
apex.transformer.pipeline_parallel.utils
import
listify_model
from
apex.transformer.pipeline_parallel.utils
import
get_num_microbatches
from
apex.transformer.pipeline_parallel.schedules.common
import
Batch
,
FwdStepFunc
from
apex.transformer.pipeline_parallel.schedules.common
import
forward_step
from
apex.transformer.pipeline_parallel.schedules.common
import
backward_step
from
apex.transformer.log_util
import
get_transformer_logger
__all__
=
[
"forward_backward_pipelining_without_interleaving"
]
_logger
=
get_transformer_logger
(
__name__
)
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
:
FwdStepFunc
,
batch
:
Batch
,
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
*
,
forward_only
:
bool
,
tensor_shape
:
Optional
[
Union
[
List
[
int
],
torch
.
Size
]]
=
None
,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
This pipeline parallel scheduling consists of three steps:
1. warmup
2. 1F1B a.k.a. steady state
3. cooldown if not forward_only
Args:
forward_step_func: A function which takes a minibatch and model as its arguments and
returns model's forward output and the loss function.
The loss function is supposed to take one `torch.Tensor` and
return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`.
batch: A minibatch, i.e., a list of `torch.Tensor`'s.
model: A `torch.nn.Module` or a list of `torch.nn.Module`.
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
"""
# timers = get_timers()
model
=
listify_model
(
model
)
if
len
(
model
)
!=
1
:
msg
=
f
"`model` is expected be a `nn.Module`, but
{
type
(
model
)
}
"
raise
RuntimeError
(
msg
)
model
=
model
[
0
]
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
_logger
.
info
(
f
"num_microbatches:
{
num_microbatches
}
, "
f
"num_warmup_microbatches:
{
num_warmup_microbatches
}
, "
f
"num_microbatches_remaining:
{
num_microbatches_remaining
}
"
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
output_tensors
=
None
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
###################################################################################################################
# Run warmup forward passes.
###################################################################################################################
_logger
.
info
(
"Warmup"
)
for
i
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"warmup iter:
{
i
}
/
{
num_warmup_microbatches
}
"
)
_logger
.
debug
(
"receive fwd"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
_logger
.
debug
(
"recv_forward before steady state start"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
###################################################################################################################
# Run 1F1B in steady state.
###################################################################################################################
_logger
.
info
(
"Steady phase"
)
for
i
in
range
(
num_microbatches_remaining
):
_logger
.
debug
(
f
"steady iter:
{
i
}
/
{
num_microbatches_remaining
}
"
)
last_iteration
=
i
==
(
num_microbatches_remaining
-
1
)
cur_microbatch
=
get_kth_microbatch
(
batch
,
i
+
num_warmup_microbatches
)
output_tensor
=
forward_step
(
forward_step_func
,
cur_microbatch
,
model
,
input_tensor
,
losses_reduced
)
if
forward_only
:
_logger
.
debug
(
"send fwd"
)
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
if
not
last_iteration
:
_logger
.
debug
(
"receive fwd (last iteration)"
)
input_tensor
=
p2p_communication
.
recv_forward
(
tensor_shape
=
tensor_shape
)
else
:
_logger
.
debug
(
"send fwd & receive bwd"
)
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
tensor_shape
)
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
input_tensor
=
None
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
else
:
_logger
.
debug
(
"send bwd and receive fwd"
)
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
_logger
.
info
(
"Cooldown phase"
)
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
_logger
.
debug
(
f
"cooldown iter:
{
i
}
/
{
num_warmup_microbatches
}
"
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
_logger
.
debug
(
"receive bwd"
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
tensor_shape
=
tensor_shape
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
_logger
.
debug
(
"send bwd"
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
=
tensor_shape
)
return
losses_reduced
apex/transformer/pipeline_parallel/utils.py
0 → 100644
View file @
db92ee13
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for pipeline model parallel."""
from
typing
import
Optional
,
List
,
Union
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.transformer
import
parallel_state
from
apex.transformer.microbatches
import
build_num_microbatches_calculator
from
apex.transformer.pipeline_parallel._timers
import
_Timers
if
multi_tensor_applier
.
available
:
import
amp_C
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_TOKENIZER
=
None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
listify_model
(
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]])
->
List
[
torch
.
nn
.
Module
]:
if
isinstance
(
model
,
list
):
return
model
return
[
model
]
def
_ensure_var_is_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
not
None
,
"{} is not initialized."
.
format
(
name
)
def
_ensure_var_is_not_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
None
,
"{} is already initialized."
.
format
(
name
)
def
setup_microbatch_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
)
->
None
:
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'num microbatches calculator'
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
def
_reconfigure_microbatch_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
)
->
None
:
if
torch
.
distributed
.
get_rank
()
==
0
:
import
warnings
warnings
.
warn
(
"This function is only for unittest"
)
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
def
get_micro_batch_size
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
micro_batch_size
def
get_num_microbatches
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
get_current_global_batch_size
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
,
consistency_check
=
True
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
)
# note (mkozuki): Comment out in favor of `get_kth_microbatch`
def
_split_batch_into_microbatch
(
batch
:
List
[
torch
.
Tensor
],
*
,
_micro_batch_size
:
Optional
[
int
]
=
None
,
_global_batch_size
:
Optional
[
int
]
=
None
,
)
->
List
[
List
[
torch
.
Tensor
]]:
micro_batch_size
=
_micro_batch_size
global_batch_size
=
_global_batch_size
if
micro_batch_size
is
None
:
micro_batch_size
=
get_micro_batch_size
()
if
global_batch_size
is
None
:
global_batch_size
=
get_current_global_batch_size
()
for
i
in
range
(
0
,
global_batch_size
,
micro_batch_size
):
yield
[
x
[
i
*
micro_batch_size
:(
i
+
1
)
*
micro_batch_size
]
for
x
in
batch
]
# TODO(mkozuki): Support non-tensor local minibatches?
def
get_kth_microbatch
(
batch
:
List
[
torch
.
Tensor
],
k
:
int
)
->
List
[
torch
.
Tensor
]:
"""Create a list of microbatches from a list of local minibatches.
This function creates a list of `k`th microbatches from a list of local minibatches.
`a local minibatch` consists of `global_batch_size / data_parallel_size` samples.
"""
micro_batch_size
=
get_micro_batch_size
()
return
[
x
[
k
*
micro_batch_size
:(
k
+
1
)
*
micro_batch_size
]
for
x
in
batch
]
def
get_autoresume
():
return
_GLOBAL_AUTORESUME
def
_set_timers
():
"""Initialize timers."""
global
_GLOBAL_TIMERS
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
"timers"
)
_GLOBAL_TIMERS
=
_Timers
()
def
get_timers
():
"""Return timers."""
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
"timers"
)
return
_GLOBAL_TIMERS
def
print_rank_0
(
message
:
str
)
->
None
:
"""If distributed is initialized, print only on rank 0."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
is_last_rank
():
return
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
)
def
print_rank_last
(
message
):
"""If distributed is initialized, print only on last rank."""
if
torch
.
distributed
.
is_initialized
():
if
is_last_rank
():
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
param_is_not_shared
(
param
:
torch
.
nn
.
Parameter
)
->
bool
:
return
getattr
(
param
,
"shared"
,
False
)
def
unwrap_model
(
model
,
module_instances
=
(
DistributedDataParallel
,)):
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
return_list
=
False
unwrapped_model
=
[]
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
calc_params_l2_norm
(
model
:
torch
.
nn
.
Module
,
bf16
:
bool
):
"""Calculate l2 norm of parameters """
# args = get_args()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Remove duplicate params.
params_data
=
[]
for
model_
in
model
:
for
param
in
model_
.
parameters
():
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
parallel_state
.
param_is_not_tensor_parallel_duplicate
(
param
)
if
is_not_shared
and
is_not_tp_duplicate
:
if
bf16
:
params_data
.
append
(
param
.
data
.
float
())
else
:
params_data
.
append
(
param
.
data
)
# Calculate norm
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
params_data
],
False
# no per-parameter norm
)
norm_2
=
norm
*
norm
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
norm_2
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
parallel_state
.
get_model_parallel_group
()
)
return
norm_2
.
item
()
**
0.5
def
average_losses_across_data_parallel_group
(
losses
):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses
=
torch
.
cat
([
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
averaged_losses
,
group
=
parallel_state
.
get_data_parallel_group
())
averaged_losses
=
averaged_losses
/
torch
.
distributed
.
get_world_size
(
group
=
parallel_state
.
get_data_parallel_group
()
)
return
averaged_losses
def
report_memory
(
name
):
"""Simple GPU memory report."""
mega_bytes
=
1024.0
*
1024.0
string
=
name
+
" memory (MB)"
string
+=
" | allocated: {}"
.
format
(
torch
.
cuda
.
memory_allocated
()
/
mega_bytes
)
string
+=
" | max allocated: {}"
.
format
(
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
" | reserved: {}"
.
format
(
torch
.
cuda
.
memory_reserved
()
/
mega_bytes
)
string
+=
" | max reserved: {}"
.
format
(
torch
.
cuda
.
max_memory_reserved
()
/
mega_bytes
)
if
parallel_state
.
get_data_parallel_rank
()
==
0
:
print
(
"[Rank {}] {}"
.
format
(
torch
.
distributed
.
get_rank
(),
string
),
flush
=
True
)
def
print_params_min_max_norm
(
optimizer
,
iteration
):
"""Print min, max, and norm of all parameters."""
index
=
0
rank
=
torch
.
distributed
.
get_rank
()
string
=
"iteration, rank, index, tensor-model-parallel, min, max, norm
\n
"
optimizer_
=
optimizer
.
optimizer
for
param_group
in
optimizer_
.
param_groups
:
for
param
in
param_group
[
"params"
]:
index
+=
1
min_
=
param
.
data
.
min
()
max_
=
param
.
data
.
max
()
norm
=
torch
.
linalg
.
norm
(
param
.
data
)
string
+=
"{:7d}, {:4d}, {:4d}, {:2d}, "
.
format
(
iteration
,
rank
,
index
,
int
(
param
.
tensor_model_parallel
)
)
string
+=
"{:.6E}, {:.6E}, {:.6E}
\n
"
.
format
(
min_
,
max_
,
norm
)
print
(
string
,
flush
=
True
)
# NOTE (mkozuki): APEX doesn't have anything equivalent for
# `_GLOBAL_ADLR_AUTORESUME` like Megatron-LM.
# def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler, save: bool):
# """Check for autoresume signal and exit if it is received."""
# from apex.ppu.checkpointing import save_checkpoint
#
# autoresume = get_adlr_autoresume()
# # Add barrier to ensure consistency.
# torch.distributed.barrier()
# if autoresume.termination_requested():
# if save:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
# print_rank_0(">>> autoresume termination request found!")
# if torch.distributed.get_rank() == 0:
# autoresume.request_resume()
# print_rank_0(">>> training terminated. Returning")
# sys.exit(0)
def
get_ltor_masks_and_position_ids
(
data
,
eod_token
,
reset_position_ids
,
reset_attention_mask
,
eod_mask_loss
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask (lower triangular).
if
reset_attention_mask
:
att_mask_batch
=
micro_batch_size
else
:
att_mask_batch
=
1
attention_mask
=
torch
.
tril
(
torch
.
ones
((
att_mask_batch
,
seq_length
,
seq_length
),
device
=
data
.
device
)
).
view
(
att_mask_batch
,
1
,
seq_length
,
seq_length
)
# Loss mask.
loss_mask
=
torch
.
ones
(
data
.
size
(),
dtype
=
torch
.
float
,
device
=
data
.
device
)
if
eod_mask_loss
:
loss_mask
[
data
==
eod_token
]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# We need to clone as the ids will be modifed based on batch index.
if
reset_position_ids
:
position_ids
=
position_ids
.
clone
()
if
reset_position_ids
or
reset_attention_mask
:
# Loop through the batches:
for
b
in
range
(
micro_batch_size
):
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
data
[
b
]
==
eod_token
]
# Detach indecies from positions if going to modify positions.
if
reset_position_ids
:
eod_index
=
eod_index
.
clone
()
# Loop through EOD indecies:
prev_index
=
0
for
j
in
range
(
eod_index
.
size
()[
0
]):
i
=
eod_index
[
j
]
# Mask attention loss.
if
reset_attention_mask
:
attention_mask
[
b
,
0
,
(
i
+
1
)
:,
:
(
i
+
1
)]
=
0
# Reset positions.
if
reset_position_ids
:
position_ids
[
b
,
(
i
+
1
)
:]
-=
i
+
1
-
prev_index
prev_index
=
i
+
1
# Convert attention mask to binary:
attention_mask
=
attention_mask
<
0.5
return
attention_mask
,
loss_mask
,
position_ids
apex/transformer/tensor_parallel/__init__.py
View file @
db92ee13
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,11 +15,11 @@
"""Model parallel utility interface."""
from
.cross_entropy
import
vocab_parallel_cross_entropy
from
apex.transformer.tensor_parallel
.cross_entropy
import
vocab_parallel_cross_entropy
from
.data
import
broadcast_data
from
apex.transformer.tensor_parallel
.data
import
broadcast_data
from
.layers
import
(
from
apex.transformer.tensor_parallel
.layers
import
(
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
...
...
@@ -28,7 +28,7 @@ from .layers import (
copy_tensor_model_parallel_attributes
,
)
from
.mappings
import
(
from
apex.transformer.tensor_parallel
.mappings
import
(
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
...
...
@@ -41,11 +41,9 @@ from .random import (
init_checkpointed_activations_memory_buffer
,
model_parallel_cuda_manual_seed
,
reset_checkpointed_activations_memory_buffer
,
gather_split_1d_tensor
,
split_tensor_into_1d_equal_chunks
,
)
from
.utils
import
divide
,
split_tensor_along_last_dim
from
apex.transformer.tensor_parallel
.utils
import
split_tensor_along_last_dim
__all__
=
[
...
...
@@ -71,9 +69,6 @@ __all__ = [
"init_checkpointed_activations_memory_buffer"
,
"model_parallel_cuda_manual_seed"
,
"reset_checkpointed_activations_memory_buffer"
,
"gather_split_1d_tensor"
,
"split_tensor_into_1d_equal_chunks"
,
# utils.py
"divide"
,
"split_tensor_along_last_dim"
,
]
apex/transformer/tensor_parallel/cross_entropy.py
View file @
db92ee13
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,10 +14,10 @@
# limitations under the License.
import
torch
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
.
.parallel_state
import
get_tensor_model_parallel_world_size
from
.utils
import
VocabUtility
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_world_size
from
apex.transformer.tensor_parallel
.utils
import
VocabUtility
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
...
...
@@ -30,7 +30,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_tensor_model_parallel_group
()
)
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
)
)
vocab_parallel_logits
=
vocab_parallel_logits
-
logits_max
.
unsqueeze
(
dim
=-
1
)
# Get the partition's vocab indecies
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
...
...
@@ -100,4 +100,4 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
def
vocab_parallel_cross_entropy
(
vocab_parallel_logits
,
target
):
"""Helper function for the cross entropy."""
return
_VocabParallelCrossEntropy
.
apply
(
torch
.
clone
(
vocab_parallel_logits
)
,
target
)
return
_VocabParallelCrossEntropy
.
apply
(
vocab_parallel_logits
,
target
)
apex/transformer/tensor_parallel/data.py
View file @
db92ee13
# coding=utf-8
# Copyright (c) 202
0
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
1
, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,9 +14,9 @@
# limitations under the License.
import
torch
from
.
.parallel_state
import
get_tensor_model_parallel_group
from
.
.parallel_state
import
get_tensor_model_parallel_rank
from
.
.parallel_state
import
get_tensor_model_parallel_src_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_group
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_rank
from
apex.transformer
.parallel_state
import
get_tensor_model_parallel_src_rank
_MAX_DATA_DIM
=
5
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment