Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
72aeb0f3
Commit
72aeb0f3
authored
Apr 16, 2025
by
dongcl
Browse files
bug fix
parent
d46a984e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
9 deletions
+6
-9
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+1
-4
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+1
-1
dcu_megatron/core/transformer/mtp/mtp_spec.py
dcu_megatron/core/transformer/mtp/mtp_spec.py
+4
-4
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
72aeb0f3
...
@@ -5,8 +5,6 @@ import types
...
@@ -5,8 +5,6 @@ import types
import
argparse
import
argparse
import
torch
import
torch
from
megatron.training
import
get_args
class
MegatronAdaptation
:
class
MegatronAdaptation
:
"""
"""
...
@@ -191,8 +189,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -191,8 +189,7 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# flux
# flux
args
=
get_args
()
if
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
0
):
if
args
.
use_flux
:
import
flux
import
flux
from
..core.tensor_parallel
import
(
from
..core.tensor_parallel
import
(
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
72aeb0f3
...
@@ -284,7 +284,7 @@ class AGLinear(torch.autograd.Function):
...
@@ -284,7 +284,7 @@ class AGLinear(torch.autograd.Function):
)
)
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
cuda
.
current_stream
().
synchronize
()
grad_input
=
grad_input
.
view
(
sequence_len
//
get_tensor_model_parallel_
world_size
()
,
batch_size
,
-
1
)
grad_input
=
grad_input
.
view
(
sequence_len
//
world_size
,
batch_size
,
-
1
)
else
:
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
...
...
dcu_megatron/core/transformer/mtp/mtp_spec.py
View file @
72aeb0f3
...
@@ -30,7 +30,7 @@ except ImportError:
...
@@ -30,7 +30,7 @@ except ImportError:
LNImpl
=
WrappedTorchNorm
LNImpl
=
WrappedTorchNorm
def
get_mtp_spec
(
transformer_layer
,
use_te
=
False
):
def
get_mtp_spec
(
transformer_layer
,
use_te
=
False
,
use_flux
=
False
):
"""
"""
Multi Token Predication Layer Specification.
Multi Token Predication Layer Specification.
"""
"""
...
@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False):
...
@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False):
module
=
MultiTokenPredictor
,
module
=
MultiTokenPredictor
,
submodules
=
MultiTokenPredicationSubmodules
(
submodules
=
MultiTokenPredicationSubmodules
(
embedding
=
None
,
embedding
=
None
,
enorm
=
TENorm
if
use_te
else
LNImpl
,
enorm
=
TENorm
if
use_te
or
use_flux
else
LNImpl
,
hnorm
=
TENorm
if
use_te
else
LNImpl
,
hnorm
=
TENorm
if
use_te
or
use_flux
else
LNImpl
,
eh_proj
=
TEColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
eh_proj
=
TEColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
transformer_layer
=
transformer_layer
,
transformer_layer
=
transformer_layer
,
final_layernorm
=
TENorm
if
use_te
else
LNImpl
,
final_layernorm
=
TENorm
if
use_te
or
use_flux
else
LNImpl
,
output_layer
=
None
,
output_layer
=
None
,
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment