Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
23eb9b17
"vscode:/vscode.git/clone" did not exist on "915140fd18c9ff4193e994e6d756ea762a52240a"
Commit
23eb9b17
authored
May 06, 2025
by
dongcl
Browse files
modify TEGroupedLinear base
parent
43770f8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+7
-4
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
23eb9b17
...
@@ -5,6 +5,8 @@ import types
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
argparse
import
torch
import
torch
from
megatron.core.utils
import
is_te_min_version
class
MegatronAdaptation
:
class
MegatronAdaptation
:
"""
"""
...
@@ -132,12 +134,13 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -132,12 +134,13 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..core.extensions.transformer_engine
import
TEDotProductAttentionPatch
from
..core.extensions.transformer_engine
import
TEDotProductAttentionPatch
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0
if
not
is_te_min_version
(
"1.10.0"
):
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
# kv channels, te_min_version 1.10.0 -> 1.9.0
TEDotProductAttentionPatch
.
__init__
)
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
TEDotProductAttentionPatch
.
__init__
)
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchedLinear
if
is_te_min_version
(
"2.3.0"
)
else
te
.
pytorch
.
BatchLinear
,)
def
patch_tensor_parallel
(
self
):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment