Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
0b2b5417
Commit
0b2b5417
authored
Apr 03, 2025
by
dongcl
Browse files
rewrite transformer_engine
parent
f098f250
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
128 additions
and
126 deletions
+128
-126
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+2
-2
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+126
-124
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
0b2b5417
...
@@ -143,11 +143,11 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -143,11 +143,11 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_core_extentions
(
self
):
def
patch_core_extentions
(
self
):
import
transformer_engine
as
te
import
transformer_engine
as
te
from
..core.extensions.transformer_engine
import
te_dot_p
roduct
_a
ttention
_init
from
..core.extensions.transformer_engine
import
TEDotP
roduct
A
ttention
Patch
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
te_dot_p
roduct
_a
ttention_init
)
TEDotP
roduct
A
ttention
Patch
.
_
_init
__
)
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
0b2b5417
import
os
import
os
import
dataclasses
import
dataclasses
import
transformer_engine
as
te
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
from
packaging.version
import
Version
as
PkgVersion
from
packaging.version
import
Version
as
PkgVersion
...
@@ -19,7 +20,8 @@ from megatron.core.parallel_state import (
...
@@ -19,7 +20,8 @@ from megatron.core.parallel_state import (
)
)
def
te_dot_product_attention_init
(
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
self
,
self
,
config
:
TransformerConfig
,
config
:
TransformerConfig
,
layer_number
:
int
,
layer_number
:
int
,
...
@@ -30,7 +32,7 @@ def te_dot_product_attention_init(
...
@@ -30,7 +32,7 @@ def te_dot_product_attention_init(
k_channels
:
Optional
[
int
]
=
None
,
k_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
cp_comm_type
:
str
=
"p2p"
,
cp_comm_type
:
str
=
"p2p"
,
):
):
self
.
config
=
config
self
.
config
=
config
self
.
te_forward_mask_type
=
False
self
.
te_forward_mask_type
=
False
self
.
qkv_format
:
str
=
'sbhd'
self
.
qkv_format
:
str
=
'sbhd'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment