Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
bf323343
Commit
bf323343
authored
Apr 21, 2025
by
dongcl
Browse files
support flux for mtp
parent
31e933a8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
18 deletions
+17
-18
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+0
-3
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
...core/models/common/embeddings/language_model_embedding.py
+0
-2
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
+17
-13
No files found.
dcu_megatron/core/distributed/finalize_model_grads.py
View file @
bf323343
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from
typing
import
List
import
torch
...
...
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
View file @
bf323343
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from
typing
import
Literal
import
torch
...
...
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
View file @
bf323343
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
os
import
logging
from
dataclasses
import
dataclass
from
typing
import
Union
,
Optional
,
Literal
...
...
@@ -137,18 +137,22 @@ class MultiTokenPredictor(MegatronModule):
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
self
.
add_output_layer_bias
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
share_mtp_embedding_and_output_weight
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
column_parallel_linear_impl
=
FluxColumnParallelLinear
else
:
column_parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
column_parallel_linear_impl
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
share_mtp_embedding_and_output_weight
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
def
forward
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment