Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
0b2b5417
"scripts/convert_asymmetric_vqgan_to_diffusers.py" did not exist on "2551b73670ee886515fc5a0d49fc304ebd0d7b51"
Commit
0b2b5417
authored
Apr 03, 2025
by
dongcl
Browse files
rewrite transformer_engine
parent
f098f250
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
128 additions
and
126 deletions
+128
-126
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+2
-2
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+126
-124
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
0b2b5417
...
...
@@ -143,11 +143,11 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_core_extentions
(
self
):
import
transformer_engine
as
te
from
..core.extensions.transformer_engine
import
te_dot_p
roduct
_a
ttention
_init
from
..core.extensions.transformer_engine
import
TEDotP
roduct
A
ttention
Patch
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
te_dot_p
roduct
_a
ttention_init
)
TEDotP
roduct
A
ttention
Patch
.
_
_init
__
)
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
0b2b5417
import
os
import
dataclasses
import
transformer_engine
as
te
from
typing
import
Any
,
Optional
from
packaging.version
import
Version
as
PkgVersion
...
...
@@ -19,135 +20,136 @@ from megatron.core.parallel_state import (
)
def
te_dot_product_attention_init
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
,
attn_mask_type
:
AttnMaskType
,
attention_type
:
str
,
attention_dropout
:
Optional
[
float
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
k_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
cp_comm_type
:
str
=
"p2p"
,
):
self
.
config
=
config
self
.
te_forward_mask_type
=
False
self
.
qkv_format
:
str
=
'sbhd'
if
self
.
config
.
apply_query_key_layer_scaling
!=
bool
(
int
(
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
,
'0'
))
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
,
attn_mask_type
:
AttnMaskType
,
attention_type
:
str
,
attention_dropout
:
Optional
[
float
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
k_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
cp_comm_type
:
str
=
"p2p"
,
):
raise
ValueError
(
f
"apply_query_key_layer_scaling is
{
self
.
config
.
apply_query_key_layer_scaling
}
"
f
"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f
"
{
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
)
}
. Transformer Engine does not support "
f
"setting query key layer scaling via argument, so these two must match."
)
self
.
config
=
config
self
.
te_forward_mask_type
=
False
self
.
qkv_format
:
str
=
'sbhd'
if
self
.
config
.
apply_query_key_layer_scaling
!=
bool
(
int
(
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
,
'0'
))
):
raise
ValueError
(
f
"apply_query_key_layer_scaling is
{
self
.
config
.
apply_query_key_layer_scaling
}
"
f
"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f
"
{
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
)
}
. Transformer Engine does not support "
f
"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs
:
dict
[
str
,
Any
]
=
{}
if
is_te_min_version
(
"0.11.0"
):
extra_kwargs
[
"num_gqa_groups"
]
=
self
.
config
.
num_query_groups
elif
self
.
config
.
num_query_groups
!=
self
.
config
.
num_attention_heads
:
raise
ValueError
(
f
"Transformer Engine v
{
get_te_version
()
}
does not support Grouped Query Attention, "
f
"use a newer version of Transformer Engine. "
f
"(num_query_groups (
{
self
.
config
.
num_query_groups
}
) != "
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
)
extra_kwargs
:
dict
[
str
,
Any
]
=
{}
if
is_te_min_version
(
"0.11.0"
):
extra_kwargs
[
"num_gqa_groups"
]
=
self
.
config
.
num_query_groups
elif
self
.
config
.
num_query_groups
!=
self
.
config
.
num_attention_heads
:
raise
ValueError
(
f
"Transformer Engine v
{
get_te_version
()
}
does not support Grouped Query Attention, "
f
"use a newer version of Transformer Engine. "
f
"(num_query_groups (
{
self
.
config
.
num_query_groups
}
) != "
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
)
if
is_te_min_version
(
"0.10.0"
):
extra_kwargs
[
"attention_type"
]
=
attention_type
# older version don't need attention_type
if
is_te_min_version
(
"0.12.0"
,
check_equality
=
False
):
self
.
te_forward_mask_type
=
True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if
self
.
config
.
context_parallel_size
>
1
:
assert
is_te_min_version
(
"1.0.0"
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
extra_kwargs
[
"cp_group"
]
=
get_context_parallel_group
(
check_initialized
=
False
)
extra_kwargs
[
"cp_global_ranks"
]
=
get_context_parallel_global_ranks
(
check_initialized
=
False
)
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
if
is_te_min_version
(
"1.10.0"
):
if
cp_comm_type
is
None
:
extra_kwargs
[
"cp_comm_type"
]
=
"p2p"
elif
cp_comm_type
==
"a2a+p2p"
:
assert
is_te_min_version
(
"1.12.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs
[
"cp_comm_type"
]
=
"a2a+p2p"
extra_kwargs
[
"cp_group"
]
=
get_hierarchical_context_parallel_groups
(
check_initialized
=
False
if
is_te_min_version
(
"0.10.0"
):
extra_kwargs
[
"attention_type"
]
=
attention_type
# older version don't need attention_type
if
is_te_min_version
(
"0.12.0"
,
check_equality
=
False
):
self
.
te_forward_mask_type
=
True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if
self
.
config
.
context_parallel_size
>
1
:
assert
is_te_min_version
(
"1.0.0"
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
extra_kwargs
[
"cp_group"
]
=
get_context_parallel_group
(
check_initialized
=
False
)
extra_kwargs
[
"cp_global_ranks"
]
=
get_context_parallel_global_ranks
(
check_initialized
=
False
)
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
if
is_te_min_version
(
"1.10.0"
):
if
cp_comm_type
is
None
:
extra_kwargs
[
"cp_comm_type"
]
=
"p2p"
elif
cp_comm_type
==
"a2a+p2p"
:
assert
is_te_min_version
(
"1.12.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs
[
"cp_comm_type"
]
=
"a2a+p2p"
extra_kwargs
[
"cp_group"
]
=
get_hierarchical_context_parallel_groups
(
check_initialized
=
False
)
else
:
extra_kwargs
[
"cp_comm_type"
]
=
cp_comm_type
if
self
.
config
.
deterministic_mode
:
if
int
(
os
.
getenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
,
"1"
))
!=
0
:
raise
RuntimeError
(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f
"Currently set to:
{
os
.
getenv
(
'NVTE_ALLOW_NONDETERMINISTIC_ALGO'
,
'not set'
)
}
."
)
else
:
extra_kwargs
[
"cp_comm_type"
]
=
cp_comm_type
if
self
.
config
.
deterministic_mode
:
if
int
(
os
.
getenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
,
"1"
))
!=
0
:
raise
RuntimeError
(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f
"Currently set to:
{
os
.
getenv
(
'NVTE_ALLOW_NONDETERMINISTIC_ALGO'
,
'not set'
)
}
."
if
config
.
window_size
is
not
None
:
# Check version
assert
is_te_min_version
(
"1.2.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs
[
'window_size'
]
=
config
.
window_size
if
is_te_min_version
(
"1.9.0"
):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels
=
(
(
k_channels
,
v_channels
)
if
k_channels
is
not
None
and
v_channels
is
not
None
else
self
.
config
.
kv_channels
)
extra_kwargs
[
'softmax_scale'
]
=
softmax_scale
else
:
kv_channels
=
self
.
config
.
kv_channels
if
config
.
window_size
is
not
None
:
# Check version
assert
is_te_min_version
(
"1.2.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.2.0 to support"
"sliding window attention."
self
.
kept_packed_seq_params
=
set
(
field
.
name
for
field
in
dataclasses
.
fields
(
PackedSeqParams
)
)
extra_kwargs
[
'window_size'
]
=
config
.
window_size
if
is_te_min_version
(
"1.9.0"
):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels
=
(
(
k_channels
,
v_channels
)
if
k_channels
is
not
None
and
v_channels
is
not
None
else
self
.
config
.
kv_channels
if
get_te_version
()
<
PkgVersion
(
"1.3.0"
):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_q"
)
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_kv"
)
if
get_te_version
()
<
PkgVersion
(
"1.10.0"
):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_q_padded"
)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_kv_padded"
)
super
(
TEDotProductAttention
,
self
).
__init__
(
num_attention_heads
=
self
.
config
.
num_attention_heads
,
kv_channels
=
kv_channels
,
attention_dropout
=
(
self
.
config
.
attention_dropout
if
attention_dropout
is
None
else
attention_dropout
),
attn_mask_type
=
attn_mask_type
.
name
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
layer_number
=
layer_number
,
**
extra_kwargs
,
)
extra_kwargs
[
'softmax_scale'
]
=
softmax_scale
else
:
kv_channels
=
self
.
config
.
kv_channels
self
.
kept_packed_seq_params
=
set
(
field
.
name
for
field
in
dataclasses
.
fields
(
PackedSeqParams
)
)
if
get_te_version
()
<
PkgVersion
(
"1.3.0"
):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_q"
)
self
.
kept_packed_seq_params
.
discard
(
"max_seqlen_kv"
)
if
get_te_version
()
<
PkgVersion
(
"1.10.0"
):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_q_padded"
)
self
.
kept_packed_seq_params
.
discard
(
"cu_seqlens_kv_padded"
)
super
(
TEDotProductAttention
,
self
).
__init__
(
num_attention_heads
=
self
.
config
.
num_attention_heads
,
kv_channels
=
kv_channels
,
attention_dropout
=
(
self
.
config
.
attention_dropout
if
attention_dropout
is
None
else
attention_dropout
),
attn_mask_type
=
attn_mask_type
.
name
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
layer_number
=
layer_number
,
**
extra_kwargs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment