Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
megatron_qwen
Commits
0816dd4a
Commit
0816dd4a
authored
Sep 29, 2024
by
libo11
Browse files
Initial commit
parents
Pipeline
#1728
canceled with stages
Changes
343
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4521 additions
and
0 deletions
+4521
-0
megatron/core/optimizer/optimizer_config.py
megatron/core/optimizer/optimizer_config.py
+116
-0
megatron/core/package_info.py
megatron/core/package_info.py
+29
-0
megatron/core/packed_seq_params.py
megatron/core/packed_seq_params.py
+13
-0
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+1267
-0
megatron/core/pipeline_parallel/__init__.py
megatron/core/pipeline_parallel/__init__.py
+1
-0
megatron/core/pipeline_parallel/__pycache__/__init__.cpython-310.pyc
...re/pipeline_parallel/__pycache__/__init__.cpython-310.pyc
+0
-0
megatron/core/pipeline_parallel/__pycache__/p2p_communication.cpython-310.pyc
...ne_parallel/__pycache__/p2p_communication.cpython-310.pyc
+0
-0
megatron/core/pipeline_parallel/__pycache__/schedules.cpython-310.pyc
...e/pipeline_parallel/__pycache__/schedules.cpython-310.pyc
+0
-0
megatron/core/pipeline_parallel/p2p_communication.py
megatron/core/pipeline_parallel/p2p_communication.py
+597
-0
megatron/core/pipeline_parallel/schedules.py
megatron/core/pipeline_parallel/schedules.py
+1407
-0
megatron/core/requirements.txt
megatron/core/requirements.txt
+2
-0
megatron/core/ssm/__init__.py
megatron/core/ssm/__init__.py
+0
-0
megatron/core/ssm/mamba_block.py
megatron/core/ssm/mamba_block.py
+234
-0
megatron/core/ssm/mamba_hybrid_layer_allocation.py
megatron/core/ssm/mamba_hybrid_layer_allocation.py
+191
-0
megatron/core/ssm/mamba_layer.py
megatron/core/ssm/mamba_layer.py
+62
-0
megatron/core/ssm/mamba_mixer.py
megatron/core/ssm/mamba_mixer.py
+485
-0
megatron/core/ssm/triton_cache_manager.py
megatron/core/ssm/triton_cache_manager.py
+44
-0
megatron/core/tensor_parallel/__init__.py
megatron/core/tensor_parallel/__init__.py
+73
-0
megatron/core/tensor_parallel/__pycache__/__init__.cpython-310.pyc
...core/tensor_parallel/__pycache__/__init__.cpython-310.pyc
+0
-0
megatron/core/tensor_parallel/__pycache__/cross_entropy.cpython-310.pyc
...tensor_parallel/__pycache__/cross_entropy.cpython-310.pyc
+0
-0
No files found.
Too many changes to show.
To preserve performance only
343 of 343+
files are displayed.
Plain diff
Email patch
megatron/core/optimizer/optimizer_config.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
import
torch
@
dataclass
class
OptimizerConfig
:
"""Configuration for optimizer."""
##############
# General
##############
optimizer
:
str
=
'adam'
"""Optimizer to use (one of Adam or SGD)."""
lr
:
Optional
[
float
]
=
None
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
iteration would be different.
"""
min_lr
:
Optional
[
float
]
=
None
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
decoupled_lr
:
Optional
[
float
]
=
None
"""Separate learning rate for the input and output layer."""
decoupled_min_lr
:
Optional
[
float
]
=
None
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
below this threshold.
"""
weight_decay
:
float
=
0.01
"""Weight decay coefficient for L2 regularization."""
##############
# Precision
##############
fp16
:
bool
=
False
"""If true, train with fp16 mixed precision training. Defaults to False."""
bf16
:
bool
=
False
"""If true, train with bf16 mixed precision training. Defaults to False."""
params_dtype
:
torch
.
dtype
=
torch
.
float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
###############
# Loss scaling
###############
loss_scale
:
Optional
[
float
]
=
None
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
dynamic loss scaling is used.
"""
initial_loss_scale
:
float
=
2
**
32
"""Initial loss-scale for dynamic loss scaling."""
min_loss_scale
:
float
=
1.0
"""Minimum loss scale for dynamic loss scaling."""
loss_scale_window
:
float
=
1000
"""Window over which to raise/lower dynamic scale."""
hysteresis
:
int
=
2
"""Hysteresis for dynamic loss scaling."""
##############
# Optimizer
##############
# Adam
adam_beta1
:
float
=
0.9
"""First coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_beta2
:
float
=
0.999
"""Second coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_eps
:
float
=
1e-08
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
# SGD.
sgd_momentum
:
float
=
0.9
"""Momentum factor for SGD optimizer."""
#######################
# Distributed optimizer
#######################
use_distributed_optimizer
:
bool
=
False
"""Distribute optimizer state over data-parallel replicas."""
overlap_grad_reduce
:
bool
=
False
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer."""
overlap_param_gather
:
bool
=
False
"""If true, overlap param all-gather with forward compute in distributed optimizer."""
################
# Miscellaneous
################
clip_grad
:
float
=
1.0
"""Gradient clipping based on global L2 norm."""
log_num_zeros_in_grad
:
bool
=
False
"""If true, calculate and log the number of zeros in gradient."""
barrier_with_L1_time
:
bool
=
False
"""If true, use barrier with level 1 time measurements."""
timers
:
Callable
=
None
"""Function to get timers."""
megatron/core/package_info.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
MAJOR
=
0
MINOR
=
8
PATCH
=
0
PRE_RELEASE
=
'rc0'
# Use the following formatting: (major, minor, patch, pre-release)
VERSION
=
(
MAJOR
,
MINOR
,
PATCH
,
PRE_RELEASE
)
__shortversion__
=
'.'
.
join
(
map
(
str
,
VERSION
[:
3
]))
__version__
=
'.'
.
join
(
map
(
str
,
VERSION
[:
3
]))
+
''
.
join
(
VERSION
[
3
:])
__package_name__
=
'megatron_core'
__contact_names__
=
'NVIDIA'
__contact_emails__
=
'nemo-toolkit@nvidia.com'
# use NeMo Email
__homepage__
=
(
'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/'
# use NeMo homepage
)
__repository_url__
=
'https://github.com/NVIDIA/Megatron-LM/megatron/core'
__download_url__
=
'https://github.com/NVIDIA/Megatron-LM/releases'
__description__
=
(
'Megatron Core - a library for efficient and scalable training of transformer based models'
)
__license__
=
'BSD-3'
__keywords__
=
(
'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch'
)
megatron/core/packed_seq_params.py
0 → 100644
View file @
0816dd4a
from
dataclasses
import
dataclass
from
torch
import
Tensor
@
dataclass
class
PackedSeqParams
:
# parameters to TEDotProductAttention and fused rope kernels for the `thd` (packed) sequence format,
qkv_format
:
str
=
None
cu_seqlens_q
:
Tensor
=
None
cu_seqlens_kv
:
Tensor
=
None
max_seqlen_q
:
Tensor
=
None
max_seqlen_kv
:
Tensor
=
None
megatron/core/parallel_state.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import
os
import
warnings
from
datetime
import
timedelta
from
typing
import
List
,
Optional
import
torch
from
.utils
import
GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP
=
None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP
=
None
# Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to.
_MODEL_AND_EXPERT_PARALLEL_GROUP
=
None
# Embedding group.
_EMBEDDING_GROUP
=
None
# Position embedding group.
_POSITION_EMBEDDING_GROUP
=
None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP_GLOO
=
None
# tensor model parallel group and data parallel group combined
# used for fp8 and moe training
_TENSOR_AND_DATA_PARALLEL_GROUP
=
None
# Expert parallel group that the current rank belongs to.
_EXPERT_MODEL_PARALLEL_GROUP
=
None
_TENSOR_AND_EXPERT_PARALLEL_GROUP
=
None
_DATA_MODULO_EXPERT_PARALLEL_GROUP
=
None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_MPU_EXPERT_MODEL_PARALLEL_RANK
=
None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS
=
None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS
=
None
# A list of global ranks for each tensor model parallel group to ease calculation of
# the first local rank in the tensor model parallel group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
=
None
# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP
=
None
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS
=
None
# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP
=
None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
=
None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
=
None
# combined parallel group of TP, DP, and CP used for fp8
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
=
None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER
=
None
# MOE logging
_MOE_AUX_LOSSES_LOGGING_TRACKER
=
{}
def
get_nccl_options
(
pg_name
,
nccl_comm_cfgs
):
"""Set the NCCL process group options.
Args:
pg_name (str): process group name
nccl_comm_cfgs (dict): nccl communicator configurations
When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
"""
if
pg_name
in
nccl_comm_cfgs
:
nccl_options
=
torch
.
distributed
.
ProcessGroupNCCL
.
Options
()
nccl_options
.
config
.
cga_cluster_size
=
nccl_comm_cfgs
[
pg_name
].
get
(
'cga_cluster_size'
,
4
)
nccl_options
.
config
.
max_ctas
=
nccl_comm_cfgs
[
pg_name
].
get
(
'max_ctas'
,
32
)
nccl_options
.
config
.
min_ctas
=
nccl_comm_cfgs
[
pg_name
].
get
(
'min_ctas'
,
1
)
return
nccl_options
else
:
return
None
def
generate_masked_orthogonal_rank_groups
(
world_size
:
int
,
parallel_size
:
List
[
int
],
mask
:
List
[
bool
],
)
->
List
[
List
[
int
]]:
"""Generate orthogonal parallel groups based on the parallel size and mask.
Arguments:
world_size (int): world size
parallel_size (List[int]):
The parallel size of each orthogonal parallel type. For example, if
tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
mask (List[bool]):
The mask controls which parallel methods the generated groups represent. If mask[i] is
True, it means the generated group contains the i-th parallelism method. For example,
if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
generated group is the `pp` group.
Algorithm:
For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
local_rank satisfy the following equation:
global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
tp_rank \in [0, tp_size)
dp_rank \in [0, dp_size)
pp_rank \in [0, pp_size)
If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
The tp_rank and pp_rank will be combined to form the `dp_group_index`.
dp_group_index = tp_rank + pp_rank * tp_size (2)
So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
equation (1).
This function solve this math problem.
For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
and the mask = [False, True, False]. Then,
dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
...
dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
...
dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
"""
def
prefix_product
(
a
:
List
[
int
],
init
=
1
)
->
List
[
int
]:
r
=
[
init
]
for
v
in
a
:
init
=
init
*
v
r
.
append
(
init
)
return
r
def
inner_product
(
a
:
List
[
int
],
b
:
List
[
int
])
->
int
:
return
sum
([
x
*
y
for
x
,
y
in
zip
(
a
,
b
)])
def
decompose
(
index
,
shape
,
stride
=
None
):
'''
This function solve the math problem below:
There is an equation:
index = sum(idx[i] * stride[i])
And given the value of index, stride.
Return the idx.
This function will used to get the pp/dp/pp_rank
from group_index and rank_in_group.
'''
if
stride
is
None
:
stride
=
prefix_product
(
shape
)
idx
=
[(
index
//
d
)
%
s
for
s
,
d
in
zip
(
shape
,
stride
)]
# stride is a prefix_product result. And the value of stride[-1]
# is not used.
assert
(
sum
([
x
*
y
for
x
,
y
in
zip
(
idx
,
stride
[:
-
1
])])
==
index
),
"idx {} with shape {} mismatch the return idx {}"
.
format
(
index
,
shape
,
idx
)
return
idx
masked_shape
=
[
s
for
s
,
m
in
zip
(
parallel_size
,
mask
)
if
m
]
unmasked_shape
=
[
s
for
s
,
m
in
zip
(
parallel_size
,
mask
)
if
not
m
]
global_stride
=
prefix_product
(
parallel_size
)
masked_stride
=
[
d
for
d
,
m
in
zip
(
global_stride
,
mask
)
if
m
]
unmasked_stride
=
[
d
for
d
,
m
in
zip
(
global_stride
,
mask
)
if
not
m
]
group_size
=
prefix_product
(
masked_shape
)[
-
1
]
num_of_group
=
world_size
//
group_size
ranks
=
[]
for
group_index
in
range
(
num_of_group
):
# get indices from unmaksed for group_index.
decomposed_group_idx
=
decompose
(
group_index
,
unmasked_shape
)
rank
=
[]
for
rank_in_group
in
range
(
group_size
):
# get indices from masked for rank_in_group.
decomposed_rank_idx
=
decompose
(
rank_in_group
,
masked_shape
)
rank
.
append
(
inner_product
(
decomposed_rank_idx
,
masked_stride
)
+
inner_product
(
decomposed_group_idx
,
unmasked_stride
)
)
ranks
.
append
(
rank
)
return
ranks
class
RankGenerator
(
object
):
def
__init__
(
self
,
tp
:
int
,
ep
:
int
,
dp
:
int
,
pp
:
int
,
cp
:
int
,
order
:
str
)
->
None
:
self
.
tp
=
tp
self
.
ep
=
ep
self
.
dp
=
dp
self
.
pp
=
pp
self
.
cp
=
cp
self
.
world_size
=
tp
*
dp
*
pp
*
cp
self
.
name_to_size
=
{
"tp"
:
self
.
tp
,
"pp"
:
self
.
pp
,
"dp"
:
self
.
dp
,
"ep"
:
self
.
ep
,
"cp"
:
self
.
cp
,
}
self
.
order
=
order
order
=
order
.
lower
()
if
'ep'
in
order
:
if
'ep-dp'
not
in
order
and
'dp-ep'
not
in
order
:
raise
RuntimeError
(
f
"The ep and dp must be adjacent in order (
{
self
.
order
}
)."
)
for
name
in
self
.
name_to_size
.
keys
():
if
name
not
in
order
and
self
.
name_to_size
[
name
]
!=
1
:
raise
RuntimeError
(
f
"The size of (
{
name
}
) is (
{
self
.
name_to_size
[
name
]
}
), but you haven't specified the order (
{
self
.
order
}
)."
)
elif
name
not
in
order
:
order
=
order
+
'-'
+
name
self
.
order_w_ep
=
order
self
.
order_wo_ep
=
'-'
.
join
([
token
for
token
in
order
.
split
(
'-'
)
if
token
!=
'ep'
])
self
.
ordered_size_wo_ep
=
[]
self
.
ordered_size_w_ep
=
[]
for
token
in
order
.
split
(
'-'
):
if
token
==
'dp'
:
self
.
ordered_size_w_ep
.
append
(
self
.
dp
//
self
.
ep
)
self
.
ordered_size_wo_ep
.
append
(
self
.
dp
)
elif
token
==
'ep'
:
self
.
ordered_size_w_ep
.
append
(
self
.
ep
)
else
:
self
.
ordered_size_w_ep
.
append
(
self
.
name_to_size
[
token
])
self
.
ordered_size_wo_ep
.
append
(
self
.
name_to_size
[
token
])
def
get_mask
(
self
,
order
:
str
,
token
:
str
):
ordered_token
=
order
.
split
(
'-'
)
token
=
token
.
split
(
'-'
)
mask
=
[
False
]
*
len
(
ordered_token
)
for
t
in
token
:
mask
[
ordered_token
.
index
(
t
)]
=
True
return
mask
def
get_ranks
(
self
,
token
,
independent_ep
=
False
):
'''Get rank group by input token.
Arguments:
token (str):
Specify the ranks type that want to get. If we want
to obtain multiple parallel types, we can use a hyphen
'-' to separate them. For example, if we want to obtain
the TP_DP group, the token should be 'tp-dp'.
independent_ep (bool: True):
This flag controls whether we treat EP and DP independently.
EP shares ranks with DP, if we want to get ranks related to
EP, we should set the flag. For example, get_ranks('dp', True)
will get DP modulo EP group, and get_ranks('dp', False) will
get full DP group.
'''
if
independent_ep
:
parallel_size
=
self
.
ordered_size_w_ep
order
=
self
.
order_w_ep
else
:
parallel_size
=
self
.
ordered_size_wo_ep
order
=
self
.
order_wo_ep
mask
=
self
.
get_mask
(
order
,
token
)
ranks
=
generate_masked_orthogonal_rank_groups
(
self
.
world_size
,
parallel_size
,
mask
)
return
ranks
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
,
use_sharp
:
bool
=
False
,
context_parallel_size
:
int
=
1
,
expert_model_parallel_size
:
int
=
1
,
nccl_communicator_config_path
:
Optional
[
str
]
=
None
,
distributed_timeout_minutes
:
int
=
30
,
order
:
str
=
"tp-cp-ep-dp-pp"
,
)
->
None
:
"""Initialize model data parallel groups.
Args:
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.
pipeline_model_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
within each data-parallel process group, which specifies
the SHARP application target groups.
context_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
network input sequence length across. Compute of attention
module requires tokens of full sequence length, so GPUs
in a context parallel group need to communicate with each
other to exchange information of other sequence chunks.
Each GPU and its counterparts in other tensor parallel
groups compose a context parallel group.
For example, assume we have 8 GPUs, if tensor model parallel
size is 4 and context parallel size is 2, the network input
will be split into two sequence chunks, which are processed
by 2 different groups of 4 GPUs. One chunk is processed by
GPU0-3, the other chunk is processed by GPU4-7. Four groups
are build to do context parallel communications: [GPU0, GPU4],
[GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
Context parallelism partitions sequence length, so it has no
impact on weights, which means weights are duplicated among
GPUs in a context parallel group. Hence, weight gradients
all-reduce is required in backward. For simplicity, we piggyback
GPUs of context parallelism on data parallel group for
weight gradient all-reduce.
expert_model_parallel_size (int, default = 1):
The number of Mixture of Experts parallel GPUs in each expert
parallel group.
nccl_communicator_config_path (str, default = None):
Path to the yaml file of NCCL communicator configurations.
`min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
for each communicator.
distributed_timeout_minutes (int, default = 30): Timeout, in
minutes,for operations executed against distributed
process groups. See PyTorch documentation at
https://pytorch.org/docs/stable/distributed.html for
caveats.
order (str, default=tp-dp-pp):
The rank initialization order of parallelism. Now we support
tp-dp-pp and tp-pp-dp orders.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
if
(
world_size
%
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
*
context_parallel_size
)
!=
0
):
raise
RuntimeError
(
f
"world_size (
{
world_size
}
) is not divisible by tensor_model_parallel_size "
f
"(
{
tensor_model_parallel_size
}
) x pipeline_model_parallel_size (
{
pipeline_model_parallel_size
}
) "
f
"x context_parallel_size (
{
context_parallel_size
}
)"
)
data_parallel_size
:
int
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
*
context_parallel_size
)
if
data_parallel_size
%
expert_model_parallel_size
!=
0
:
raise
RuntimeError
(
f
"data_parallel_size (
{
data_parallel_size
}
) is not divisible by expert_model_parallel_size "
)
if
expert_model_parallel_size
>
1
and
context_parallel_size
>
1
:
raise
RuntimeError
(
f
"combination of expert model prallellism and context parallelism is not supported"
)
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
not
pipeline_model_parallel_size
>
1
:
raise
RuntimeError
(
"pipeline-model-parallel size should be greater than 1 with interleaved schedule"
)
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size
if
pipeline_model_parallel_split_rank
is
not
None
:
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
pipeline_model_parallel_split_rank
rank
=
torch
.
distributed
.
get_rank
()
nccl_comm_cfgs
=
{}
if
nccl_communicator_config_path
is
not
None
:
try
:
import
yaml
except
ImportError
:
raise
RuntimeError
(
"Cannot import `yaml`. Setting custom nccl communicator configs "
"requires the yaml package."
)
with
open
(
nccl_communicator_config_path
,
"r"
)
as
stream
:
nccl_comm_cfgs
=
yaml
.
safe_load
(
stream
)
rank_generator
=
RankGenerator
(
tp
=
tensor_model_parallel_size
,
ep
=
expert_model_parallel_size
,
dp
=
data_parallel_size
,
pp
=
pipeline_model_parallel_size
,
cp
=
context_parallel_size
,
order
=
order
,
)
timeout
=
timedelta
(
minutes
=
distributed_timeout_minutes
)
# Build the data-parallel groups.
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP_GLOO
global
_DATA_PARALLEL_GLOBAL_RANKS
global
_DATA_PARALLEL_GROUP_WITH_CP
global
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
global
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
assert
_DATA_PARALLEL_GROUP
is
None
,
'data parallel group is already initialized'
for
ranks
in
rank_generator
.
get_ranks
(
'dp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'dp'
,
nccl_comm_cfgs
)
)
group_gloo
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
backend
=
"gloo"
)
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP_GLOO
=
group_gloo
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
for
ranks_with_cp
in
rank_generator
.
get_ranks
(
'dp-cp'
):
group_with_cp
=
torch
.
distributed
.
new_group
(
ranks_with_cp
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'dp_cp'
,
nccl_comm_cfgs
)
)
group_with_cp_gloo
=
torch
.
distributed
.
new_group
(
ranks_with_cp
,
timeout
=
timeout
,
backend
=
"gloo"
)
if
rank
in
ranks_with_cp
:
_DATA_PARALLEL_GROUP_WITH_CP
=
group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
=
group_with_cp_gloo
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
=
ranks_with_cp
# Apply SHARP to DP process groups
if
use_sharp
:
if
rank
==
0
:
print
(
"The number of process groups to use SHARP with depends on the type "
"of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
"process groups and QM2 supports up to 256 process groups. We apply "
"SHARP to the communications of the data-parallel domain. If the "
"number of data-parallel process groups is larger than the max "
"process groups that the network switch supports, the communication "
"will fall back to non-SHARP operators. To enable SHARP, "
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
torch
.
distributed
.
barrier
(
group
=
get_data_parallel_group
(
with_context_parallel
=
True
),
device_ids
=
[
torch
.
cuda
.
current_device
()],
)
# Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to DP process groups
os
.
environ
[
"NCCL_COLLNET_ENABLE"
]
=
"0"
# Build the context-parallel groups.
global
_CONTEXT_PARALLEL_GROUP
global
_CONTEXT_PARALLEL_GLOBAL_RANKS
assert
_CONTEXT_PARALLEL_GROUP
is
None
,
'context parallel group is already initialized'
for
ranks
in
rank_generator
.
get_ranks
(
'cp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'cp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_CONTEXT_PARALLEL_GROUP
=
group
_CONTEXT_PARALLEL_GLOBAL_RANKS
=
ranks
# Build the model-parallel groups.
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
'model parallel group is already initialized'
for
ranks
in
rank_generator
.
get_ranks
(
'tp-pp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'mp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
# Build the model-parallel groups with expert parallel
global
_MODEL_AND_EXPERT_PARALLEL_GROUP
assert
(
_MODEL_AND_EXPERT_PARALLEL_GROUP
is
None
),
'model and expert parallel group is already initialized'
for
ranks
in
rank_generator
.
get_ranks
(
'tp-ep-pp'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'mp_exp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_MODEL_AND_EXPERT_PARALLEL_GROUP
=
group
# Build the tensor model-parallel groups.
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
),
'tensor model parallel group is already initialized'
for
ranks
in
rank_generator
.
get_ranks
(
'tp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
=
ranks
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_GLOBAL_RANKS
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
),
'pipeline model parallel group is already initialized'
global
_EMBEDDING_GROUP
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
'embedding group is already initialized'
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
_POSITION_EMBEDDING_GROUP
is
None
,
'position embedding group is already initialized'
for
ranks
in
rank_generator
.
get_ranks
(
'pp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'pp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank
is
not
None
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
],
ranks
[
-
1
],
]
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
position_embedding_ranks
:
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
]]
else
:
embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'embd'
,
nccl_comm_cfgs
)
)
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'embd'
,
nccl_comm_cfgs
),
)
if
rank
in
position_embedding_ranks
:
_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
# Build the tensor + data parallel groups.
global
_TENSOR_AND_DATA_PARALLEL_GROUP
global
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP
is
None
),
'Tensor + data parallel group is already initialized'
for
ranks
in
rank_generator
.
get_ranks
(
'tp-dp-cp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp_dp_cp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
=
group
for
ranks
in
rank_generator
.
get_ranks
(
'tp-dp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp_dp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_AND_DATA_PARALLEL_GROUP
=
group
# Build the tensor + expert parallel groups
global
_EXPERT_MODEL_PARALLEL_GROUP
assert
_EXPERT_MODEL_PARALLEL_GROUP
is
None
,
'Expert parallel group is already initialized'
global
_TENSOR_AND_EXPERT_PARALLEL_GROUP
assert
(
_TENSOR_AND_EXPERT_PARALLEL_GROUP
is
None
),
'Tensor + expert parallel group is already initialized'
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP
is
None
),
'Data modulo expert group is already initialized'
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
for
ranks
in
rank_generator
.
get_ranks
(
'tp-ep'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp_exp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_AND_EXPERT_PARALLEL_GROUP
=
group
for
ranks
in
rank_generator
.
get_ranks
(
'ep'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
pg_options
=
get_nccl_options
(
'exp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_EXPERT_MODEL_PARALLEL_GROUP
=
group
for
ranks
in
rank_generator
.
get_ranks
(
'dp'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'dp_modulo_exp'
,
nccl_comm_cfgs
)
)
group_gloo
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
_DATA_MODULO_EXPERT_PARALLEL_GROUP
=
group
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
=
group_gloo
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer
()
def
is_initialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
not
None
def
is_unitialized
()
->
bool
:
"""Check if parallel state has been initialized
Deprecated. Use is_initialized instead.
"""
warnings
.
warn
(
"is_unitialized is deprecated, use is_initialized instead"
,
DeprecationWarning
,
)
return
not
is_initialized
()
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
if
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
):
return
False
return
True
def
get_model_parallel_group
(
with_expert_parallel
=
False
):
"""Get the model parallel group the caller rank belongs to."""
if
with_expert_parallel
:
assert
(
_MODEL_AND_EXPERT_PARALLEL_GROUP
is
not
None
),
'model parallel group is not initialized'
return
_MODEL_AND_EXPERT_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
'model parallel group is not initialized'
return
_MODEL_PARALLEL_GROUP
def
get_tensor_model_parallel_group
(
check_initialized
=
True
):
"""Get the tensor model parallel group the caller rank belongs to."""
if
check_initialized
:
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
),
'tensor model parallel group is not initialized'
return
_TENSOR_MODEL_PARALLEL_GROUP
def
get_pipeline_model_parallel_group
():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
),
'pipeline_model parallel group is not initialized'
return
_PIPELINE_MODEL_PARALLEL_GROUP
def
get_data_parallel_group
(
with_context_parallel
=
False
):
"""Get the data parallel group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_DATA_PARALLEL_GROUP_WITH_CP
is
not
None
),
'data parallel group with context parallel combined is not initialized'
return
_DATA_PARALLEL_GROUP_WITH_CP
else
:
assert
_DATA_PARALLEL_GROUP
is
not
None
,
'data parallel group is not initialized'
return
_DATA_PARALLEL_GROUP
def
get_data_parallel_group_gloo
(
with_context_parallel
=
False
):
"""Get the data parallel group-gloo the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
is
not
None
),
'data parallel group-gloo with context parallel combined is not initialized'
return
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
else
:
assert
_DATA_PARALLEL_GROUP_GLOO
is
not
None
,
'data parallel group-gloo is not initialized'
return
_DATA_PARALLEL_GROUP_GLOO
def
get_context_parallel_group
(
check_initialized
=
True
):
"""Get the context parallel group the caller rank belongs to."""
if
check_initialized
:
assert
_CONTEXT_PARALLEL_GROUP
is
not
None
,
'context parallel group is not initialized'
return
_CONTEXT_PARALLEL_GROUP
def
get_context_parallel_global_ranks
(
check_initialized
=
True
):
"""Get all global ranks of the context parallel group that the caller rank belongs to."""
if
check_initialized
:
assert
(
_CONTEXT_PARALLEL_GLOBAL_RANKS
is
not
None
),
'context parallel group is not initialized'
return
_CONTEXT_PARALLEL_GLOBAL_RANKS
def
get_embedding_group
():
"""Get the embedding group the caller rank belongs to."""
assert
_EMBEDDING_GROUP
is
not
None
,
'embedding group is not initialized'
return
_EMBEDDING_GROUP
def
get_position_embedding_group
():
"""Get the position embedding group the caller rank belongs to."""
assert
_POSITION_EMBEDDING_GROUP
is
not
None
,
'position embedding group is not initialized'
return
_POSITION_EMBEDDING_GROUP
def
get_amax_reduction_group
(
with_context_parallel
=
False
):
"""Get the FP8 amax reduction group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
is
not
None
),
'FP8 amax reduction group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP
is
not
None
),
'FP8 amax reduction group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP
def
get_tensor_and_data_parallel_group
(
with_context_parallel
=
False
):
"""Get the tensor and data parallel group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
is
not
None
),
'tensor and data parallel group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP
is
not
None
),
'tensor and data parallel group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP
def
get_expert_model_parallel_group
():
assert
(
_EXPERT_MODEL_PARALLEL_GROUP
is
not
None
),
'expert model parallel group is not initialized'
return
_EXPERT_MODEL_PARALLEL_GROUP
def
get_tensor_and_expert_parallel_group
():
assert
(
_TENSOR_AND_EXPERT_PARALLEL_GROUP
is
not
None
),
'tensor and expert parallel group is not initialized'
return
_TENSOR_AND_EXPERT_PARALLEL_GROUP
def
get_data_modulo_expert_parallel_group
():
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP
is
not
None
),
'data modulo expert parallel group is not initialized'
return
_DATA_MODULO_EXPERT_PARALLEL_GROUP
def
get_data_modulo_expert_parallel_group_gloo
():
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
is
not
None
),
'data modulo expert parallel group-gloo is not initialized'
return
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
def
set_expert_model_parallel_world_size
(
world_size
):
global
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
set_pipeline_model_parallel_world_size
(
world_size
):
"""Set the pipeline model parallel size"""
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
set_virtual_pipeline_model_parallel_world_size
(
world_size
):
"""Set the pipeline model parallel size"""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
get_tensor_model_parallel_world_size
():
"""Return world size for the tensor model parallel group."""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_tensor_model_parallel_group
())
def
get_pipeline_model_parallel_world_size
():
"""Return world size for the pipeline model parallel group."""
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_pipeline_model_parallel_group
())
def
set_expert_model_parallel_rank
(
rank
):
"""Set expert model parallel rank."""
global
_MPU_EXPERT_MODEL_PARALLEL_RANK
_MPU_EXPERT_MODEL_PARALLEL_RANK
=
rank
def
set_tensor_model_parallel_rank
(
rank
):
"""Set tensor model parallel rank."""
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
rank
def
set_pipeline_model_parallel_rank
(
rank
):
"""Set pipeline model parallel rank."""
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
set_pipeline_model_parallel_split_rank
(
rank
):
"""Set pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
rank
def
get_tensor_model_parallel_rank
():
"""Return my rank for the tensor model parallel group."""
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
if
_MPU_TENSOR_MODEL_PARALLEL_RANK
is
not
None
:
return
_MPU_TENSOR_MODEL_PARALLEL_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_tensor_model_parallel_group
())
def
get_pipeline_model_parallel_rank
():
"""Return my rank for the pipeline model parallel group."""
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
if
_MPU_PIPELINE_MODEL_PARALLEL_RANK
is
not
None
:
return
_MPU_PIPELINE_MODEL_PARALLEL_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
get_pipeline_model_parallel_split_rank
():
"""Return pipeline model parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
(
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
0
):
return
False
return
get_pipeline_model_parallel_rank
()
==
0
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
(
get_virtual_pipeline_model_parallel_world_size
()
)
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
def
is_rank_in_embedding_group
(
ignore_virtual
=
False
):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_EMBEDDING_GLOBAL_RANKS
if
ignore_virtual
:
return
rank
in
_EMBEDDING_GLOBAL_RANKS
if
rank
in
_EMBEDDING_GLOBAL_RANKS
:
if
rank
==
_EMBEDDING_GLOBAL_RANKS
[
0
]:
return
is_pipeline_first_stage
(
ignore_virtual
=
False
)
elif
rank
==
_EMBEDDING_GLOBAL_RANKS
[
-
1
]:
return
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
:
return
True
return
False
def
is_rank_in_position_embedding_group
():
"""Return true if current rank is in position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
<
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_after_split
(
rank
=
None
):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
is
None
:
return
True
if
rank
>=
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
:
return
True
return
False
def
is_pipeline_stage_at_split
():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank
=
get_pipeline_model_parallel_rank
()
return
is_pipeline_stage_before_split
(
rank
)
and
is_pipeline_stage_after_split
(
rank
+
1
)
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def
set_virtual_pipeline_model_parallel_rank
(
rank
):
"""Set the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
assert
(
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
is
not
None
),
"Tensor model parallel group is not initialized"
return
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
[
0
]
def
get_data_parallel_src_rank
(
with_context_parallel
=
False
):
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
if
with_context_parallel
:
assert
(
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
is
not
None
),
"Data parallel group with context parallel combined is not initialized"
return
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
[
0
]
else
:
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
"Data parallel group is not initialized"
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_first_rank
():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
"""Return the global rank that follows the caller in the pipeline"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
"""Return the global rank that preceeds the caller in the pipeline"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
(
with_context_parallel
=
False
):
"""Return world size for the data parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
(
with_context_parallel
=
with_context_parallel
)
)
else
:
return
0
def
get_data_parallel_rank
(
with_context_parallel
=
False
):
"""Return my rank for the data parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
(
with_context_parallel
=
with_context_parallel
)
)
else
:
return
0
def
get_context_parallel_world_size
():
"""Return world size for the context parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
(
group
=
get_context_parallel_group
())
else
:
return
0
def
get_context_parallel_rank
():
"""Return my rank for the context parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_context_parallel_group
())
else
:
return
0
def
get_expert_model_parallel_world_size
():
"""Return world size for the expert model parallel group"""
if
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
:
return
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
tensor_and_expert_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
get_tensor_and_expert_parallel_group
()
)
return
tensor_and_expert_parallel_world_size
//
get_tensor_model_parallel_world_size
()
else
:
return
0
def
get_tensor_and_expert_parallel_world_size
():
"""Return world size for the expert model parallel group times model parallel group.
Currently, each expert will also be distributed across TP group by default.
"""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
tensor_and_expert_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
get_tensor_and_expert_parallel_group
()
)
return
tensor_and_expert_parallel_world_size
else
:
return
0
def
get_expert_model_parallel_rank
():
"""Return my rank for the expert parallel group"""
if
_MPU_EXPERT_MODEL_PARALLEL_RANK
:
return
_MPU_EXPERT_MODEL_PARALLEL_RANK
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
tensor_and_expert_parallel_rank
=
torch
.
distributed
.
get_rank
(
group
=
get_tensor_and_expert_parallel_group
()
)
return
tensor_and_expert_parallel_rank
//
get_tensor_model_parallel_world_size
()
else
:
return
0
def
get_data_modulo_expert_parallel_rank
():
"""Return my rank for the context parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_modulo_expert_parallel_group
())
else
:
return
0
def
get_tensor_and_expert_parallel_rank
():
"""Return my rank for the tensor and expert parallel group"""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_tensor_and_expert_parallel_group
())
else
:
return
0
def
_set_global_memory_buffer
():
"""Initialize global buffer"""
global
_GLOBAL_MEMORY_BUFFER
assert
_GLOBAL_MEMORY_BUFFER
is
None
,
'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER
=
GlobalMemoryBuffer
()
def
get_global_memory_buffer
():
"""Return the global GlobalMemoryBuffer object"""
assert
_GLOBAL_MEMORY_BUFFER
is
not
None
,
'global memory buffer is not initialized'
return
_GLOBAL_MEMORY_BUFFER
def
destroy_global_memory_buffer
():
"""Sets the global memory buffer to None"""
global
_GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER
=
None
def
destroy_model_parallel
():
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP
=
None
global
_MODEL_AND_EXPERT_PARALLEL_GROUP
_MODEL_AND_EXPERT_PARALLEL_GROUP
=
None
global
_TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP_WITH_CP
_DATA_PARALLEL_GROUP_WITH_CP
=
None
global
_CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP
=
None
global
_CONTEXT_PARALLEL_GLOBAL_RANKS
_CONTEXT_PARALLEL_GLOBAL_RANKS
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
global
_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP
=
None
global
_TENSOR_AND_DATA_PARALLEL_GROUP
_TENSOR_AND_DATA_PARALLEL_GROUP
=
None
global
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
=
None
global
_EXPERT_MODEL_PARALLEL_GROUP
_EXPERT_MODEL_PARALLEL_GROUP
=
None
global
_TENSOR_AND_EXPERT_PARALLEL_GROUP
_TENSOR_AND_EXPERT_PARALLEL_GROUP
=
None
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP
_DATA_MODULO_EXPERT_PARALLEL_GROUP
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER
=
None
global
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_EXPERT_MODEL_PARALLEL_RANK
_MPU_EXPERT_MODEL_PARALLEL_RANK
=
None
megatron/core/pipeline_parallel/__init__.py
0 → 100644
View file @
0816dd4a
from
.schedules
import
get_forward_backward_func
megatron/core/pipeline_parallel/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0816dd4a
File added
megatron/core/pipeline_parallel/__pycache__/p2p_communication.cpython-310.pyc
0 → 100644
View file @
0816dd4a
File added
megatron/core/pipeline_parallel/__pycache__/schedules.cpython-310.pyc
0 → 100644
View file @
0816dd4a
File added
megatron/core/pipeline_parallel/p2p_communication.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
operator
from
functools
import
reduce
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
megatron
import
core
from
megatron.core
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
get_pipeline_model_parallel_group
,
get_pipeline_model_parallel_next_rank
,
get_pipeline_model_parallel_prev_rank
,
get_pipeline_model_parallel_rank
,
get_pipeline_model_parallel_world_size
,
)
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
config
):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
are not uniform.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(recv_prev_shape, recv_next_shape)
"""
recv_prev_shape_tensor
=
None
recv_next_shape_tensor
=
None
send_prev_shape_tensor
=
None
send_next_shape_tensor
=
None
if
recv_prev
:
recv_prev_shape_tensor
=
torch
.
empty
(
(
3
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
recv_next
:
recv_next_shape_tensor
=
torch
.
empty
(
(
3
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
tensor_send_prev
is
not
None
:
send_prev_shape_tensor
=
torch
.
tensor
(
tensor_send_prev
.
size
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
tensor_send_next
is
not
None
:
send_next_shape_tensor
=
torch
.
tensor
(
tensor_send_next
.
size
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
torch
.
int64
)
if
config
.
use_ring_exchange_p2p
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
send_prev_shape_tensor
,
tensor_recv_prev
=
recv_prev_shape_tensor
,
tensor_send_next
=
send_next_shape_tensor
,
tensor_recv_next
=
recv_next_shape_tensor
,
group
=
get_pipeline_model_parallel_group
(),
)
else
:
ops
=
[]
if
send_prev_shape_tensor
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
send_prev_shape_tensor
,
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
send_prev_op
)
if
recv_prev_shape_tensor
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_prev_shape_tensor
,
get_pipeline_model_parallel_prev_rank
(),
)
ops
.
append
(
recv_prev_op
)
if
send_next_shape_tensor
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
send_next_shape_tensor
,
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
send_next_op
)
if
recv_next_shape_tensor
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_next_shape_tensor
,
get_pipeline_model_parallel_next_rank
(),
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
torch
.
cuda
.
synchronize
()
recv_prev_shape
=
[
0
,
0
,
0
]
if
recv_prev_shape_tensor
is
not
None
:
recv_prev_shape
=
recv_prev_shape_tensor
.
tolist
()
recv_next_shape
=
[
0
,
0
,
0
]
if
recv_next_shape_tensor
is
not
None
:
recv_next_shape
=
recv_next_shape_tensor
.
tolist
()
return
recv_prev_shape
,
recv_next_shape
def
_batched_p2p_ops
(
*
,
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
tensor_recv_prev
:
Optional
[
torch
.
Tensor
],
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_recv_next
:
Optional
[
torch
.
Tensor
],
group
:
torch
.
distributed
.
ProcessGroup
):
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
get_pipeline_model_parallel_prev_rank
(),
group
,
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
get_pipeline_model_parallel_prev_rank
(),
group
,
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
get_pipeline_model_parallel_next_rank
(),
group
,
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
get_pipeline_model_parallel_next_rank
(),
group
,
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
else
:
reqs
=
[]
return
reqs
def
_p2p_ops
(
*
,
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
tensor_recv_prev
:
Optional
[
torch
.
Tensor
],
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_recv_next
:
Optional
[
torch
.
Tensor
],
group
:
torch
.
distributed
.
ProcessGroup
):
reqs
=
[]
rank
=
get_pipeline_model_parallel_rank
()
even_send_odd_recv_group
=
group
if
get_pipeline_model_parallel_world_size
()
==
2
:
# Use the global process group for one of the two p2p communications
# to allow the overlap of the independent communications.
# Using the global process group is compatible because the pipeline-parallel
# communications set the source and destination by global rank.
even_recv_odd_send_group
=
torch
.
distributed
.
group
.
WORLD
else
:
even_recv_odd_send_group
=
group
if
get_pipeline_model_parallel_rank
()
%
2
==
0
:
if
tensor_send_next
is
not
None
:
send_next_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_next
,
dst
=
get_pipeline_model_parallel_next_rank
(),
group
=
even_send_odd_recv_group
,
)
reqs
.
append
(
send_next_req
)
if
tensor_recv_prev
is
not
None
:
recv_prev_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_prev
,
src
=
get_pipeline_model_parallel_prev_rank
(),
group
=
even_recv_odd_send_group
,
)
reqs
.
append
(
recv_prev_req
)
if
tensor_send_prev
is
not
None
:
send_prev_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_prev
,
dst
=
get_pipeline_model_parallel_prev_rank
(),
group
=
even_send_odd_recv_group
,
)
reqs
.
append
(
send_prev_req
)
if
tensor_recv_next
is
not
None
:
recv_next_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_next
,
src
=
get_pipeline_model_parallel_next_rank
(),
group
=
even_recv_odd_send_group
,
)
reqs
.
append
(
recv_next_req
)
else
:
if
tensor_recv_prev
is
not
None
:
recv_prev_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_prev
,
src
=
get_pipeline_model_parallel_prev_rank
(),
group
=
even_send_odd_recv_group
,
)
reqs
.
append
(
recv_prev_req
)
if
tensor_send_next
is
not
None
:
send_next_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_next
,
dst
=
get_pipeline_model_parallel_next_rank
(),
group
=
even_recv_odd_send_group
,
)
reqs
.
append
(
send_next_req
)
if
tensor_recv_next
is
not
None
:
recv_next_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_next
,
src
=
get_pipeline_model_parallel_next_rank
(),
group
=
even_send_odd_recv_group
,
)
reqs
.
append
(
recv_next_req
)
if
tensor_send_prev
is
not
None
:
send_prev_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_prev
,
dst
=
get_pipeline_model_parallel_prev_rank
(),
group
=
even_recv_odd_send_group
,
)
reqs
.
append
(
send_prev_req
)
return
reqs
def
_communicate
(
*
,
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_send_prev
:
Optional
[
torch
.
Tensor
],
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
wait_on_reqs
:
bool
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Args:
tensor_send_next (torch.Tensor, optional):
Tensor to send to next rank (no tensor sent if None)
tensor_send_prev (torch.Tensor, optional):
Tensor to send to prev rank (no tensor sent if None)
recv_prev (boolean, required):
whether tensor should be received from previous rank.
recv_next (boolean, required):
whether tensor should be received from next rank.
tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
wait_on_reqs (boolean, optional, default=False):
For non-batched p2p communication, wait on each request
before returning.
Returns:
tuple containing
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
if
not
config
.
variable_seq_lengths
:
recv_prev_shape
=
tensor_shape
recv_next_shape
=
tensor_shape
else
:
recv_prev_shape
,
recv_next_shape
=
_communicate_shapes
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
config
)
if
recv_prev
:
if
config
.
pipeline_dtype
is
None
:
raise
RuntimeError
(
"pipeline_dtype must be provided if recv_prev is True"
)
if
tensor_shape
is
None
:
raise
RuntimeError
(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev
=
torch
.
empty
(
recv_prev_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
if
recv_next
:
if
config
.
pipeline_dtype
is
None
:
raise
RuntimeError
(
"dtype must be provided if recv_next is True"
)
if
tensor_shape
is
None
:
raise
RuntimeError
(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next
=
torch
.
empty
(
recv_next_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
# Send tensors in both the forward and backward directions as appropriate.
if
config
.
use_ring_exchange_p2p
:
def
_ring_exchange_wrapper
(
**
kwargs
):
torch
.
distributed
.
ring_exchange
(
**
kwargs
)
return
[]
p2p_func
=
_ring_exchange_wrapper
elif
config
.
batch_p2p_comm
:
assert
wait_on_reqs
p2p_func
=
_batched_p2p_ops
else
:
p2p_func
=
_p2p_ops
reqs
=
p2p_func
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
get_pipeline_model_parallel_group
(),
)
if
wait_on_reqs
and
len
(
reqs
)
>
0
:
for
req
in
reqs
:
req
.
wait
()
reqs
=
None
if
config
.
batch_p2p_comm
and
config
.
batch_p2p_sync
:
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch
.
cuda
.
synchronize
()
return
tensor_recv_prev
,
tensor_recv_next
,
reqs
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
)
->
torch
.
Tensor
:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-recv'
).
stop
()
return
input_tensor
def
recv_backward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
)
->
torch
.
Tensor
:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
config
:
ModelParallelConfig
)
->
None
:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
if
not
core
.
parallel_state
.
is_pipeline_last_stage
():
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send'
,
log_level
=
2
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
tensor_shape
=
None
,
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
config
:
ModelParallelConfig
)
->
None
:
"""Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
"""
if
not
core
.
parallel_state
.
is_pipeline_first_stage
():
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send'
,
log_level
=
2
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
tensor_shape
=
None
,
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
)
->
torch
.
Tensor
:
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-backward-recv'
,
log_level
=
2
).
start
()
#print("++++++++++++output_tensor is:",output_tensor.numel())
_
,
output_tensor_grad
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
)
->
torch
.
Tensor
:
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
"""
if
core
.
parallel_state
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
recv_prev
:
bool
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
overlap_p2p_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
_
,
wait_handles
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
wait_on_reqs
=
(
not
overlap_p2p_comm
),
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-forward-recv'
).
stop
()
if
overlap_p2p_comm
:
return
input_tensor
,
wait_handles
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
:
torch
.
Tensor
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
overlap_p2p_comm
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
"""
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send-backward-recv'
,
log_level
=
2
).
start
()
_
,
output_tensor_grad
,
wait_handles
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
wait_on_reqs
=
(
not
overlap_p2p_comm
),
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send-backward-recv'
).
stop
()
if
overlap_p2p_comm
:
return
output_tensor_grad
,
wait_handles
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
:
torch
.
Tensor
,
input_tensor_grad
:
torch
.
Tensor
,
recv_prev
:
bool
,
recv_next
:
bool
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
)
->
torch
.
Tensor
:
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
"""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward-send-forward-backward-recv'
,
log_level
=
2
).
start
()
input_tensor
,
output_tensor_grad
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
megatron/core/pipeline_parallel/schedules.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
contextlib
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
import
torch
from
torch.autograd.variable
import
Variable
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.utils
import
get_attr_wrapped_model
,
get_model_config
,
get_model_type
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
get_forward_backward_func
():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.
model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
num_microbatches (int, required):
The number of microbatches to go through
seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.
micro_batch_size (int, required): The number of sequences in a microbatch.
decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.
forward_only (optional, default = False): Perform only the forward step
collect_non_loss_data (optional, bool, default=False): TODO
first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation step.
"""
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
if
pipeline_model_parallel_size
>
1
:
if
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
def
deallocate_output_tensor
(
out
,
deallocate_pipeline_outputs
=
False
):
'''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
'''
if
(
out
is
None
)
or
(
not
deallocate_pipeline_outputs
):
return
assert
isinstance
(
out
,
torch
.
Tensor
),
"expected Tensor, found %s."
%
type
(
out
).
__name__
assert
out
.
_base
is
None
,
"counter-productive to free a view of another tensor."
out
.
data
=
torch
.
empty
((
1
,),
device
=
out
.
device
,
dtype
=
out
.
dtype
,)
def
custom_backward
(
output
,
grad_output
):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
autograd engine must be called directly, bypassing Pytorch's
torch.autograd.backward. Pytorch's 'backward' checks that the output and
grad have the same shape, while C++'s 'backward' does not.
'''
assert
output
.
numel
()
==
1
,
"output should be pseudo-'freed' in schedule, to optimize memory"
assert
isinstance
(
output
,
torch
.
Tensor
),
"output == '%s'."
%
type
(
output
).
__name__
assert
isinstance
(
grad_output
,
(
torch
.
Tensor
,
type
(
None
))),
(
"grad_output == '%s'."
%
type
(
grad_output
).
__name__
)
# Handle scalar output
if
grad_output
is
None
:
assert
output
.
numel
()
==
1
,
"implicit grad requires scalar output."
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
,)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable
.
_execution_engine
.
run_backward
(
tensors
=
(
output
,),
grad_tensors
=
(
grad_output
,),
keep_graph
=
False
,
create_graph
=
False
,
inputs
=
tuple
(),
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
def
set_current_microbatch
(
model
,
microbatch_id
):
decoder_exists
=
True
decoder
=
None
try
:
decoder
=
get_attr_wrapped_model
(
model
,
"decoder"
)
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
decoder
.
current_microbatch
=
microbatch_id
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
is_first_microbatch
and
hasattr
(
model
,
'set_is_first_microbatch'
):
model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
with
context_manager
:
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
else
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
,
checkpoint_activations_microbatch
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
():
if
not
collect_non_loss_data
:
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]],
num_tokens
if
unwrap_output_tensor
:
return
output_tensor
,
num_tokens
return
[
output_tensor
],
num_tokens
def
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
and
config
.
grad_scale_func
is
not
None
:
output_tensor
[
0
]
=
config
.
grad_scale_func
(
output_tensor
[
0
])
if
config
.
deallocate_pipeline_outputs
:
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
input_tensor_grad
=
[]
for
x
in
input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
def
check_first_val_step
(
first_val_step
,
forward_only
,
cond
):
if
(
first_val_step
is
not
None
)
and
forward_only
:
return
first_val_step
and
cond
else
:
return
cond
def
forward_backward_no_pipelining
(
*
,
forward_step_func
,
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
num_microbatches
:
int
,
seq_length
:
int
,
# unused
micro_batch_size
:
int
,
# unused
decoder_seq_length
:
int
=
None
,
# unused
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses.
See get_forward_backward_func() for argument details
"""
if
isinstance
(
model
,
list
):
assert
len
(
model
)
==
1
,
"non-pipeline-parallel schedule does not support model chunking"
model
=
model
[
0
]
if
isinstance
(
data_iterator
,
list
):
assert
(
len
(
data_iterator
)
==
1
),
"non-pipeline-parallel schedule does not support model chunking"
data_iterator
=
data_iterator
[
0
]
config
=
get_model_config
(
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
model_type
=
get_model_type
(
model
)
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
with
no_sync_func
():
for
i
in
range
(
num_microbatches
-
1
):
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
i
,
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
num_microbatches
==
1
),
current_microbatch
=
num_microbatches
-
1
,
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism and layernorm all-reduce for sequence parallelism).
config
.
finalize_model_grads_func
(
[
model
],
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
return
forward_data_store
def
forward_backward_pipelining_with_interleaving
(
*
,
forward_step_func
,
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
num_microbatches
:
int
,
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
=
None
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert
isinstance
(
model
,
list
),
"interleaved pipeline parallelism expected model chunking"
assert
all
(
isinstance
(
chunk
,
torch
.
nn
.
Module
)
for
chunk
in
model
),
"invalid model chunking"
assert
isinstance
(
data_iterator
,
list
),
"interleaved pipeline parallelism expected each model chunk to have a data iterator"
config
=
get_model_config
(
model
[
0
])
if
config
.
overlap_p2p_comm
and
config
.
batch_p2p_comm
:
raise
ValueError
(
"Can not use both overlap_p2p_comm and batch_p2p_comm"
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
isinstance
(
no_sync_func
,
list
):
def
multi_no_sync
():
stack
=
contextlib
.
ExitStack
()
for
model_chunk_no_sync_func
in
config
.
no_sync_func
:
stack
.
enter_context
(
model_chunk_no_sync_func
())
return
stack
no_sync_func
=
multi_no_sync
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
if
config
.
grad_sync_func
is
not
None
and
not
isinstance
(
config
.
grad_sync_func
,
list
):
config
.
grad_sync_func
=
[
config
.
grad_sync_func
for
_
in
model
]
if
config
.
param_sync_func
is
not
None
and
not
isinstance
(
config
.
param_sync_func
,
list
):
config
.
param_sync_func
=
[
config
.
param_sync_func
for
_
in
model
]
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
None
:
no_sync_context
=
no_sync_func
()
no_sync_context
.
__enter__
()
def
enable_grad_sync
():
"""Enable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
not
None
:
no_sync_context
.
__exit__
(
None
,
None
,
None
)
no_sync_context
=
None
disable_grad_sync
()
# Model chunk IDs with synchronized grads
synchronized_model_chunks
=
set
()
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
forward_data_store
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
if
num_microbatches
%
pipeline_parallel_size
!=
0
:
msg
=
f
'number of microbatches (
{
num_microbatches
}
) is not divisible by '
msg
+=
f
'pipeline-model-parallel-size (
{
pipeline_parallel_size
}
) '
msg
+=
'when using interleaved schedule'
raise
RuntimeError
(
msg
)
model_type
=
get_model_type
(
model
[
0
])
if
model_type
==
ModelType
.
encoder_and_decoder
:
raise
RuntimeError
(
"Interleaving is not supported with an encoder and decoder model."
)
if
decoder_seq_length
is
not
None
and
decoder_seq_length
!=
seq_length
:
raise
RuntimeError
(
"Interleaving is not supported with a different decoder sequence length."
)
tensor_shape
=
[
seq_length
,
micro_batch_size
,
config
.
hidden_size
]
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_context_parallel_world_size
()
if
config
.
sequence_parallel
:
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_tensor_model_parallel_world_size
()
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
total_num_microbatches
=
num_microbatches
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
total_num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
num_microbatches
==
pipeline_parallel_size
:
num_warmup_microbatches
=
total_num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
total_num_microbatches
)
num_microbatches_remaining
=
total_num_microbatches
-
num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops
=
None
if
config
.
num_microbatches_with_partial_activation_checkpoints
is
not
None
:
max_outstanding_backprops
=
num_warmup_microbatches
+
1
# Synchronize params for first two model chunks
if
config
.
param_sync_func
is
not
None
:
config
.
param_sync_func
[
0
](
model
[
0
].
parameters
())
config
.
param_sync_func
[
1
](
model
[
1
].
parameters
())
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
def
get_microbatch_id_in_model_chunk
(
iteration_id
,
forward
):
"""Helper method to get the microbatch_id within model chunk given the iteration number."""
assert
forward
iteration_group_id
=
iteration_id
//
(
pipeline_parallel_size
*
num_model_chunks
)
microbatch_id_in_model_chunk
=
(
iteration_group_id
*
pipeline_parallel_size
)
+
(
iteration_id
%
pipeline_parallel_size
)
return
microbatch_id_in_model_chunk
def
is_first_microbatch_for_model_chunk
(
microbatch_id
:
int
)
->
bool
:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size
=
pipeline_parallel_size
*
num_model_chunks
num_microbatch_groups
=
total_num_microbatches
//
microbatch_group_size
microbatch_group_id
=
microbatch_id
//
microbatch_group_size
microbatch_id_in_group
=
microbatch_id
%
microbatch_group_size
if
microbatch_group_id
==
0
:
return
microbatch_id_in_group
%
pipeline_parallel_size
==
0
else
:
return
False
def
is_last_microbatch_for_model_chunk
(
microbatch_id
:
int
)
->
bool
:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size
=
pipeline_parallel_size
*
num_model_chunks
num_microbatch_groups
=
total_num_microbatches
//
microbatch_group_size
microbatch_group_id
=
microbatch_id
//
microbatch_group_size
microbatch_id_in_group
=
microbatch_id
%
microbatch_group_size
if
microbatch_group_id
==
num_microbatch_groups
-
1
:
return
microbatch_id_in_group
%
pipeline_parallel_size
==
pipeline_parallel_size
-
1
else
:
return
False
def
forward_step_helper
(
microbatch_id
,
current_microbatch
,
checkpoint_activations_microbatch
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
param_sync_func
is
not
None
:
param_sync_microbatch_id
=
microbatch_id
+
pipeline_parallel_rank
if
(
param_sync_microbatch_id
<
total_num_microbatches
and
is_first_microbatch_for_model_chunk
(
param_sync_microbatch_id
)
):
param_sync_chunk_id
=
get_model_chunk_id
(
param_sync_microbatch_id
,
forward
=
True
)
+
1
if
1
<
param_sync_chunk_id
<
num_model_chunks
:
config
.
param_sync_func
[
param_sync_chunk_id
](
model
[
param_sync_chunk_id
].
parameters
()
)
# forward step
if
parallel_state
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch_for_model_chunk
(
microbatch_id
),
),
current_microbatch
=
current_microbatch
,
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# launch grad synchronization (default)
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
microbatch_id
):
enable_grad_sync
()
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
grad_sync_func
is
not
None
:
grad_sync_microbatch_id
=
microbatch_id
-
pipeline_parallel_rank
if
grad_sync_microbatch_id
>=
0
and
is_last_microbatch_for_model_chunk
(
grad_sync_microbatch_id
):
grad_sync_chunk_id
=
get_model_chunk_id
(
grad_sync_microbatch_id
,
forward
=
False
)
enable_grad_sync
()
config
.
grad_sync_func
[
grad_sync_chunk_id
](
model
[
grad_sync_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
grad_sync_chunk_id
)
disable_grad_sync
()
return
input_tensor_grad
# Run warmup forward passes.
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
fwd_wait_handles
=
None
bwd_wait_handles
=
None
for
k
in
range
(
num_warmup_microbatches
):
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
cur_model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
k
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
current_microbatch
=
get_microbatch_id_in_model_chunk
(
k
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
k
,
current_microbatch
,
checkpoint_activations_microbatch
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
total_num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
not
config
.
overlap_p2p_comm
:
if
(
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
):
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
config
=
config
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
else
:
input_tensor
,
fwd_wait_handles
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
if
(
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
not
all_warmup_microbatches
):
input_tensor_grad
=
None
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
(
output_tensor_grad
,
bwd_wait_handles
,
)
=
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
forward_k
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
current_microbatch
=
get_microbatch_id_in_model_chunk
(
forward_k
,
forward
=
True
)
if
config
.
overlap_p2p_comm
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
output_tensor
=
forward_step_helper
(
forward_k
,
current_microbatch
,
checkpoint_activations_microbatch
)
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
# Last virtual stage no activation tensor to send
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Send activation tensor to the next stage and receive activation tensor from the
# previous stage
input_tensor
,
fwd_wait_handles
=
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
# assert fwd_wait_handles is not None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
# First virtual stage no activation gradient tensor to send
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if the current virtual stage has an activation gradient tensor to receive
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
output_tensor_grad
,
bwd_wait_handles
=
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
else
:
# no p2p overlap
output_tensor
=
forward_step_helper
(
forward_k
,
current_microbatch
,
checkpoint_activations_microbatch
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
parallel_state
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Communicate tensors.
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
config
.
overlap_p2p_comm
and
bwd_wait_handles
is
not
None
:
for
wait_handle
in
bwd_wait_handles
:
wait_handle
.
wait
()
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
config
=
config
)
)
for
k
in
range
(
num_microbatches_remaining
,
total_num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
total_num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
)
)
# Launch any remaining grad reductions.
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
for
model_chunk_id
in
range
(
num_model_chunks
):
if
model_chunk_id
not
in
synchronized_model_chunks
:
config
.
grad_sync_func
[
model_chunk_id
](
model
[
model_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config
.
finalize_model_grads_func
(
model
,
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
return
forward_data_store
def
get_tensor_shapes
(
*
,
rank
:
int
,
model_type
:
ModelType
,
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
,
config
,
):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
tensor_shapes
=
[]
seq_length
=
seq_length
//
parallel_state
.
get_context_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
decoder_seq_length
=
decoder_seq_length
//
parallel_state
.
get_context_parallel_world_size
()
if
config
.
sequence_parallel
:
seq_length
=
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
decoder_seq_length
=
(
decoder_seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
)
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
parallel_state
.
is_pipeline_stage_before_split
(
rank
):
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
config
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
return
tensor_shapes
def
recv_forward
(
tensor_shapes
,
config
):
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
return
input_tensors
def
recv_backward
(
tensor_shapes
,
config
):
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
config
))
return
output_tensor_grads
def
send_forward
(
output_tensors
,
tensor_shapes
,
config
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_forward
(
output_tensor
,
config
)
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
config
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
config
)
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
config
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
config
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
config
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
config
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
*
,
forward_step_func
,
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
num_microbatches
:
int
,
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
=
None
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
if
isinstance
(
model
,
list
):
assert
(
len
(
model
)
==
1
),
"non-interleaved pipeline parallelism does not support model chunking"
model
=
model
[
0
]
if
isinstance
(
data_iterator
,
list
):
assert
(
len
(
data_iterator
)
==
1
),
"non-pipeline-parallel schedule does not support model chunking"
data_iterator
=
data_iterator
[
0
]
config
=
get_model_config
(
model
)
if
config
.
overlap_p2p_comm
:
raise
ValueError
(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
None
:
no_sync_context
=
no_sync_func
()
no_sync_context
.
__enter__
()
def
enable_grad_sync
():
"""Enable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
not
None
:
no_sync_context
.
__exit__
(
None
,
None
,
None
)
no_sync_context
=
None
disable_grad_sync
()
# Compute number of warmup microbatches.
num_warmup_microbatches
=
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops
=
None
if
config
.
num_microbatches_with_partial_activation_checkpoints
is
not
None
:
max_outstanding_backprops
=
num_warmup_microbatches
+
1
model_type
=
get_model_type
(
model
)
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
-
1
,
model_type
=
model_type
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
decoder_seq_length
=
decoder_seq_length
,
config
=
config
,
)
send_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
,
model_type
=
model_type
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
decoder_seq_length
=
decoder_seq_length
,
config
=
config
,
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
output_tensors
=
None
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
forward_data_store
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
i
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
i
,
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
[
0
],
config
.
deallocate_pipeline_outputs
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
i
==
(
num_microbatches_remaining
-
1
)
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
(
i
+
num_warmup_microbatches
)
%
max_outstanding_backprops
)
>=
config
.
num_microbatches_with_partial_activation_checkpoints
else
:
checkpoint_activations_microbatch
=
None
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
(
i
==
0
)
and
(
num_warmup_microbatches
==
0
)
),
current_microbatch
=
i
+
num_warmup_microbatches
,
)
total_num_tokens
+=
num_tokens
.
item
()
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
if
not
last_iteration
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
else
:
output_tensor_grad
=
send_forward_recv_backward
(
output_tensor
,
send_tensor_shapes
,
config
)
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
[
0
],
config
.
deallocate_pipeline_outputs
)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Enable grad sync for the last microbatch in the batch if the full
# backward pass completes in the 1F1B stage.
if
num_warmup_microbatches
==
0
and
last_iteration
:
if
config
.
grad_sync_func
is
None
or
rank
==
0
:
enable_grad_sync
()
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
if
last_iteration
:
input_tensor
=
None
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
else
:
input_tensor
=
send_backward_recv_forward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
# Run cooldown backward passes.
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if
i
==
num_warmup_microbatches
-
1
:
if
config
.
grad_sync_func
is
None
or
rank
==
0
:
enable_grad_sync
()
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
config
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
# Launch any remaining grad reductions.
if
no_sync_context
is
not
None
:
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
config
.
grad_sync_func
(
model
.
parameters
())
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config
.
finalize_model_grads_func
(
[
model
],
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
return
forward_data_store
megatron/core/requirements.txt
0 → 100644
View file @
0816dd4a
torch
\ No newline at end of file
megatron/core/ssm/__init__.py
0 → 100644
View file @
0816dd4a
megatron/core/ssm/mamba_block.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Union
from
torch
import
Tensor
,
nn
from
megatron.core
import
parallel_state
from
megatron.core.ssm.mamba_hybrid_layer_allocation
import
Symbols
as
LayerSymbols
from
megatron.core.ssm.mamba_hybrid_layer_allocation
import
allocate_layers
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.transformer.custom_layers.transformer_engine
import
TENorm
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
,
build_module
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.utils
import
make_viewless_tensor
def
create_mamba_block
(
config
,
mamba_layer_spec
,
residual_in_fp32
=
False
,
layer_idx
=
None
,
):
block
=
build_module
(
mamba_layer_spec
,
config
,
residual_in_fp32
=
residual_in_fp32
,
layer_idx
=
layer_idx
,
)
block
.
layer_idx
=
layer_idx
return
block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
# Now only used for embedding layer.
rescale_prenorm_residual
=
True
,
n_residuals_per_layer
=
1
,
# Change to 2 if we have MLP
):
with
get_cuda_rng_tracker
().
fork
():
if
isinstance
(
module
,
nn
.
Linear
):
if
not
getattr
(
module
.
weight
,
"_no_reinit"
,
False
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
module
.
bias
is
not
None
:
if
not
getattr
(
module
.
bias
,
"_no_reinit"
,
False
):
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"in_proj.weight"
,
"x_proj.weight"
,
"conv1d.weight"
,
"out_proj.weight"
]:
nn
.
init
.
kaiming_uniform
(
p
,
a
=
math
.
sqrt
(
5
))
if
rescale_prenorm_residual
:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
# Special Scaled Initialization
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
/
math
.
sqrt
(
n_residuals_per_layer
*
n_layer
),
)
@
dataclass
class
MambaStackSubmodules
:
mamba_layer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
attention_layer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
mlp_layer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
class
MambaStack
(
MegatronModule
):
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MambaStackSubmodules
,
residual_in_fp32
=
False
,
pre_process
:
bool
=
True
,
hybrid_attention_ratio
:
float
=
0.0
,
hybrid_mlp_ratio
:
float
=
0.0
,
hybrid_override_pattern
:
str
=
None
,
post_layer_norm
:
bool
=
True
,
post_process
:
bool
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
super
().
__init__
(
config
=
config
)
self
.
residual_in_fp32
=
residual_in_fp32
self
.
pre_process
=
pre_process
self
.
post_layer_norm
=
post_layer_norm
self
.
post_process
=
post_process
# Required for pipeline parallel schedules
self
.
input_tensor
=
None
self
.
hybrid_attention_ratio
=
hybrid_attention_ratio
self
.
hybrid_mlp_ratio
=
hybrid_mlp_ratio
self
.
hybrid_override_pattern
=
hybrid_override_pattern
layer_type_list
=
allocate_layers
(
self
.
config
.
num_layers
,
self
.
hybrid_attention_ratio
,
self
.
hybrid_mlp_ratio
,
self
.
hybrid_override_pattern
,
)
pp_layer_offset
=
0
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
pp_layer_offset
,
layer_type_list
=
self
.
_select_layers_for_pipeline_parallel
(
layer_type_list
)
self
.
layers
=
nn
.
ModuleList
()
for
i
,
layer_type
in
enumerate
(
layer_type_list
):
if
layer_type
==
LayerSymbols
.
MAMBA
:
layer_idx
=
i
+
pp_layer_offset
block
=
create_mamba_block
(
self
.
config
,
submodules
.
mamba_layer
,
residual_in_fp32
=
residual_in_fp32
,
layer_idx
=
layer_idx
,
)
elif
layer_type
==
LayerSymbols
.
ATTENTION
:
# Wondering if layer_number should be i+1. See TransformerBlock
# and TransformerLayer::sharded_state_dict
# Also, transformer layers apply their own pp_layer_offset
block
=
build_module
(
submodules
.
attention_layer
,
config
=
self
.
config
,
layer_number
=
i
)
elif
layer_type
==
LayerSymbols
.
MLP
:
# Wondering if layer_number should be i+1. See TransformerBlock
# and TransformerLayer::sharded_state_dict
# Also, transformer layers apply their own pp_layer_offset
block
=
build_module
(
submodules
.
mlp_layer
,
config
=
self
.
config
,
layer_number
=
i
)
else
:
assert
True
,
"unexpected layer_type"
self
.
layers
.
append
(
block
)
# Required for activation recomputation
self
.
num_layers_per_pipeline_rank
=
len
(
self
.
layers
)
if
self
.
post_process
and
self
.
post_layer_norm
:
# Final layer norm before output.
self
.
final_norm
=
TENorm
(
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
self
.
config
.
num_layers
,))
def
_select_layers_for_pipeline_parallel
(
self
,
layer_type_list
):
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
num_layers_per_pipeline_rank
=
(
self
.
config
.
num_layers
//
parallel_state
.
get_pipeline_model_parallel_world_size
()
)
assert
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
None
,
(
"The Mamba hybrid model does not currently support "
"virtual/interleaved pipeline parallelism"
)
offset
=
pipeline_rank
*
num_layers_per_pipeline_rank
selected_list
=
layer_type_list
[
offset
:
offset
+
num_layers_per_pipeline_rank
]
return
offset
,
selected_list
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
return
{
i
:
layer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
)
for
i
,
layer
in
enumerate
(
self
.
layers
)
}
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
inference_params
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
):
if
not
self
.
pre_process
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
if
inference_params
:
# NOTE(bnorick): match InferenceParams attributes for mamba_ssm.utils.generation.InferenceParams,
# this hack supports eval
inference_params
.
max_seqlen
=
inference_params
.
max_sequence_length
inference_params
.
seqlen_offset
=
inference_params
.
sequence_len_offset
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
)
# The attention layer (currently a simplified transformer layer)
# outputs a tuple of (hidden_states, context). Context is intended
# for cross-attention, and is not needed in our model.
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
=
hidden_states
[
0
]
# Final layer norm.
if
self
.
post_process
and
self
.
post_layer_norm
:
hidden_states
=
self
.
final_norm
(
hidden_states
)
# Ensure that the tensor passed between pipeline parallel stages is
# viewless. See related notes in TransformerBlock and TransformerLayer
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
return
hidden_states
megatron/core/ssm/mamba_hybrid_layer_allocation.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
if
__name__
!=
"__main__"
:
from
megatron.core.utils
import
log_single_rank
else
:
from
typing
import
Any
def
log_single_rank
(
logger
:
logging
.
Logger
,
*
args
:
Any
,
rank
:
int
=
0
,
**
kwargs
:
Any
):
print
(
*
args
[
1
:],
**
kwargs
)
logger
=
logging
.
getLogger
(
__name__
)
class
Symbols
:
MAMBA
=
'M'
ATTENTION
=
'*'
MLP
=
'-'
VALID
=
{
MAMBA
,
ATTENTION
,
MLP
}
def
_allocate_auto
(
total_layers_count
:
int
,
target_attention_ratio
:
float
,
target_mlp_ratio
:
float
)
->
list
:
# First, allocate attention (evenly spaced, starting and ending with mamba)
attention_layers_count
:
int
=
round
(
total_layers_count
*
target_attention_ratio
)
mamba_layers_count
:
int
=
total_layers_count
-
attention_layers_count
mamba_sections_count
:
int
=
attention_layers_count
+
1
mamba_section_length
:
float
=
mamba_layers_count
/
mamba_sections_count
layer_type_list
=
[
Symbols
.
MAMBA
]
*
total_layers_count
x
:
float
=
mamba_section_length
for
l
in
range
(
total_layers_count
):
if
x
<
0.5
:
layer_type_list
[
l
]
=
Symbols
.
ATTENTION
x
+=
mamba_section_length
else
:
x
-=
1
# Next, allocate mlp
# (evenly distributed, but right-justified, not replacing attention)
mlp_layers_count
:
int
=
round
(
total_layers_count
*
target_mlp_ratio
)
if
mlp_layers_count
>
0
:
mamba_layers_count
-=
mlp_layers_count
mamba_to_mlp_ratio
:
float
=
mamba_layers_count
/
mlp_layers_count
x
:
float
=
mamba_to_mlp_ratio
for
l
in
range
(
total_layers_count
):
if
layer_type_list
[
l
]
==
Symbols
.
MAMBA
:
if
x
<
0.5
:
layer_type_list
[
l
]
=
Symbols
.
MLP
x
+=
mamba_to_mlp_ratio
else
:
x
-=
1
return
layer_type_list
def
_allocate_override
(
total_layers_count
:
int
,
override_pattern
:
str
)
->
list
:
layer_type_list
=
list
(
override_pattern
)
override_pattern_length
=
len
(
layer_type_list
)
if
override_pattern_length
!=
total_layers_count
:
raise
ValueError
(
"The hybrid override pattern is the wrong "
f
"length: got
{
override_pattern_length
}
, expected "
f
"
{
total_layers_count
}
"
)
for
l
in
layer_type_list
:
if
l
not
in
Symbols
.
VALID
:
raise
ValueError
(
f
"In hybrid override pattern, '
{
l
}
' is not "
f
"one of
{
Symbols
.
VALID
}
"
)
return
layer_type_list
def
_layer_counts_match
(
a
:
list
,
b
:
list
)
->
bool
:
for
s
in
Symbols
.
VALID
:
if
a
.
count
(
s
)
!=
b
.
count
(
s
):
return
False
return
True
def
allocate_layers
(
total_layers_count
:
int
,
target_attention_ratio
:
float
,
target_mlp_ratio
:
float
,
override_pattern
:
str
=
None
,
)
->
list
:
assert
total_layers_count
>
0
assert
target_attention_ratio
>=
0.0
and
target_attention_ratio
<=
1.0
assert
target_mlp_ratio
>=
0.0
and
target_mlp_ratio
<=
1.0
assert
target_attention_ratio
+
target_mlp_ratio
<=
1.0
# Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio
layer_type_list
=
_allocate_auto
(
total_layers_count
,
target_attention_ratio
,
target_mlp_ratio
)
if
override_pattern
is
not
None
:
layer_type_list_override
=
_allocate_override
(
total_layers_count
,
override_pattern
)
log_single_rank
(
logger
,
logging
.
INFO
,
"Using hybrid override pattern"
)
if
(
target_attention_ratio
>
0.0
or
target_mlp_ratio
>
0.0
)
and
not
_layer_counts_match
(
layer_type_list_override
,
layer_type_list
):
raise
ValueError
(
"The number of each type of layer in the override "
"pattern must match the number in the overridden "
"pattern."
)
if
layer_type_list_override
==
layer_type_list
:
log_single_rank
(
logger
,
logging
.
INFO
,
"The override pattern matches the overridden pattern"
)
else
:
log_single_rank
(
logger
,
logging
.
INFO
,
"Warning: overriding pattern A with pattern B"
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"A:
{
''
.
join
(
layer_type_list
)
}
"
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"B:
{
''
.
join
(
layer_type_list_override
)
}
"
)
layer_type_list
=
layer_type_list_override
if
target_attention_ratio
>
0.0
or
target_mlp_ratio
>
0.0
or
override_pattern
is
not
None
:
actual_attention_layers_count
=
layer_type_list
.
count
(
Symbols
.
ATTENTION
)
actual_attention_ratio
=
actual_attention_layers_count
/
total_layers_count
actual_mlp_layers_count
=
layer_type_list
.
count
(
Symbols
.
MLP
)
actual_mlp_ratio
=
actual_mlp_layers_count
/
total_layers_count
allocation_string
=
''
.
join
(
layer_type_list
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"Hybrid allocation (
{
Symbols
.
MAMBA
}
is mamba, "
f
"
{
Symbols
.
ATTENTION
}
is attention, "
f
"
{
Symbols
.
MLP
}
is mlp):"
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
allocation_string
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"
{
actual_attention_layers_count
}
attention layers in "
f
"
{
total_layers_count
}
total layers."
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"Target attention ratio:
{
target_attention_ratio
:.
2
f
}
. "
f
"Actual attention ratio:
{
actual_attention_ratio
:.
2
f
}
."
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"
{
actual_mlp_layers_count
}
mlp layers in "
f
"
{
total_layers_count
}
total layers."
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"Target mlp ratio:
{
target_mlp_ratio
:.
2
f
}
. "
f
"Actual mlp ratio:
{
actual_mlp_ratio
:.
2
f
}
."
,
)
return
layer_type_list
if
__name__
==
"__main__"
:
test_cases
=
[
# (10, 0.2, 0.0),
# (48, 0.0, 0.0), # will not print anything
# (48, 0.1, 0.0),
# 48, 0.3, 0.0),
# (48, 0.5, 0.0),
# (48, 0.6, 0.0),
# (48, 0.7, 0.0),
# (10, 0.0, 0.1),
# (10, 0.0, 0.3),
# (10, 0.0, 0.5),
# (10, 0.1, 0.1),
# (10, 0.2, 0.2),
# (10, 0.3, 0.3),
# (10, 0.5, 0.5),
# (48, 0.2, 0.3),
# (48, 0.5, 0.2),
# (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"),
# (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.5, 0.5),
# (10, 0.3, 0.2, "MMM*-*M*M-"),
# (10, 0.3, 0.2, "MM*M-*M*M-"),
(
9
,
0.0
,
0.0
,
"M*-M*-M*-"
),
(
9
,
0.0
,
0.0
,
"MMMMMMMMM"
),
]
for
t
in
test_cases
:
print
(
""
)
allocate_layers
(
*
t
)
megatron/core/ssm/mamba_layer.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
from
typing
import
Union
import
torch
from
torch
import
Tensor
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
,
build_module
from
megatron.core.transformer.transformer_config
import
TransformerConfig
@
dataclass
class
MambaLayerSubmodules
:
norm
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
mixer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
class
MambaLayer
(
MegatronModule
):
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MambaLayerSubmodules
,
layer_idx
=
None
,
residual_in_fp32
=
False
,
):
"""
Top level Mamba Layer
"""
super
().
__init__
(
config
)
self
.
config
=
config
self
.
residual_in_fp32
=
residual_in_fp32
self
.
mixer
=
build_module
(
submodules
.
mixer
,
self
.
config
,
self
.
config
.
hidden_size
,
layer_idx
=
layer_idx
,
)
self
.
norm
=
build_module
(
submodules
.
norm
,
self
.
config
,
self
.
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
# Not used in MambaLayer
inference_params
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
# Not used in MambaLayer
):
residual
=
hidden_states
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
hidden_states
=
self
.
mixer
(
hidden_states
,
inference_params
=
inference_params
)
return
hidden_states
+
residual
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
)
megatron/core/ssm/mamba_mixer.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
megatron.core.parallel_state
import
get_tensor_model_parallel_world_size
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
copy_to_tensor_model_parallel_region
,
gather_from_sequence_parallel_region
,
get_cuda_rng_tracker
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.transformer_config
import
TransformerConfig
try
:
from
mamba_ssm.ops.triton.selective_state_update
import
selective_state_update
except
ImportError
:
selective_state_update
=
None
try
:
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
except
ImportError
:
causal_conv1d_fn
=
None
causal_conv1d_update
=
None
try
:
from
mamba_ssm.ops.triton.layernorm_gated
import
RMSNorm
as
RMSNormGated
from
mamba_ssm.ops.triton.ssd_combined
import
mamba_chunk_scan_combined
except
ImportError
:
raise
ImportError
(
"mamba-ssm is required by the Mamba model but cannot be imported"
)
try
:
from
einops
import
rearrange
,
repeat
except
ImportError
:
raise
ImportError
(
"einops is required by the Mamba model but cannot be imported"
)
class
Mamba
(
MegatronModule
):
def
__init__
(
self
,
config
:
TransformerConfig
,
d_model
,
d_state
=
128
,
d_conv
=
4
,
conv_init
=
None
,
expand
=
2
,
headdim
=
64
,
ngroups
=
8
,
A_init_range
=
(
1
,
16
),
D_has_hdim
=
False
,
rmsnorm
=
True
,
norm_before_gate
=
False
,
dt_min
=
0.001
,
dt_max
=
0.1
,
dt_init
=
"random"
,
dt_scale
=
1.0
,
dt_init_floor
=
1e-4
,
bias
=
False
,
conv_bias
=
True
,
# Fused kernel and sharding options
chunk_size
=
128
,
use_fast_path
=
True
,
layer_idx
=
None
,
):
super
().
__init__
(
config
)
self
.
config
=
config
self
.
d_model
=
d_model
self
.
d_state
=
d_state
self
.
d_conv
=
d_conv
self
.
conv_init
=
conv_init
self
.
expand
=
expand
self
.
d_inner
=
int
(
self
.
expand
*
self
.
d_model
)
self
.
headdim
=
headdim
self
.
ngroups
=
ngroups
assert
self
.
d_inner
%
self
.
headdim
==
0
self
.
nheads
=
self
.
d_inner
//
self
.
headdim
self
.
D_has_hdim
=
D_has_hdim
self
.
rmsnorm
=
rmsnorm
self
.
norm_before_gate
=
norm_before_gate
self
.
chunk_size
=
chunk_size
self
.
use_fast_path
=
use_fast_path
self
.
layer_idx
=
layer_idx
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
d_inner
%
self
.
tensor_model_parallel_size
==
0
assert
self
.
ngroups
%
self
.
tensor_model_parallel_size
==
0
assert
self
.
nheads
%
self
.
tensor_model_parallel_size
==
0
assert
not
bias
self
.
d_inner_local
=
self
.
d_inner
//
self
.
tensor_model_parallel_size
self
.
ngroups_local
=
self
.
ngroups
//
self
.
tensor_model_parallel_size
self
.
nheads_local
=
self
.
nheads
//
self
.
tensor_model_parallel_size
assert
self
.
d_inner_local
%
self
.
ngroups_local
==
0
# Assume sequence parallelism: input is already partitioned along the
# sequence dimension
self
.
in_proj
=
ColumnParallelLinear
(
self
.
d_model
,
self
.
d_inner
*
2
+
2
*
self
.
ngroups
*
self
.
d_state
+
self
.
nheads
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
gather_output
=
False
,
bias
=
bias
,
)
conv_dim
=
self
.
d_inner_local
+
2
*
self
.
ngroups_local
*
self
.
d_state
with
get_cuda_rng_tracker
().
fork
():
self
.
conv1d
=
nn
.
Conv1d
(
in_channels
=
conv_dim
,
out_channels
=
conv_dim
,
bias
=
conv_bias
,
kernel_size
=
d_conv
,
groups
=
conv_dim
,
padding
=
d_conv
-
1
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
setattr
(
self
.
conv1d
.
weight
,
'tensor_model_parallel'
,
True
)
setattr
(
self
.
conv1d
.
bias
,
'tensor_model_parallel'
,
True
)
if
self
.
conv_init
is
not
None
:
nn
.
init
.
uniform_
(
self
.
conv1d
.
weight
,
-
self
.
conv_init
,
self
.
conv_init
)
self
.
activation
=
"silu"
self
.
act
=
nn
.
SiLU
()
with
get_cuda_rng_tracker
().
fork
():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt
=
torch
.
exp
(
torch
.
rand
(
self
.
nheads_local
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
)
*
(
math
.
log
(
dt_max
)
-
math
.
log
(
dt_min
))
+
math
.
log
(
dt_min
)
).
clamp
(
min
=
dt_init_floor
)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt
=
dt
+
torch
.
log
(
-
torch
.
expm1
(
-
dt
))
with
torch
.
no_grad
():
self
.
dt_bias
=
nn
.
Parameter
(
inv_dt
)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self
.
dt_bias
.
_no_reinit
=
True
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self
.
dt_bias
.
_no_weight_decay
=
True
assert
A_init_range
[
0
]
>
0
and
A_init_range
[
1
]
>=
A_init_range
[
0
]
A
=
torch
.
empty
(
self
.
nheads_local
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
).
uniform_
(
*
A_init_range
)
A_log
=
torch
.
log
(
A
)
# Keep A_log in fp32
self
.
A_log
=
nn
.
Parameter
(
A_log
)
self
.
A_log
.
_no_weight_decay
=
True
setattr
(
self
.
A_log
,
'tensor_model_parallel'
,
True
)
# D "skip" parameter
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
self
.
d_inner_local
if
self
.
D_has_hdim
else
self
.
nheads_local
,
device
=
torch
.
cuda
.
current_device
(),
)
)
# Keep in fp32
self
.
D
.
_no_weight_decay
=
True
setattr
(
self
.
D
,
'tensor_model_parallel'
,
True
)
if
self
.
rmsnorm
:
assert
RMSNormGated
is
not
None
self
.
norm
=
RMSNormGated
(
self
.
d_inner_local
,
eps
=
1e-5
,
group_size
=
self
.
d_inner_local
//
self
.
ngroups_local
,
norm_before_gate
=
False
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
params_dtype
,
)
# Assume sequence parallelism: input is partitioned along d_inner and
# output is partitioned along the sequence dimension
self
.
out_proj
=
RowParallelLinear
(
self
.
d_inner
,
self
.
d_model
,
config
=
self
.
config
,
init_method
=
self
.
config
.
output_layer_init_method
,
bias
=
bias
,
input_is_parallel
=
True
,
skip_bias_add
=
False
,
)
def
forward
(
self
,
hidden_states
,
inference_params
=
None
):
"""
hidden_states: (nL, B, D) / (L B D)
Returns: same shape as hidden_states
"""
_
,
batch
,
dim
=
hidden_states
.
shape
conv_state
,
ssm_state
=
None
,
None
if
inference_params
is
not
None
:
assert
not
self
.
config
.
sequence_parallel
conv_state
,
ssm_state
=
self
.
_get_states_from_cache
(
inference_params
,
batch
)
if
inference_params
.
seqlen_offset
>
0
:
# The states are updated inplace
out
,
_
,
_
=
self
.
step
(
hidden_states
,
conv_state
,
ssm_state
)
return
out
# (nheads_local)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# pl b d -> l b p(2d)
# TODO move transpose to GEMM
if
self
.
config
.
sequence_parallel
:
# gather data along sequenece dimension
hidden_states
=
gather_from_sequence_parallel_region
(
hidden_states
)
else
:
hidden_states
=
copy_to_tensor_model_parallel_region
(
hidden_states
)
xz
=
hidden_states
@
self
.
in_proj
.
weight
.
t
()
z
,
xBC
,
dt
=
torch
.
split
(
xz
,
[
self
.
d_inner_local
,
self
.
d_inner_local
+
2
*
self
.
ngroups_local
*
self
.
d_state
,
self
.
nheads_local
,
],
dim
=-
1
,
)
# transpose: l b pd --> b pd l
xBC
=
rearrange
(
xBC
,
"l b d -> b d l"
)
xBC
=
xBC
.
contiguous
()
# Compute short convolution
if
conv_state
is
not
None
:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state
.
copy_
(
F
.
pad
(
xBC
,
(
self
.
d_conv
-
xBC
.
shape
[
-
1
],
0
)))
# Update state (B D W)
seqlen
=
xBC
.
size
(
2
)
if
causal_conv1d_fn
is
None
:
xBC
=
self
.
act
(
self
.
conv1d
(
xBC
)[...,
:
seqlen
])
else
:
assert
self
.
activation
in
[
"silu"
,
"swish"
]
xBC
=
causal_conv1d_fn
(
x
=
xBC
,
weight
=
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
)
# transpose b pd l --> l b pd
xBC
=
rearrange
(
xBC
,
"b d l -> l b d"
)
xBC
=
xBC
.
contiguous
()
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_inner_local
,
self
.
ngroups_local
*
self
.
d_state
,
self
.
ngroups_local
*
self
.
d_state
,
],
dim
=-
1
,
)
# TODO Vijay: fuse most of the transposes with the GEMMS
x
=
rearrange
(
x
,
"l b (h p) -> b l h p"
,
p
=
self
.
headdim
).
contiguous
()
dt
=
rearrange
(
dt
,
"l b d -> b l d"
).
contiguous
()
B
=
rearrange
(
B
,
"l b (g n) -> b l g n"
,
n
=
self
.
d_state
).
contiguous
()
C
=
rearrange
(
C
,
"l b (g n) -> b l g n"
,
n
=
self
.
d_state
).
contiguous
()
z
=
rearrange
(
z
,
"l b (h p) -> b l h p"
,
p
=
self
.
headdim
).
contiguous
()
y
=
mamba_chunk_scan_combined
(
x
,
dt
,
A
,
B
,
C
,
self
.
chunk_size
,
D
=
rearrange
(
self
.
D
.
float
(),
"(h p) -> h p"
,
p
=
self
.
headdim
)
if
self
.
D_has_hdim
else
self
.
D
,
z
=
z
if
not
self
.
rmsnorm
else
None
,
dt_bias
=
self
.
dt_bias
.
float
(),
dt_softplus
=
True
,
return_final_states
=
ssm_state
is
not
None
,
)
if
ssm_state
is
not
None
:
y
,
last_state
=
y
ssm_state
.
copy_
(
last_state
)
if
self
.
rmsnorm
:
y
=
rearrange
(
y
,
"b l h p -> b l (h p)"
).
contiguous
()
z
=
rearrange
(
z
,
"b l h p -> b l (h p)"
).
contiguous
()
y
=
self
.
norm
(
y
,
z
)
y
=
rearrange
(
y
,
"b l d -> l b d"
).
contiguous
()
else
:
y
=
rearrange
(
y
,
"b l h p -> l b (h p)"
).
contiguous
()
# l b pd --> pl b d
out_full
=
y
@
self
.
out_proj
.
weight
.
t
()
if
self
.
config
.
sequence_parallel
:
out
=
reduce_scatter_to_sequence_parallel_region
(
out_full
)
else
:
out
=
reduce_from_tensor_model_parallel_region
(
out_full
)
return
out
def
step
(
self
,
hidden_states
,
conv_state
,
ssm_state
):
# assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now"
dtype
=
hidden_states
.
dtype
assert
hidden_states
.
shape
[
0
]
==
1
,
"Only support decoding with 1 token at a time for now"
# l b d --> b d
hidden_states
=
hidden_states
.
squeeze
(
0
)
# b d_model --> b p(2d)
xz
=
hidden_states
@
self
.
in_proj
.
weight
.
t
()
z
,
xBC
,
dt
=
torch
.
split
(
xz
,
[
self
.
d_inner_local
,
self
.
d_inner_local
+
2
*
self
.
ngroups_local
*
self
.
d_state
,
self
.
nheads_local
,
],
dim
=-
1
,
)
# Conv step
if
causal_conv1d_update
is
None
:
conv_state
.
copy_
(
torch
.
roll
(
conv_state
,
shifts
=-
1
,
dims
=-
1
))
# Update state (B D W)
conv_state
[:,
:,
-
1
]
=
xBC
xBC
=
torch
.
sum
(
conv_state
*
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
dim
=-
1
)
# (B D)
if
self
.
conv1d
.
bias
is
not
None
:
xBC
=
xBC
+
self
.
conv1d
.
bias
xBC
=
self
.
act
(
xBC
).
to
(
dtype
=
dtype
)
else
:
xBC
=
causal_conv1d_update
(
xBC
,
conv_state
,
rearrange
(
self
.
conv1d
.
weight
,
"d 1 w -> d w"
),
self
.
conv1d
.
bias
,
self
.
activation
,
)
x
,
B
,
C
=
torch
.
split
(
xBC
,
[
self
.
d_inner_local
,
self
.
ngroups_local
*
self
.
d_state
,
self
.
ngroups_local
*
self
.
d_state
,
],
dim
=-
1
,
)
A
=
-
torch
.
exp
(
self
.
A_log
.
float
())
# SSM step
if
selective_state_update
is
None
:
if
self
.
ngroups_local
>
1
:
B
=
rearrange
(
B
,
"b (g n) -> b g n"
,
n
=
self
.
d_state
)
C
=
rearrange
(
C
,
"b (g n) -> b g n"
,
n
=
self
.
d_state
)
B
=
repeat
(
B
,
"b g n -> b (g h) n"
,
h
=
self
.
d_inner_local
//
self
.
ngroups_local
)
C
=
repeat
(
C
,
"b g n -> b (g h) n"
,
h
=
self
.
d_inner_local
//
self
.
ngroups_local
)
dt
=
repeat
(
dt
,
"b h -> b (h p)"
,
p
=
self
.
headdim
)
dt_bias
=
repeat
(
self
.
dt_bias
,
"h -> (h p)"
,
p
=
self
.
headdim
)
A
=
repeat
(
A
,
"h -> (h p) n"
,
p
=
self
.
headdim
,
n
=
self
.
d_state
)
D
=
repeat
(
self
.
D
,
"h -> (h p)"
,
p
=
self
.
headdim
)
dt
=
F
.
softplus
(
dt
+
dt_bias
.
to
(
dtype
=
dt
.
dtype
))
dA
=
torch
.
exp
(
torch
.
einsum
(
"bd,dn->bdn"
,
dt
,
A
))
dB_x
=
torch
.
einsum
(
'bd,bdn,bd->bdn'
,
dt
,
B
,
x
)
ssm_state
.
copy_
(
ssm_state
*
rearrange
(
dA
,
"b (h p) n -> b h p n"
,
p
=
self
.
headdim
)
+
rearrange
(
dB_x
,
"b (h p) n -> b h p n"
,
p
=
self
.
headdim
)
)
y
=
torch
.
einsum
(
"bdn,bdn->bd"
,
rearrange
(
ssm_state
.
to
(
dtype
),
"b h p n -> b (h p) n"
,
p
=
self
.
headdim
),
C
,
)
y
=
y
+
D
.
to
(
dtype
)
*
x
if
not
self
.
rmsnorm
:
y
=
y
*
self
.
act
(
z
)
# (B D)
else
:
# Discretize A and B (b (g n))
dt
=
F
.
softplus
(
dt
+
self
.
dt_bias
.
to
(
dtype
=
dt
.
dtype
))
# (batch, nheads)
dA
=
torch
.
exp
(
dt
*
A
)
x
=
rearrange
(
x
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
dBx
=
torch
.
einsum
(
"bh,bn,bhp->bhpn"
,
dt
,
B
,
x
)
ssm_state
.
copy_
(
ssm_state
*
rearrange
(
dA
,
"b h -> b h 1 1"
)
+
dBx
)
y
=
torch
.
einsum
(
"bhpn,bn->bhp"
,
ssm_state
.
to
(
dtype
),
C
)
y
=
y
+
rearrange
(
self
.
D
.
to
(
dtype
),
"h -> h 1"
)
*
x
y
=
rearrange
(
y
,
"b h p -> b (h p)"
)
if
not
self
.
rmsnorm
:
y
=
y
*
self
.
act
(
z
)
# (B D)
else
:
A
=
repeat
(
A
,
"h -> h p n"
,
p
=
self
.
headdim
,
n
=
self
.
d_state
).
to
(
dtype
=
torch
.
float32
)
dt
=
repeat
(
dt
,
"b h -> b h p"
,
p
=
self
.
headdim
)
dt_bias
=
repeat
(
self
.
dt_bias
,
"h -> h p"
,
p
=
self
.
headdim
)
D
=
repeat
(
self
.
D
,
"h -> h p"
,
p
=
self
.
headdim
)
B
=
rearrange
(
B
,
"b (g n) -> b g n"
,
g
=
self
.
ngroups_local
)
C
=
rearrange
(
C
,
"b (g n) -> b g n"
,
g
=
self
.
ngroups_local
)
x_reshaped
=
rearrange
(
x
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
if
not
self
.
rmsnorm
:
z
=
rearrange
(
z
,
"b (h p) -> b h p"
,
p
=
self
.
headdim
)
y
=
selective_state_update
(
ssm_state
,
x_reshaped
,
dt
,
A
,
B
,
C
,
D
,
z
=
z
if
not
self
.
rmsnorm
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
)
y
=
rearrange
(
y
,
"b h p -> b (h p)"
)
if
self
.
rmsnorm
:
y
=
self
.
norm
(
y
,
z
)
# b pd --> b d
out
=
y
@
self
.
out_proj
.
weight
.
t
()
out
=
reduce_from_tensor_model_parallel_region
(
out
)
return
out
.
unsqueeze
(
0
),
conv_state
,
ssm_state
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
device
=
self
.
out_proj
.
weight
.
device
conv_dtype
=
self
.
conv1d
.
weight
.
dtype
if
dtype
is
None
else
dtype
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
conv1d
.
weight
.
shape
[
0
],
self
.
d_conv
,
device
=
device
,
dtype
=
conv_dtype
)
ssm_dtype
=
self
.
in_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
# ssm_dtype = torch.float32
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
nheads_local
,
self
.
headdim
,
self
.
d_state
,
device
=
device
,
dtype
=
ssm_dtype
,
)
return
conv_state
,
ssm_state
def
_get_states_from_cache
(
self
,
inference_params
,
batch_size
,
initialize_states
=
False
):
assert
self
.
layer_idx
is
not
None
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
conv_state
=
torch
.
zeros
(
batch_size
,
self
.
conv1d
.
weight
.
shape
[
0
],
self
.
d_conv
,
device
=
self
.
conv1d
.
weight
.
device
,
dtype
=
self
.
conv1d
.
weight
.
dtype
,
)
ssm_state
=
torch
.
zeros
(
batch_size
,
self
.
nheads_local
,
self
.
headdim
,
self
.
d_state
,
device
=
self
.
in_proj
.
weight
.
device
,
dtype
=
self
.
in_proj
.
weight
.
dtype
,
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
conv_state
,
ssm_state
)
else
:
conv_state
,
ssm_state
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# TODO: What if batch size changes between generation, and we reuse the same states?
if
initialize_states
:
conv_state
.
zero_
()
ssm_state
.
zero_
()
return
conv_state
,
ssm_state
megatron/core/ssm/triton_cache_manager.py
0 → 100644
View file @
0816dd4a
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
os
import
socket
from
pathlib
import
Path
import
torch
try
:
from
triton.runtime.cache
import
FileCacheManager
except
ImportError
:
raise
ImportError
(
"triton is required by the Mamba model but cannot be imported"
)
def
get_rank
():
return
torch
.
distributed
.
get_rank
()
def
default_cache_dir
():
return
os
.
path
.
join
(
Path
.
home
(),
".triton"
,
"cache"
)
class
ParallelFileCacheManager
(
FileCacheManager
):
# See https://github.com/triton-lang/triton/blob/main/python/triton/runtime/cache.py
# When running Triton with multiple ranks, they each create their own cache manager. Their input
# keys to that class are mostly (but not entirely) the same across ranks, which leads many ranks
# to write to the same 'key' directories in the cache dir at the same time during compilation,
# leading to conflicts. This works around that by making each cache dir be rank specific by
# adding "rank_<host>_<pid>" to the cache directory.
def
__init__
(
self
,
key
):
self
.
key
=
key
self
.
lock_path
=
None
# create cache directory if it doesn't exist
self
.
cache_dir
=
os
.
environ
.
get
(
'TRITON_CACHE_DIR'
,
default_cache_dir
())
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
"rank_{}_{}"
.
format
(
socket
.
gethostname
(),
os
.
getpid
())
)
if
self
.
cache_dir
:
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
self
.
lock_path
=
os
.
path
.
join
(
self
.
cache_dir
,
"lock"
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
megatron/core/tensor_parallel/__init__.py
0 → 100644
View file @
0816dd4a
from
.cross_entropy
import
vocab_parallel_cross_entropy
from
.data
import
broadcast_data
from
.layers
import
(
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
copy_tensor_model_parallel_attributes
,
linear_with_grad_accumulation_and_async_allreduce
,
param_is_not_tensor_parallel_duplicate
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
)
from
.mappings
import
(
all_gather_last_dim_from_tensor_parallel_region
,
all_to_all
,
all_to_all_hp2sp
,
all_to_all_sp2hp
,
copy_to_tensor_model_parallel_region
,
gather_from_sequence_parallel_region
,
gather_from_sequence_parallel_region_to_moe
,
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_last_dim_to_tensor_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
reduce_scatter_to_sequence_parallel_region_from_moe
,
scatter_to_sequence_parallel_region
,
scatter_to_tensor_model_parallel_region
,
)
from
.random
import
(
checkpoint
,
get_cuda_rng_tracker
,
get_data_parallel_rng_tracker_name
,
model_parallel_cuda_manual_seed
,
)
from
.utils
import
(
gather_split_1d_tensor
,
split_tensor_along_last_dim
,
split_tensor_into_1d_equal_chunks
,
)
__all__
=
[
# cross_entropy.py
"vocab_parallel_cross_entropy"
,
# data.py
"broadcast_data"
,
# layers.py
"ColumnParallelLinear"
,
"RowParallelLinear"
,
"VocabParallelEmbedding"
,
"set_tensor_model_parallel_attributes"
,
"set_defaults_if_not_set_tensor_model_parallel_attributes"
,
"copy_tensor_model_parallel_attributes"
,
"param_is_not_tensor_parallel_duplicate"
,
"linear_with_grad_accumulation_and_async_allreduce"
,
# mappings.py
"copy_to_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"gather_from_sequence_parallel_region"
,
"reduce_from_tensor_model_parallel_region"
,
"reduce_scatter_to_sequence_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
# random.py
"checkpoint"
,
"get_cuda_rng_tracker"
,
"model_parallel_cuda_manual_seed"
,
# utils.py
"split_tensor_along_last_dim"
,
"split_tensor_into_1d_equal_chunks"
,
"gather_split_1d_tensor"
,
"gather_from_sequence_parallel_region_to_moe"
,
"reduce_scatter_to_sequence_parallel_region_from_moe"
,
]
megatron/core/tensor_parallel/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0816dd4a
File added
megatron/core/tensor_parallel/__pycache__/cross_entropy.cpython-310.pyc
0 → 100644
View file @
0816dd4a
File added
Prev
1
…
7
8
9
10
11
12
13
14
15
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment