Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
390eac88
Commit
390eac88
authored
Apr 17, 2025
by
dongcl
Browse files
bug fix
parent
ec7c8bc3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+3
-3
dcu_megatron/core/transformer/mtp/mtp_spec.py
dcu_megatron/core/transformer/mtp/mtp_spec.py
+4
-4
No files found.
dcu_megatron/core/tensor_parallel/layers.py
View file @
390eac88
...
...
@@ -1012,7 +1012,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
return
output
,
output_bias
class
FluxRowParallelLinear
(
torch
.
nn
.
Module
):
class
FluxRowParallelLinear
(
RowParallelLinear
):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
...
...
@@ -1064,7 +1064,7 @@ class FluxRowParallelLinear(torch.nn.Module):
tp_comm_buffer_name
:
str
=
None
,
# Not used
):
super
(
FluxRowParallelLinear
,
self
)
__init__
(
super
(
FluxRowParallelLinear
,
self
)
.
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
config
=
config
,
...
...
@@ -1161,7 +1161,7 @@ class FluxRowParallelLinear(torch.nn.Module):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
and
self
.
sequence_parallel
else
None
,
gradient_accumulation_fusion
=
self
.
gradient_accumulation_fusion
,
allreduce_dgrad
=
False
,
sequence_parallel
=
False
if
explicit_expert_comm
else
self
.
sequence_parallel
,
sequence_parallel
=
False
if
self
.
explicit_expert_comm
else
self
.
sequence_parallel
,
grad_output_buffer
=
None
,
transpose_weight
=
self
.
flux_transpose_weight
,
fw_gemm_rs_op
=
self
.
fw_gemm_rs_op
,
...
...
dcu_megatron/core/transformer/mtp/mtp_spec.py
View file @
390eac88
...
...
@@ -30,7 +30,7 @@ except ImportError:
LNImpl
=
WrappedTorchNorm
def
get_mtp_spec
(
transformer_layer
,
use_te
=
False
,
use_flux
=
False
):
def
get_mtp_spec
(
transformer_layer
,
use_te
=
False
):
"""
Multi Token Predication Layer Specification.
"""
...
...
@@ -39,11 +39,11 @@ def get_mtp_spec(transformer_layer, use_te=False, use_flux=False):
module
=
MultiTokenPredictor
,
submodules
=
MultiTokenPredicationSubmodules
(
embedding
=
None
,
enorm
=
TENorm
if
use_te
or
use_flux
else
LNImpl
,
hnorm
=
TENorm
if
use_te
or
use_flux
else
LNImpl
,
enorm
=
TENorm
if
use_te
else
LNImpl
,
hnorm
=
TENorm
if
use_te
else
LNImpl
,
eh_proj
=
TEColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
transformer_layer
=
transformer_layer
,
final_layernorm
=
TENorm
if
use_te
or
use_flux
else
LNImpl
,
final_layernorm
=
TENorm
if
use_te
else
LNImpl
,
output_layer
=
None
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment