Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
9eb8683b
Commit
9eb8683b
authored
Apr 10, 2025
by
dongcl
Browse files
Merge branch 'main' into megatron_v0.11.0
parents
6f016785
be9a69d7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1020 additions
and
1038 deletions
+1020
-1038
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+20
-6
dcu_megatron/core/__init__.py
dcu_megatron/core/__init__.py
+0
-1
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+126
-124
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+7
-0
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+847
-2
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+11
-905
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+9
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
9eb8683b
...
...
@@ -123,7 +123,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
,
transformer_block_forward
from
..core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
from
..core.transformer.transformer_config
import
TransformerConfig
Patch
,
MLATransformerConfig
Patch
# Transformer block
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
...
...
@@ -133,9 +133,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# Transformer config
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.TransformerConfig'
,
TransformerConfig
)
TransformerConfig
Patch
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.MLATransformerConfig'
,
MLATransformerConfig
)
MLATransformerConfig
Patch
)
# Moe
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
...
...
@@ -154,18 +154,19 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_core_extentions
(
self
):
import
transformer_engine
as
te
from
..core.extensions.transformer_engine
import
te_dot_p
roduct
_a
ttention
_init
from
..core.extensions.transformer_engine
import
TEDotP
roduct
A
ttention
Patch
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
te_dot_p
roduct
_a
ttention_init
)
TEDotP
roduct
A
ttention
Patch
.
_
_init
__
)
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
def
patch_tensor_parallel
(
self
):
from
..core
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
..core.tensor_parallel
import
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
parallel_linear_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
...
...
@@ -186,6 +187,19 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod
,
apply_wrapper
=
True
)
# flux
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__"
,
parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward"
,
ColumnParallelLinearPatch
.
forward
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.__init__"
,
parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.forward"
,
RowParallelLinearPatch
.
forward
)
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
dcu_megatron/core/__init__.py
View file @
9eb8683b
from
.tensor_parallel.layers
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
.transformer.transformer_block
import
transformer_block_init_wrapper
,
transformer_block_forward
dcu_megatron/core/extensions/transformer_engine.py
View file @
9eb8683b
import
os
import
dataclasses
import
transformer_engine
as
te
from
typing
import
Any
,
Optional
from
packaging.version
import
Version
as
PkgVersion
...
...
@@ -19,135 +20,136 @@ from megatron.core.parallel_state import (
)
def
te_dot_product_attention_init
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
,
attn_mask_type
:
AttnMaskType
,
attention_type
:
str
,
attention_dropout
:
Optional
[
float
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
k_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
cp_comm_type
:
str
=
"p2p"
,
):
self
.
config
=
config
self
.
te_forward_mask_type
=
False
self
.
qkv_format
:
str
=
'sbhd'
if
self
.
config
.
apply_query_key_layer_scaling
!=
bool
(
int
(
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
,
'0'
))
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
,
attn_mask_type
:
AttnMaskType
,
attention_type
:
str
,
attention_dropout
:
Optional
[
float
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
k_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
cp_comm_type
:
str
=
"p2p"
,
):
raise
ValueError
(
f
"apply_query_key_layer_scaling is
{
self
.
config
.
apply_query_key_layer_scaling
}
"
f
"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f
"
{
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
)
}
. Transformer Engine does not support "
f
"setting query key layer scaling via argument, so these two must match."
)
self
.
config
=
config
self
.
te_forward_mask_type
=
False
self
.
qkv_format
:
str
=
'sbhd'
if
self
.
config
.
apply_query_key_layer_scaling
!=
bool
(
int
(
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
,
'0'
))
):
raise
ValueError
(
f
"apply_query_key_layer_scaling is
{
self
.
config
.
apply_query_key_layer_scaling
}
"
f
"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f
"
{
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
)
}
. Transformer Engine does not support "
f
"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs
:
dict
[
str
,
Any
]
=
{}
if
is_te_min_version
(
"0.11.0"
):
extra_kwargs
[
"num_gqa_groups"
]
=
self
.
config
.
num_query_groups
elif
self
.
config
.
num_query_groups
!=
self
.
config
.
num_attention_heads
:
raise
ValueError
(
f
"Transformer Engine v
{
get_te_version
()
}
does not support Grouped Query Attention, "
f
"use a newer version of Transformer Engine. "
f
"(num_query_groups (
{
self
.
config
.
num_query_groups
}
) != "
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
)
extra_kwargs
:
dict
[
str
,
Any
]
=
{}
if
is_te_min_version
(
"0.11.0"
):
extra_kwargs
[
"num_gqa_groups"
]
=
self
.
config
.
num_query_groups
elif
self
.
config
.
num_query_groups
!=
self
.
config
.
num_attention_heads
:
raise
ValueError
(
f
"Transformer Engine v
{
get_te_version
()
}
does not support Grouped Query Attention, "
f
"use a newer version of Transformer Engine. "
f
"(num_query_groups (
{
self
.
config
.
num_query_groups
}
) != "
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
)
if
is_te_min_version
(
"0.10.0"
):
extra_kwargs
[
"attention_type"
]
=
attention_type
# older version don't need attention_type
if
is_te_min_version
(
"0.12.0"
,
check_equality
=
False
):
self
.
te_forward_mask_type
=
True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if
self
.
config
.
context_parallel_size
>
1
:
assert
is_te_min_version
(
"1.0.0"
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
extra_kwargs
[
"cp_group"
]
=
get_context_parallel_group
(
check_initialized
=
False
)
extra_kwargs
[
"cp_global_ranks"
]
=
get_context_parallel_global_ranks
(
check_initialized
=
False
)
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
if
is_te_min_version
(
"1.10.0"
):
if
cp_comm_type
is
None
:
extra_kwargs
[
"cp_comm_type"
]
=
"p2p"
elif
cp_comm_type
==
"a2a+p2p"
:
assert
is_te_min_version
(
"1.12.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs
[
"cp_comm_type"
]
=
"a2a+p2p"
extra_kwargs
[
"cp_group"
]
=
get_hierarchical_context_parallel_groups
(
check_initialized
=
False
if
is_te_min_version
(
"0.10.0"
):
extra_kwargs
[
"attention_type"
]
=
attention_type
# older version don't need attention_type
if
is_te_min_version
(
"0.12.0"
,
check_equality
=
False
):
self
.
te_forward_mask_type
=
True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if
self
.
config
.
context_parallel_size
>
1
:
assert
is_te_min_version
(
"1.0.0"
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
extra_kwargs
[
"cp_group"
]
=
get_context_parallel_group
(
check_initialized
=
False
)
extra_kwargs
[
"cp_global_ranks"
]
=
get_context_parallel_global_ranks
(
check_initialized
=
False
)
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
if
is_te_min_version
(
"1.10.0"
):
if
cp_comm_type
is
None
:
extra_kwargs
[
"cp_comm_type"
]
=
"p2p"
elif
cp_comm_type
==
"a2a+p2p"
:
assert
is_te_min_version
(
"1.12.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs
[
"cp_comm_type"
]
=
"a2a+p2p"
extra_kwargs
[
"cp_group"
]
=
get_hierarchical_context_parallel_groups
(
check_initialized
=
False
)
else
:
extra_kwargs
[
"cp_comm_type"
]
=
cp_comm_type
if
self
.
config
.
deterministic_mode
:
if
int
(
os
.
getenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
,
"1"
))
!=
0
:
raise
RuntimeError
(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f
"Currently set to:
{
os
.
getenv
(
'NVTE_ALLOW_NONDETERMINISTIC_ALGO'
,
'not set'
)
}
."
)
else
:
extra_kwargs
[
"cp_comm_type"
]
=
cp_comm_type
if
self
.
config
.
deterministic_mode
:
if
int
(
os
.
getenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
,
"1"
))
!=
0
:
raise
RuntimeError
(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f
"Currently set to:
{
os
.
getenv
(
'NVTE_ALLOW_NONDETERMINISTIC_ALGO'
,
'not set'
)
}
."
if
config
.
window_size
is
not
None
:
# Check version
assert
is_te_min_version
(
"1.2.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs
[
'window_size'
]
=
config
.
window_size
if
is_te_min_version
(
"1.9.0"
):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels
=
(
(
k_channels
,
v_channels
)
if
k_channels
is
not
None
and
v_channels
is
not
None
else
self
.
config
.
kv_channels
)
extra_kwargs
[
'softmax_scale'
]
=
softmax_scale
else
:
kv_channels
=
self
.
config
.
kv_channels
if
config
.
window_size
is
not
None
:
# Check version
assert
is_te_min_version
(
"1.2.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.2.0 to support"
"sliding window attention."
self
.
kept_packed_seq_params
=
set
(
field
.
name
for
field
in
dataclasses
.
fields
(
PackedSeqParams
)
)
extra_kwargs
[
'window_size'
]
=
config
.
window_size
if
is_te_min_version
(
"1.9.0"
):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels
=
(
(
k_channels
,
v_channels
)
if
k_channels
is
not
None
and
v_channels
is
not
None
else
self
.
config
.
kv_channels
if
get_te_version
()
<
PkgVersion
(
"1.3.0"
):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_q"
)
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_kv"
)
if
get_te_version
()
<
PkgVersion
(
"1.10.0"
):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_q_padded"
)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_kv_padded"
)
super
(
TEDotProductAttention
,
self
).
__init__
(
num_attention_heads
=
self
.
config
.
num_attention_heads
,
kv_channels
=
kv_channels
,
attention_dropout
=
(
self
.
config
.
attention_dropout
if
attention_dropout
is
None
else
attention_dropout
),
attn_mask_type
=
attn_mask_type
.
name
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
layer_number
=
layer_number
,
**
extra_kwargs
,
)
extra_kwargs
[
'softmax_scale'
]
=
softmax_scale
else
:
kv_channels
=
self
.
config
.
kv_channels
self
.
kept_packed_seq_params
=
set
(
field
.
name
for
field
in
dataclasses
.
fields
(
PackedSeqParams
)
)
if
get_te_version
()
<
PkgVersion
(
"1.3.0"
):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_q"
)
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_kv"
)
if
get_te_version
()
<
PkgVersion
(
"1.10.0"
):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_q_padded"
)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_kv_padded"
)
super
(
TEDotProductAttention
,
self
).
__init__
(
num_attention_heads
=
self
.
config
.
num_attention_heads
,
kv_channels
=
kv_channels
,
attention_dropout
=
(
self
.
config
.
attention_dropout
if
attention_dropout
is
None
else
attention_dropout
),
attn_mask_type
=
attn_mask_type
.
name
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
layer_number
=
layer_number
,
**
extra_kwargs
,
)
dcu_megatron/core/tensor_parallel/__init__.py
0 → 100644
View file @
9eb8683b
from
.layers
import
(
parallel_linear_init_wrapper
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
,
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
9eb8683b
from
typing
import
Callable
import
flux
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
...
...
@@ -9,18 +10,29 @@ from megatron.core.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
megatron.core.utils
import
is_torch_min_version
from
megatron.core.utils
import
(
is_torch_min_version
,
prepare_input_tensors_for_wgrad_compute
)
from
megatron.core.tensor_parallel.layers
import
(
_initialize_affine_weight_cpu
,
_initialize_affine_weight_gpu
,
VocabParallelEmbedding
,
)
from
megatron.core.tensor_parallel.mappings
import
(
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
_grad_accum_fusion_available
=
True
try
:
import
fused_weight_gradient_mlp_cuda
except
ImportError
:
_grad_accum_fusion_available
=
False
from
flux.cpp_mod
import
ReduceScatterOption
def
vocab_parallel_embedding_init
(
...
...
@@ -124,3 +136,836 @@ def vocab_parallel_embedding_forward(self, input_, weight=None):
# Reduce across all the model parallel GPUs.
output
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
return
output
class
AGLinear
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
allreduce_dgrad
,
sequence_parallel
,
grad_output_buffer
,
wgrad_deferral_limit
,
transpose_weight
=
False
,
):
"""Forward."""
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
allreduce_dgrad
=
allreduce_dgrad
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
wgrad_deferral_limit
=
wgrad_deferral_limit
ctx
.
grad_output_buffer
=
grad_output_buffer
ctx
.
transpose_weight
=
transpose_weight
sequence_len
=
input
.
size
(
0
)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input
=
input
.
view
(
input
.
shape
[
0
]
*
input
.
shape
[
1
],
input
.
shape
[
2
]
)
M
,
K
=
list
(
input
.
size
())
N
=
weight
.
size
(
0
)
M
=
M
*
get_tensor_model_parallel_world_size
()
if
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
if
sequence_parallel
:
ag_gemm_kernel
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
get_tensor_model_parallel_world_size
()
//
torch
.
cuda
.
device_count
(),
M
,
N
,
K
,
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
output
=
ag_gemm_kernel
.
forward
(
input
,
weight
,
bias
=
bias
,
input_scale
=
input_scale
,
weight_scale
=
weight_scale
,
output_scale
=
None
,
fast_accum
=
False
)
else
:
output_buf
=
torch
.
empty
([
M
,
N
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
gemm_only_op
=
flux
.
GemmOnly
(
input_dtype
=
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
use_fp8_gemm
=
False
,
)
output
=
gemm_only_op
.
forward
(
input
,
weight
,
bias
=
bias
,
output_buf
=
output_buf
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
,
input
.
size
(
0
)
//
sequence_len
,
-
1
)
return
output
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""Backward."""
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_output_buffer
=
ctx
.
grad_output_buffer
wgrad_deferral_limit
=
ctx
.
wgrad_deferral_limit
transpose_weight
=
ctx
.
transpose_weight
wgrad_compute
=
True
if
grad_output_buffer
is
not
None
:
if
wgrad_deferral_limit
==
0
or
len
(
grad_output_buffer
)
<
wgrad_deferral_limit
:
grad_output_buffer
.
append
(
grad_output
)
wgrad_compute
=
False
world_size
=
get_tensor_model_parallel_world_size
()
if
wgrad_compute
:
if
ctx
.
sequence_parallel
:
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
handle
=
dist_all_gather_func
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input
=
all_gather_buffer
else
:
total_input
=
input
if
ctx
.
sequence_parallel
:
sequence_len
,
batch_size
,
output_hidden_size
=
grad_output
.
size
()
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
grad_output
=
grad_output
.
view
(
sequence_len
*
batch_size
,
output_hidden_size
)
if
not
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
world_size
//
torch
.
cuda
.
device_count
(),
sequence_len
*
batch_size
,
output_hidden_size
,
input
.
dtype
,
input
.
dtype
,
transpose_weight
=
transpose_weight
,
fuse_reduction
=
False
)
grad_input
=
gemm_rs_op
.
forward
(
grad_output
,
weight
,
bias
=
None
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
)
torch
.
cuda
.
current_stream
().
synchronize
()
grad_input
=
grad_input
.
view
(
sequence_len
//
get_tensor_model_parallel_group
(),
batch_size
,
-
1
)
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
sequence_parallel
and
wgrad_compute
:
handle
.
wait
()
if
wgrad_compute
:
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
grad_output
,
total_input
)
if
not
ctx
.
sequence_parallel
and
ctx
.
allreduce_dgrad
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if
ctx
.
gradient_accumulation_fusion
:
if
wgrad_compute
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
elif
weight
.
main_grad
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
else
:
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
if
hasattr
(
weight
,
'grad_added_to_main_grad'
):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if
getattr
(
weight
,
'zero_out_wgrad'
,
False
):
grad_weight
=
torch
.
zeros
(
weight
.
main_grad
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
grad_weight
=
torch
.
empty
(
weight
.
main_grad
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
weight
.
grad_added_to_main_grad
=
True
else
:
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
allreduce_dgrad
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
def
ag_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
allreduce_dgrad
:
bool
,
sequence_parallel
:
bool
,
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
wgrad_deferral_limit
:
Optional
[
int
]
=
0
,
transpose_weight
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext .
\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
allreduce_dgrad (bool required): Do the allreduce of input gradients.
The allreduce is done asynchronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): Buffer used to save
output gradients when embedding table wgrad compute is deferred.
Defaults to None.
wgrad_deferral_limit (int optional): Limit on the number of
micro-batches for which embedding weight gradient GEMM should be
deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight.
"""
args
=
[
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
allreduce_dgrad
,
sequence_parallel
,
grad_output_buffer
,
wgrad_deferral_limit
,
transpose_weight
,
]
if
not
ag_linear
.
warned
:
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
sequence_parallel
:
warnings
.
warn
(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
ag_linear
.
warned
=
True
if
allreduce_dgrad
:
warnings
.
warn
(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
ag_linear
.
warned
=
True
return
AGLinear
.
apply
(
*
args
)
ag_linear
.
warned
=
False
class
LinearRS
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
allreduce_dgrad
,
sequence_parallel
,
grad_output_buffer
,
wgrad_deferral_limit
,
transpose_weight
=
False
,
):
"""Forward."""
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
allreduce_dgrad
=
allreduce_dgrad
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
wgrad_deferral_limit
=
wgrad_deferral_limit
ctx
.
grad_output_buffer
=
grad_output_buffer
ctx
.
transpose_weight
=
transpose_weight
world_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
input
.
dim
()
sequence_len
=
input
.
size
(
0
)
# input: 3D tensor whose order of dimension is [sequence, batch, hidden]
input
=
input
.
view
(
input
.
shape
[
0
]
*
input
.
shape
[
1
],
input
.
shape
[
2
]
)
M
=
input
.
size
(
0
)
N
=
weight
.
size
(
0
)
if
sequence_parallel
:
if
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
gemm_rs_op
=
flux
.
GemmRS
(
get_tensor_model_parallel_group
(),
world_size
//
torch
.
cuda
.
device_count
(),
M
,
N
,
input
.
dtype
,
input
.
dtype
,
transpose_weight
=
transpose_weight
,
fuse_reduction
=
False
,
)
output
=
gemm_rs_op
.
forward
(
input
,
weight
,
bias
=
bias
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
else
:
output
=
torch
.
empty
([
M
,
N
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
gemm_only_op
=
flux
.
GemmOnly
(
input_dtype
=
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
use_fp8_gemm
=
False
,
)
output
=
gemm_only_op
.
forward
(
input
,
weight
,
bias
=
bias
,
output_buf
=
output
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
torch
.
cuda
.
current_stream
().
synchronize
()
output
=
output
.
view
(
sequence_len
,
input
.
size
(
0
)
//
sequence_len
,
-
1
)
if
not
sequence_parallel
:
_reduce
(
output
)
return
output
@
staticmethod
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
"""Backward."""
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_output_buffer
=
ctx
.
grad_output_buffer
wgrad_deferral_limit
=
ctx
.
wgrad_deferral_limit
transpose_weight
=
ctx
.
transpose_weight
wgrad_compute
=
True
if
grad_output_buffer
is
not
None
:
if
wgrad_deferral_limit
==
0
or
len
(
grad_output_buffer
)
<
wgrad_deferral_limit
:
grad_output_buffer
.
append
(
grad_output
)
wgrad_compute
=
False
if
ctx
.
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
sequence_len
,
batch_size
,
_
=
grad_output
.
size
()
grad_output
=
grad_output
.
view
(
sequence_len
*
batch_size
,
-
1
)
M
,
K
=
list
(
grad_output
.
size
())
M
=
M
*
world_size
N
=
weight
.
size
(
-
1
)
if
not
transpose_weight
:
weight
=
weight
.
t
().
contiguous
()
grad_input
=
torch
.
empty
([
M
,
N
],
dtype
=
input
.
dtype
,
device
=
input
.
device
)
ag_kernel
=
flux
.
AGKernel
(
get_tensor_model_parallel_group
(),
world_size
//
torch
.
cuda
.
device_count
(),
M
,
N
,
K
,
input
.
dtype
,
output_dtype
=
input
.
dtype
,
transpose_weight
=
transpose_weight
,
local_copy
=
False
,
ring_mode
=
flux
.
AgRingMode
.
Auto
,
)
output
=
ag_kernel
.
forward
(
grad_output
,
weight
,
bias
=
None
,
input_scale
=
None
,
weight_scale
=
None
,
output_scale
=
None
,
fast_accum
=
False
,
)
torch
.
cuda
.
current_stream
().
synchronize
()
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
if
wgrad_compute
:
grad_output
,
total_input
=
prepare_input_tensors_for_wgrad_compute
(
grad_output
,
input
)
if
ctx
.
gradient_accumulation_fusion
:
if
wgrad_compute
:
if
weight
.
main_grad
.
dtype
==
torch
.
float32
:
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
elif
weight
.
main_grad
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
fused_weight_gradient_mlp_cuda
.
wgrad_gemm_accum_fp16
(
total_input
,
grad_output
,
weight
.
main_grad
)
else
:
raise
RuntimeError
(
"Unsupported gradient type for gradient accumulation fusion"
)
if
hasattr
(
weight
,
'grad_added_to_main_grad'
):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if
getattr
(
weight
,
'zero_out_wgrad'
,
False
):
grad_weight
=
torch
.
zeros
(
weight
.
main_grad
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
grad_weight
=
torch
.
empty
(
weight
.
main_grad
.
shape
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
weight
.
grad_added_to_main_grad
=
True
else
:
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
def
linear_rs
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
gradient_accumulation_fusion
:
bool
,
allreduce_dgrad
:
bool
,
sequence_parallel
:
bool
,
grad_output_buffer
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
wgrad_deferral_limit
:
Optional
[
int
]
=
0
,
transpose_weight
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext .
\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
allreduce_dgrad (bool required): Do the allreduce of input gradients.
The allreduce is done asynchronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): Buffer used to save
output gradients when embedding table wgrad compute is deferred.
Defaults to None.
wgrad_deferral_limit (int optional): Limit on the number of
micro-batches for which embedding weight gradient GEMM should be
deferred. Disable by setting this to 0. Defaults to 0.
transpose_weight: transpose weight.
"""
args
=
[
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
allreduce_dgrad
,
sequence_parallel
,
grad_output_buffer
,
wgrad_deferral_limit
,
transpose_weight
,
]
if
not
linear_rs
.
warned
:
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
sequence_parallel
:
warnings
.
warn
(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_rs
.
warned
=
True
if
allreduce_dgrad
:
warnings
.
warn
(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_rs
.
warned
=
True
return
LinearRS
.
apply
(
*
args
)
linear_rs
.
warned
=
False
def
parallel_linear_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# flux params
self
.
use_flux
=
False
if
"use_flux"
in
kwargs
:
self
.
use_flux
=
kwargs
[
"use_flux"
]
elif
hasattr
(
self
.
config
,
"use_flux"
):
self
.
use_flux
=
self
.
config
.
use_flux
self
.
flux_transpose_weight
=
False
if
"flux_transpose_weight"
in
kwargs
:
self
.
flux_transpose_weight
=
kwargs
[
"flux_transpose_weight"
]
elif
hasattr
(
self
.
config
,
"flux_transpose_weight"
):
self
.
flux_transpose_weight
=
self
.
config
.
flux_transpose_weight
return
wrapper
class
ColumnParallelLinearPatch
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
"""
def
forward
(
self
,
input_
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
):
"""Forward of ColumnParallelLinear
Args:
input_:
3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional):
weight tensor to use, compulsory when skip_weight_param_allocation is True.
runtime_gather_output (bool): Gather output at runtime. Default None means
`gather_output` arg in the constructor will be used.
Returns:
- output
- bias
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
else
:
# Check the weight passed in is the correct shape
expected_shape
=
(
self
.
output_size_per_partition
,
self
.
input_size
)
if
weight
.
shape
!=
expected_shape
:
raise
RuntimeError
(
f
"supplied weight's shape is
{
tuple
(
weight
.
shape
)
}
, "
f
"not
{
expected_shape
}
as expected"
)
if
self
.
config
.
_cpu_offloading_context
is
not
None
:
if
self
.
config
.
_cpu_offloading_context
.
inside_context
is
True
:
assert
(
self
.
config
.
cpu_offloading
is
False
),
"CPU Offloading cannot be enabled while using non-TE modules"
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
(
self
.
allreduce_dgrad
or
self
.
sequence_parallel
or
self
.
explicit_expert_comm
or
self
.
disable_grad_reduce
):
input_parallel
=
input_
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
if
self
.
config
.
defer_embedding_wgrad_compute
:
if
(
self
.
config
.
wgrad_deferral_limit
==
0
or
len
(
self
.
embedding_activation_buffer
)
<
self
.
config
.
wgrad_deferral_limit
):
self
.
embedding_activation_buffer
.
append
(
input_parallel
)
# Matrix multiply.
if
self
.
use_flux
:
self
.
_forward_impl
=
ag_linear
elif
not
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad
=
False
if
self
.
explicit_expert_comm
else
self
.
allreduce_dgrad
forward_params
=
{
"input"
:
input_parallel
,
"weight"
:
weight
,
"bias"
:
bias
,
"gradient_accumulation_fusion"
:
self
.
gradient_accumulation_fusion
,
"allreduce_dgrad"
:
allreduce_dgrad
,
"sequence_parallel"
:
False
if
self
.
explicit_expert_comm
else
self
.
sequence_parallel
,
"grad_output_buffer"
:
self
.
grad_output_buffer
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
"wgrad_deferral_limit"
:
self
.
config
.
wgrad_deferral_limit
if
self
.
config
.
defer_embedding_wgrad_compute
else
None
,
}
if
self
.
use_flux
:
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
})
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
gather_output
=
self
.
gather_output
# Use the runtime gather output if it's set explicitly.
if
runtime_gather_output
is
not
None
:
gather_output
=
runtime_gather_output
if
gather_output
:
# All-gather across the partitions.
assert
not
self
.
sequence_parallel
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
class
RowParallelLinearPatch
(
torch
.
nn
.
Module
):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
"""
def
forward
(
self
,
input_
):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
if
self
.
config
.
_cpu_offloading_context
is
not
None
:
if
self
.
config
.
_cpu_offloading_context
.
inside_context
is
True
:
assert
(
self
.
config
.
cpu_offloading
is
False
),
"CPU Offloading cannot be enabled while using non-TE modules"
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
assert
not
self
.
sequence_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
if
self
.
use_flux
:
self
.
_forward_impl
=
linear_rs
elif
not
self
.
weight
.
requires_grad
:
self
.
_forward_impl
=
linear_with_frozen_weight
else
:
self
.
_forward_impl
=
linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad
=
False
forward_params
=
{
"input"
:
input_parallel
,
"weight"
:
self
.
weight
,
"bias"
:
None
if
not
self
.
use_flux
or
self
.
skip_bias_add
else
self
.
bias
,
"gradient_accumulation_fusion"
:
self
.
gradient_accumulation_fusion
,
"allreduce_dgrad"
:
allreduce_dgrad
,
"sequence_parallel"
:
False
if
not
self
.
use_flux
else
self
.
sequence_parallel
,
"grad_output_buffer"
:
False
,
}
if
self
.
use_flux
:
forward_params
.
update
({
"transpose_weight"
:
self
.
flux_transpose_weight
})
output_parallel
=
self
.
_forward_impl
(
**
forward_params
)
if
self
.
use_flux
:
return
output_parallel
,
None
if
skip_bias_add
else
self
.
bias
# All-reduce across all the partitions.
if
self
.
explicit_expert_comm
:
assert
self
.
skip_bias_add
output_
=
output_parallel
elif
self
.
sequence_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
output
=
(
output_
+
self
.
bias
)
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
dcu_megatron/core/transformer/transformer_config.py
View file @
9eb8683b
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch.nn.functional
as
F
from
megatron.core.transformer.enums
import
AttnBackend
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.utils
import
get_te_version
,
init_method_normal
,
is_te_min_version
,
scaled_init_method_normal
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
@
dataclass
class
TransformerConfig
(
ModelParallelConfig
):
"""Configuration object for megatron-core transformers.
The initialization function has an argument for each parameter,
including those in ModelParallelConfig.
"""
####################
# model architecture
####################
num_layers
:
int
=
0
"""Number of transformer layers in a transformer block."""
num_layers_in_first_pipeline_stage
:
Optional
[
int
]
=
None
"""Number of transformer layers on first pipeline stage.
None implies equal layer division across PP ranks."""
num_layers_in_last_pipeline_stage
:
Optional
[
int
]
=
None
"""Number of transformer layers on last pipeline stage.
None implies equal layer division across PP ranks."""
account_for_embedding_in_pipeline_split
:
bool
=
False
"""If set, the embedding layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
account_for_loss_in_pipeline_split
:
bool
=
False
"""If set, the loss layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
hidden_size
:
int
=
0
"""Transformer hidden size."""
num_attention_heads
:
int
=
0
"""Number of transformer attention heads."""
attention_backend
:
AttnBackend
=
AttnBackend
.
auto
"""Attention backend to run. By default we let transformer engine
decide the best backend to run (except in the case of local).
If attention backend is local we use the local pytorch implementation in mcore.
Users can specify exact backend by changing this config. """
softmax_scale
:
Optional
[
float
]
=
None
"""Softmax scale for attention scaling."""
num_query_groups
:
Optional
[
int
]
=
None
"""Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size
:
Optional
[
int
]
=
None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size
if not provided."""
kv_channels
:
Optional
[
int
]
=
None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""
hidden_dropout
:
float
=
0.1
"""Dropout probability for transformer hidden state."""
attention_dropout
:
float
=
0.1
"""Post attention dropout probability."""
fp32_residual_connection
:
bool
=
False
"""If true, move residual connections to fp32."""
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm
:
bool
=
False
"""If True, uses the original BERT residule connection ordering."""
layernorm_epsilon
:
float
=
1e-5
"""Epsilon value for any LayerNorm operations."""
layernorm_zero_centered_gamma
:
bool
=
False
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""
add_bias_linear
:
bool
=
True
"""Include a bias term in all linear layers (QKV projections, after core attention, and two in
MLP layer)."""
add_qkv_bias
:
bool
=
False
"""Add a bias term only for QKV projections."""
gated_linear_unit
:
bool
=
False
"""Use a gated linear unit for the first linear layer in the MLP."""
activation_func
:
Callable
=
F
.
gelu
"""Activation function to use for the non-linearity in the MLP."""
activation_func_fp8_input_store
:
bool
=
False
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts
:
Optional
[
int
]
=
None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
rotary_interleaved
:
bool
=
False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
normalization
:
str
=
"LayerNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm
:
bool
=
False
"""Whether to apply LayerNorm to the query and key embeddings."""
test_mode
:
bool
=
False
"""Whether to run real-time tests."""
calculate_per_token_loss
:
bool
=
False
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
global batch, versus the default behavior of assuming all tokens are non-padded."""
multi_latent_attention
:
bool
=
False
"""Whether to use multi-latent attention."""
####################
# initialization
####################
init_method
:
Optional
[
Callable
]
=
None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method
:
Optional
[
Callable
]
=
None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
init_method_std
:
float
=
0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""
####################
# mixed-precision
####################
apply_query_key_layer_scaling
:
bool
=
False
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16."""
attention_softmax_in_fp32
:
bool
=
True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""
####################
# fusion
####################
bias_activation_fusion
:
bool
=
False
"""If True, fuses bias addition and the activation function when possible."""
masked_softmax_fusion
:
bool
=
False
"""If True, uses softmax fusion."""
persist_layer_norm
:
bool
=
False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""
memory_efficient_layer_norm
:
bool
=
False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""
bias_dropout_fusion
:
bool
=
False
# TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""
apply_rope_fusion
:
bool
=
False
"""If True, use fused RoPE kernel."""
####################
# activation recomputation
####################
recompute_granularity
:
Optional
[
str
]
=
None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method
:
Optional
[
str
]
=
None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers
:
Optional
[
int
]
=
None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations
:
Optional
[
bool
]
=
None
"""If True, distribute recomputed activations across the model parallel group."""
####################
# fp8 related
####################
fp8
:
Optional
[
str
]
=
None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
fp8_margin
:
int
=
0
"""Margin for the scaling factor computation."""
fp8_interval
:
int
=
1
"""DEPRECATED from TransformerEngine v1.8.0. This flag is ignored.
Controls how often the scaling factor is recomputed.
"""
fp8_amax_history_len
:
int
=
1
"""The length of the amax history window used for scaling factor computation."""
fp8_amax_compute_algo
:
str
=
"most_recent"
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""
fp8_wgrad
:
bool
=
True
"""When set to False, override FP8 config options and do the wgrad computation
in higher precision."""
fp8_dot_product_attention
:
bool
=
False
"""When set to True, use the FP8 implementation of Dot Product Attention."""
fp8_multi_head_attention
:
bool
=
False
"""When set to True, use the FP8 implementation of Multi Head Attention."""
tp_only_amax_red
:
bool
=
False
"""When set to True, reduce the FP8 AMAX only in the TP or TP-CP domain"""
####################
# MoE related
####################
moe_shared_expert_intermediate_size
:
Optional
[
int
]
=
None
"""Shared expert total ffn hidden size.
It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if
there are multiple shared experts.
None means no shared expert."""
moe_shared_expert_overlap
:
bool
=
False
"""Enable overlapping between shared expert computations and dispatcher communications.
Without this, the shared epxerts execute after the routed experts."""
moe_layer_freq
:
Union
[
int
,
List
[
int
]]
=
1
"""Frequency between MoE layers and Dense layers. Accepts either:
- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers.
- A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]"""
moe_ffn_hidden_size
:
Optional
[
int
]
=
None
"""MoE Feed-Forward Network hidden size"""
moe_router_load_balancing_type
:
str
=
"aux_loss"
"""The load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss
used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the loss used in DeepSeekV2,
which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss"."""
moe_router_topk
:
int
=
2
"""Number of experts to route to for each token."""
moe_router_topk_limited_devices
:
Optional
[
int
]
=
None
"""Number of EP ranks to consider for each token in group-limited routing,
DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk.
"""
moe_router_num_groups
:
Optional
[
int
]
=
None
"""Number of groups to divide experts into for group-limited routing.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
"""
moe_router_group_topk
:
Optional
[
int
]
=
None
"""Number of selected groups for group-limited routing."""
moe_router_pre_softmax
:
bool
=
False
"""Enable pre-softmax routing for MoE, which means softmax is before the top-k selection.
By default, softmax is done after top-k."""
moe_router_topk_scaling_factor
:
Optional
[
float
]
=
None
"""Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
enabled. Defaults to None, which means no scaling."""
moe_router_score_function
:
str
=
"softmax"
"""Score function for MoE routing. Can be "softmax" or "sigmoid"."""
moe_router_enable_expert_bias
:
bool
=
False
"""TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy.
The routing decision is based on the sum of the routing scores and the expert bias.
See https://arxiv.org/abs/2408.15664 for details."""
moe_router_bias_update_rate
:
float
=
1e-3
"""The expert bias is updated based on the number of assigned tokens to each expert
in a global batch, where the bias is increased for the experts with less assigned tokens
and decreased for the experts with more assigned tokens.
The default value 1e-3 is same as that used in DeepSeekV3."""
moe_grouped_gemm
:
bool
=
False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""
moe_use_legacy_grouped_gemm
:
bool
=
False
"""Use legacy GroupedMLP rather than TEGroupedMLP.
Note: The legacy one will be deprecated soon."""
moe_aux_loss_coeff
:
float
=
0
# 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff
:
Optional
[
float
]
=
None
# 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps
:
Optional
[
float
]
=
None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping
:
bool
=
False
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
moe_token_dispatcher_type
:
str
=
"allgather"
"""The type of token dispatcher to use. The default is 'allgather'.
Options are 'allgather' and 'alltoall'."""
moe_per_layer_logging
:
bool
=
False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor
:
Optional
[
float
]
=
None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token
will be dropped. The default is None."""
moe_pad_expert_input_to_capacity
:
bool
=
False
"""moe_pad_expert_input_to_capacity (bool): If True, pads the input for each expert to match
the expert capacity length, effective only after the moe_expert_capacity_factor is set. The
default setting is False."""
moe_token_drop_policy
:
str
=
'probs'
"""The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with
the lowest probabilities will be dropped. If "position", tokens at the end of each batch will
be dropped.
"""
moe_layer_recompute
:
bool
=
False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
moe_permute_fusion
:
bool
=
False
"""Fuse token rearrangement ops during token dispatching."""
class
ExtraTransformerConfig
:
##################
# multi-token prediction
##################
...
...
@@ -399,538 +24,19 @@ class TransformerConfig(ModelParallelConfig):
"""share embedding and output weight with mtp layer."""
##################
#
Context Parallel
#
flux
##################
cp_comm_type
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
"""Inter-gpu communication type for context parallelism.
str: all layers share same communication type.
List[str]: each layer has its separate communication type.
cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be
overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention. The all-gather is not
async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get
full sequence of QKV.
"a2a+p2p": A hierarchical implementation of context parallelism to attention.
It uses A2A communications in low-level CP groups (e.g., via NVLink),
and P2P communications in high-level CP groups (e.g., via IBLink).
"""
use_flux
:
bool
=
False
"""If set, flux will be used in ColumnParallelLinear and RowParallelLinear"""
##################
# Cuda Graphs
##################
enable_cuda_graph
:
bool
=
False
"""When set to true, TransformerLayer layers are swapped with a CUDA graphed version."""
cuda_graph_use_single_mempool
:
bool
=
False
"""When set to true, cudagraphs will be captured inside a single mempool, in which all
cudagraphs may only be used once per step. If false, cudagraphs may be reused across
microbatches. Enabling may reduce cudagraph memory overheads due to memory fragmentation,
however may greatly increase the number of cudagraphs created when the number of microbatches
is high."""
cuda_graph_retain_backward_graph
:
bool
=
False
"""When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True'
This may enable cudagraphs for certain modules that are not completely cudagraph safe. For
more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html."""
cuda_graph_warmup_steps
:
int
=
3
"""Number of warmup steps for CUDA graphs"""
external_cuda_graph
:
bool
=
False
"""When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""
####################
# miscellaneous
####################
clone_scatter_output_in_embedding
:
bool
=
True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""
disable_parameter_transpose_cache
:
bool
=
False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""
config_logger_dir
:
str
=
""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
flash_decode
:
bool
=
False
""" Use the optimized flash decoding kernel during inference. """
use_te_rng_tracker
:
bool
=
False
""" Whether to use the TE or MCore version of the RNG tracker. """
inference_rng_tracker
:
bool
=
False
""" Whether we should instantiate a separate RNG tracker for inference. """
def
__post_init__
(
self
):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
super
().
__post_init__
()
if
self
.
fp16
and
self
.
bf16
:
raise
ValueError
(
f
'Only one of self.fp16:
{
self
.
fp16
}
and self.bf16
{
self
.
bf16
}
should be True.'
)
if
self
.
num_attention_heads
%
self
.
tensor_model_parallel_size
!=
0
:
raise
ValueError
(
f
"num_attention_heads (
{
self
.
num_attention_heads
}
) must be a multiple of "
f
"tensor_model_parallel_size (
{
self
.
tensor_model_parallel_size
}
)."
)
if
self
.
ffn_hidden_size
is
None
:
self
.
ffn_hidden_size
=
4
*
self
.
hidden_size
if
self
.
kv_channels
is
None
:
self
.
kv_channels
=
self
.
hidden_size
//
self
.
num_attention_heads
if
self
.
num_query_groups
is
None
:
self
.
num_query_groups
=
self
.
num_attention_heads
if
self
.
num_query_groups
%
self
.
tensor_model_parallel_size
!=
0
:
raise
ValueError
(
f
"num_query_groups (
{
self
.
num_query_groups
}
) must be a multiple of "
f
"tensor_model_parallel_size (
{
self
.
tensor_model_parallel_size
}
)."
)
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
if
self
.
expert_model_parallel_size
>
1
and
self
.
num_moe_experts
is
None
:
raise
ValueError
(
'num_moe_experts must be non None to use expert-parallel.'
)
if
self
.
num_moe_experts
is
not
None
and
self
.
num_moe_experts
<=
0
:
raise
ValueError
(
'num_moe_experts must be non-negative.'
)
if
self
.
moe_ffn_hidden_size
is
None
:
self
.
moe_ffn_hidden_size
=
self
.
ffn_hidden_size
if
self
.
moe_shared_expert_intermediate_size
is
not
None
:
if
self
.
moe_shared_expert_intermediate_size
<=
0
:
raise
ValueError
(
f
'moe_shared_expert_intermediate_size must be '
f
'num_shared_experts * ffn_size_of_each_shared_expert, '
f
'but got
{
self
.
moe_shared_expert_intermediate_size
}
'
)
if
self
.
moe_shared_expert_overlap
and
self
.
moe_token_dispatcher_type
not
in
[
"alltoall"
]:
raise
ValueError
(
f
'moe_shared_expert_overlap only works with alltoall token dispatcher.'
)
if
self
.
moe_expert_capacity_factor
is
not
None
:
if
self
.
moe_token_dispatcher_type
not
in
[
"alltoall"
,
"alltoall_seq"
]:
raise
ValueError
(
'moe_expert_capacity_factor only works with alltoall token dispatcher'
)
if
self
.
moe_expert_capacity_factor
<
0
:
self
.
moe_expert_capacity_factor
=
None
if
self
.
moe_router_load_balancing_type
not
in
[
"aux_loss"
,
"seq_aux_loss"
,
"none"
]:
raise
ValueError
(
'moe_expert_capacity_factor only works with aux_loss or none load balancing'
)
if
self
.
moe_pad_expert_input_to_capacity
:
if
self
.
moe_expert_capacity_factor
is
None
:
raise
ValueError
(
'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity'
)
if
self
.
cpu_offloading
and
(
self
.
cpu_offloading_num_layers
<
0
or
self
.
cpu_offloading_num_layers
>=
self
.
num_layers
):
raise
ValueError
(
f
'CPU offloading can be done only for layers less than
{
self
.
num_layers
}
'
)
if
self
.
cpu_offloading
and
self
.
pipeline_model_parallel_size
>
1
:
raise
ValueError
(
'Currently there is no support for Pipeline parallelism with CPU offloading'
)
if
self
.
cpu_offloading
and
self
.
recompute_granularity
is
not
None
:
raise
ValueError
(
'CPU offloading does not work when activation recomputation is enabled'
)
if
self
.
recompute_granularity
is
not
None
:
if
self
.
recompute_granularity
not
in
[
'full'
,
'selective'
]:
raise
ValueError
(
f
'When using recompute_granuarlity:
{
self
.
recompute_granularity
}
must be "full"'
'or "selective".'
)
if
self
.
recompute_method
is
not
None
:
if
self
.
recompute_method
not
in
[
'block'
,
'uniform'
]:
raise
ValueError
(
f
'recompute_method:
{
self
.
recompute_method
}
must be "block" or "uniform".'
)
elif
self
.
recompute_granularity
!=
'selective'
:
raise
ValueError
(
f
'Using recompute_granularity:
{
self
.
recompute_granularity
}
so '
'recompute_method must be "block" or "uniform"'
)
if
self
.
recompute_granularity
!=
'selective'
and
self
.
recompute_num_layers
is
None
:
raise
ValueError
(
f
'When using recompute_granularity:
{
self
.
recompute_granularity
}
'
'recompute_num_layers must be between '
'1 and num_layers_per_pipeline_rank: '
f
'
{
self
.
num_layers
//
self
.
pipeline_model_parallel_size
}
'
)
elif
(
self
.
recompute_granularity
==
'selective'
and
self
.
recompute_num_layers
is
not
None
):
raise
ValueError
(
f
'When using recompute_granularity:
{
self
.
recompute_granularity
}
'
'recompute_num_layers must be None.'
)
if
self
.
distribute_saved_activations
and
self
.
sequence_parallel
:
raise
ValueError
(
f
'distribute_saved_activations:
{
self
.
distribute_saved_activations
}
must be '
f
'false when sequence parallel is enabled:
{
self
.
sequence_parallel
}
'
)
if
(
self
.
num_layers_in_first_pipeline_stage
is
not
None
or
self
.
num_layers_in_last_pipeline_stage
is
not
None
)
and
(
self
.
account_for_embedding_in_pipeline_split
or
self
.
account_for_loss_in_pipeline_split
):
raise
ValueError
(
'num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage cannot be'
'set at the same time with account_for_embedding_in_pipeline_split'
'and account_for_loss_in_pipeline_split'
)
if
(
self
.
num_layers_in_first_pipeline_stage
is
not
None
or
self
.
num_layers_in_last_pipeline_stage
is
not
None
):
pipeline_parallel_size
=
self
.
pipeline_model_parallel_size
num_layers
=
self
.
num_layers
if
self
.
num_layers_in_first_pipeline_stage
is
not
None
:
if
self
.
num_layers_in_first_pipeline_stage
<=
0
:
raise
ValueError
(
'num_layers_in_first_pipeline_stage must be larger than 0'
)
if
self
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
(
self
.
num_layers_in_first_pipeline_stage
%
self
.
virtual_pipeline_model_parallel_size
!=
0
):
raise
ValueError
(
f
'number of layers at first stage: '
f
'
{
self
.
num_layers_in_first_pipeline_stage
}
'
f
'must be divisible by virtual pipeline'
f
'parallel degree
{
self
.
virtual_pipeline_model_parallel_size
}
'
)
num_layers
-=
self
.
num_layers_in_first_pipeline_stage
pipeline_parallel_size
-=
1
if
self
.
num_layers_in_last_pipeline_stage
is
not
None
:
if
self
.
num_layers_in_last_pipeline_stage
<=
0
:
raise
ValueError
(
'num_layers_in_last_pipeline_stage must be larger than 0'
)
if
self
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
(
self
.
num_layers_in_last_pipeline_stage
%
self
.
virtual_pipeline_model_parallel_size
!=
0
):
raise
ValueError
(
f
'number of layers at last stage: '
f
'
{
self
.
num_layers_in_last_pipeline_stage
}
'
f
'must be divisible by virtual pipeline'
f
'parallel degree
{
self
.
virtual_pipeline_model_parallel_size
}
'
)
num_layers
-=
self
.
num_layers_in_last_pipeline_stage
pipeline_parallel_size
-=
1
if
not
num_layers
%
pipeline_parallel_size
==
0
:
raise
ValueError
(
f
'number of layers at middle stage:
{
num_layers
}
must be divisible by'
f
'the middle pipeline model parallel size
{
pipeline_parallel_size
}
'
)
if
self
.
virtual_pipeline_model_parallel_size
is
not
None
:
num_layers_per_middle_pipeline_rank
=
num_layers
//
pipeline_parallel_size
if
(
not
num_layers_per_middle_pipeline_rank
%
self
.
virtual_pipeline_model_parallel_size
==
0
):
raise
ValueError
(
f
'number of layers on each middle pipeline rank:'
f
'
{
num_layers_per_middle_pipeline_rank
}
must be divisible by virtual'
f
'pipeline parallel degree
{
self
.
virtual_pipeline_model_parallel_size
}
'
)
if
self
.
account_for_embedding_in_pipeline_split
or
self
.
account_for_loss_in_pipeline_split
:
if
self
.
virtual_pipeline_model_parallel_size
is
None
:
pipeline_parallel_size
=
self
.
pipeline_model_parallel_size
if
self
.
account_for_embedding_in_pipeline_split
:
pipeline_parallel_size
-=
1
if
self
.
account_for_loss_in_pipeline_split
:
pipeline_parallel_size
-=
1
if
not
self
.
num_layers
%
pipeline_parallel_size
==
0
:
raise
ValueError
(
f
'number of middle layers:
{
self
.
num_layers
}
must be divisible by '
f
'middle pipeline_model_parallel_size
{
pipeline_parallel_size
}
'
)
else
:
num_layers
=
self
.
num_layers
if
self
.
account_for_embedding_in_pipeline_split
:
num_layers
+=
1
if
self
.
account_for_loss_in_pipeline_split
:
num_layers
+=
1
if
not
num_layers
%
self
.
pipeline_model_parallel_size
==
0
:
raise
ValueError
(
f
'num_layers:
{
num_layers
}
after enable'
f
'account_for_embedding_in_pipeline_split or '
f
'account_for_loss_in_pipeline_split must be divisible'
f
'by pipeline_model_parallel_size '
f
'
{
self
.
pipeline_model_parallel_size
}
'
)
num_layers_per_pipeline_rank
=
num_layers
//
self
.
pipeline_model_parallel_size
if
(
not
num_layers_per_pipeline_rank
%
self
.
virtual_pipeline_model_parallel_size
==
0
):
raise
ValueError
(
f
'number of layers on each pipeline rank:
{
num_layers_per_pipeline_rank
}
'
f
'(after enable account_for_embedding_in_pipeline_split or '
f
'account_for_loss_in_pipeline_split) must be divisible by'
f
'virtual_pipeline_model_parallel_size'
f
'
{
self
.
virtual_pipeline_model_parallel_size
}
'
)
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
if
self
.
bias_activation_fusion
:
if
self
.
activation_func
not
in
[
F
.
gelu
,
F
.
silu
]:
raise
ValueError
(
"When bias_activation_fusion is True, activation function should be either "
"gelu or swiglu"
)
if
(
self
.
activation_func
==
F
.
gelu
and
not
self
.
gated_linear_unit
and
not
self
.
add_bias_linear
):
raise
ValueError
(
"When bias_activation_fusion is True, gated_linear_unit is False, "
"and activation function is gelu, add_bias_linear must also be True."
)
if
self
.
activation_func_fp8_input_store
:
if
self
.
activation_func
!=
F
.
silu
or
not
self
.
gated_linear_unit
:
raise
ValueError
(
"Storing activation input in FP8 is supported only for SwiGLU."
)
if
self
.
apply_rope_fusion
:
if
self
.
rotary_interleaved
:
raise
ValueError
(
"rotary_interleaved does not work with apply_rope_fusion."
)
from
megatron.core.models.common.embeddings.rope_utils
import
(
fused_apply_rotary_pos_emb
,
fused_apply_rotary_pos_emb_thd
,
)
if
fused_apply_rotary_pos_emb
is
None
and
fused_apply_rotary_pos_emb_thd
is
None
:
raise
ValueError
(
"apply_rope_fusion is not available. Please install TE >= 1.4 or Apex."
)
if
self
.
multi_latent_attention
:
raise
ValueError
(
"multi_latent_attention does not support apply_rope_fusion."
)
if
self
.
multi_latent_attention
and
self
.
rotary_interleaved
:
raise
ValueError
(
"rotary_interleaved does not work with multi_latent_attention."
)
if
self
.
init_method
is
None
:
self
.
init_method
=
init_method_normal
(
self
.
init_method_std
)
if
self
.
output_layer_init_method
is
None
:
self
.
output_layer_init_method
=
scaled_init_method_normal
(
self
.
init_method_std
,
self
.
num_layers
)
if
(
self
.
moe_token_dispatcher_type
==
"alltoall_seq"
and
self
.
tensor_model_parallel_size
!=
self
.
expert_tensor_parallel_size
):
raise
ValueError
(
"alltoall_seq dispatcher not support different TP size for MoE and Dense layer."
)
if
self
.
moe_router_enable_expert_bias
and
self
.
moe_router_score_function
!=
"sigmoid"
:
raise
ValueError
(
"Expert bias for aux-loss-free routing only supports sigmoid score function."
"Please set --moe-router-score-function sigmoid for sigmoid score function."
)
if
self
.
num_moe_experts
and
self
.
fp8
:
# TE version below 1.7.0 will raise Error when handle zeros tokens for expert
if
not
is_te_min_version
(
"1.7.0.dev0"
):
raise
ValueError
(
"Only transformer-engine>=1.7.0 supports MoE FP8 training, "
f
"but your version is
{
get_te_version
()
}
."
)
if
self
.
moe_grouped_gemm
and
not
is_te_min_version
(
"1.11.0"
):
raise
ValueError
(
"Only transformer-engine>=1.11.0 supports FP8 grouped gemm, "
f
"but your version is
{
get_te_version
()
}
."
)
if
(
self
.
moe_router_topk
==
1
and
self
.
moe_router_score_function
==
'softmax'
and
not
self
.
moe_router_pre_softmax
and
self
.
moe_router_load_balancing_type
!=
'sinkhorn'
):
# Requires applying softmax before selecting the top-k when k is 1,
# since softmax on a [num_tokens, 1] would yield a zero gradient.
raise
ValueError
(
"Please use --moe-router-pre-softmax when topk is 1."
)
if
self
.
moe_router_group_topk
:
if
self
.
moe_router_topk_limited_devices
:
raise
ValueError
(
"moe_router_topk_limited_devices is deprecated and replaced by "
"moe_router_group_topk and moe_router_num_groups."
)
if
not
self
.
moe_router_num_groups
:
raise
ValueError
(
"When using group limited routing, moe_router_num_groups must be specified."
)
else
:
assert
self
.
num_moe_experts
%
self
.
moe_router_num_groups
==
0
,
(
f
"num_moe_experts (
{
self
.
num_moe_experts
}
) should be divisible by "
f
"moe_router_num_groups (
{
self
.
moe_router_num_groups
}
)."
)
assert
self
.
moe_router_group_topk
<=
self
.
moe_router_num_groups
,
(
f
"moe_router_group_topk (
{
self
.
moe_router_group_topk
}
) should be smaller than "
f
"moe_router_num_groups (
{
self
.
moe_router_num_groups
}
)."
)
elif
self
.
moe_router_topk_limited_devices
:
warnings
.
warn
(
"moe_router_topk_limited_devices is deprecated. Use moe_router_group_topk and "
"moe_router_num_groups instead."
)
self
.
moe_router_group_topk
=
self
.
moe_router_topk_limited_devices
self
.
moe_router_num_groups
=
self
.
expert_model_parallel_size
if
self
.
flash_decode
and
self
.
fp8
:
raise
ValueError
(
"FP8 inference is currently not support with flash decoding."
)
if
self
.
enable_cuda_graph
:
if
self
.
cpu_offloading
:
raise
ValueError
(
"CUDA graphs not supported with CPU offloading."
)
if
self
.
recompute_granularity
:
raise
ValueError
(
"CUDA graphs not supported with activation recomputation."
)
if
self
.
moe_token_dispatcher_type
in
[
'allgather'
,
'alltoall_seq'
]:
if
self
.
variable_seq_lengths
is
True
:
raise
ValueError
(
f
"Token dispatcher type:
{
self
.
moe_token_dispatcher_type
}
does not support "
f
"variable sequence length, please use alltoall dispatcher instead."
)
if
self
.
moe_permute_fusion
:
from
megatron.core.transformer.moe.moe_utils
import
(
fused_permute
,
fused_sort_chunks_by_index
,
fused_unpermute
,
)
if
(
fused_permute
is
None
or
fused_sort_chunks_by_index
is
None
or
fused_unpermute
is
None
):
raise
ValueError
(
"fused permutation is not available. Please install TE >= 2.1.0."
)
if
self
.
cp_comm_type
is
not
None
:
if
isinstance
(
self
.
cp_comm_type
,
list
):
assert
len
(
self
.
cp_comm_type
)
==
self
.
num_layers
,
(
f
"Length of cp_comm_type (
{
len
(
self
.
cp_comm_type
)
}
) should equal to "
f
"the total number of transformer layers (
{
self
.
num_layers
}
)!"
)
else
:
assert
isinstance
(
self
.
cp_comm_type
,
str
),
"Unsupported communication type for context parallelism!"
assert
(
self
.
pipeline_model_parallel_size
>
0
),
f
"Pipeline model parallel size must be larger than 0
\
when enable --standalone-embedding-stage and --standalone-loss-stage"
flux_transpose_weight
:
bool
=
False
@
dataclass
class
MLATransformerConfig
(
TransformerConfig
):
"""Configuration object for megatron-core Multi-Latent Attention (MLA) transformers.
The initialization function has an argument for each parameter, including those in
ModelParallelConfig. Included YaRN RoPE parameters that is fused in MLA.
"""
multi_latent_attention
:
bool
=
True
"""Whether to use Multi-Latent Attention."""
q_lora_rank
:
int
=
512
"""Rank of Query tensor's low rank representation."""
kv_lora_rank
:
int
=
512
"""Rank of Key and Value tensors' low rank representation."""
qk_head_dim
:
int
=
128
"""Dimension of the head in the QK projection. q_head_dim = qk_head_dim + qk_pos_emb_head_dim"""
qk_pos_emb_head_dim
:
int
=
64
"""Dimension of the position embedding in the QK projection."""
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
pass
v_head_dim
:
int
=
128
"""Dimension of the head in the V projection."""
normalization
:
str
=
"RMSNorm"
"""Default normalization layer for MLA models is RMSNorm."""
rope_type
:
str
=
"yarn"
"""Type of RoPE to use. Default to yarn, options are rope and yarn."""
rotary_base
:
float
=
10000
"""Rotary base for the rotary embeddings, used by rope and yarn."""
rotary_percent
:
float
=
1.0
"""Rotary percent for the rotary embeddings, used by rope."""
rotary_scaling_factor
:
float
=
40
"""Rotary scaling factor for the rotary embeddings, used by yarn."""
max_position_embeddings
:
int
=
4096
"""Maximum position embeddings for the original model, used by yarn."""
beta_fast
:
float
=
32
"""Beta fast for YaRN RoPE, used by yarn."""
beta_slow
:
float
=
1
"""Beta slow for YaRN RoPE, used by yarn."""
mscale
:
float
=
0.707
"""Mscale for YaRN RoPE in Multi-Latent Attention, used by yarn."""
mscale_all_dim
:
float
=
0.707
"""Mscale all dimensions for YaRN RoPE in Multi-Latent Attention, used by yarn."""
@
dataclass
class
MLATransformerConfigPatch
(
MLATransformerConfig
,
ExtraTransformerConfig
):
pass
dcu_megatron/training/arguments.py
View file @
9eb8683b
...
...
@@ -165,3 +165,12 @@ def _add_mtp_args(parser):
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Main model share embedding and output weight with mtp layer.'
)
return
parser
def
_add_flux_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--use-flux'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, flux will be used in ColumnParallelLinear and RowParallelLinear'
)
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to transpose weight when using flux kernel'
)
return
parser
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment