Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
66c3d7c9
Commit
66c3d7c9
authored
Apr 26, 2025
by
sdwldchl
Browse files
rewrite mtp
parent
1f7b14ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
0 additions
and
472 deletions
+0
-472
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
...core/models/common/embeddings/language_model_embedding.py
+0
-133
dcu_megatron/core/transformer/mtp/mtp_spec.py
dcu_megatron/core/transformer/mtp/mtp_spec.py
+0
-51
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
+0
-258
dcu_megatron/core/utils.py
dcu_megatron/core/utils.py
+0
-30
No files found.
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
deleted
100644 → 0
View file @
1f7b14ab
from
typing
import
Literal
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
def
language_model_embedding_init_func
(
self
,
config
:
TransformerConfig
,
vocab_size
:
int
,
max_sequence_length
:
int
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
num_tokentypes
:
int
=
0
,
scatter_to_sequence_parallel
:
bool
=
True
,
skip_weight_param_allocation
:
bool
=
False
):
"""Patch language model embeddings init."""
super
(
LanguageModelEmbedding
,
self
).
__init__
(
config
=
config
)
self
.
config
:
TransformerConfig
=
config
self
.
vocab_size
:
int
=
vocab_size
self
.
max_sequence_length
:
int
=
max_sequence_length
self
.
add_position_embedding
:
bool
=
position_embedding_type
==
'learned_absolute'
self
.
num_tokentypes
=
num_tokentypes
self
.
scatter_to_sequence_parallel
=
scatter_to_sequence_parallel
self
.
reduce_scatter_embeddings
=
(
(
not
self
.
add_position_embedding
)
and
self
.
num_tokentypes
<=
0
and
self
.
config
.
sequence_parallel
and
self
.
scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
num_embeddings
=
self
.
vocab_size
,
embedding_dim
=
self
.
config
.
hidden_size
,
init_method
=
self
.
config
.
init_method
,
reduce_scatter_embeddings
=
self
.
reduce_scatter_embeddings
,
config
=
self
.
config
,
skip_weight_param_allocation
=
skip_weight_param_allocation
)
# Position embedding (serial).
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
max_sequence_length
,
self
.
config
.
hidden_size
)
# Initialize the position embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
position_embeddings
.
weight
)
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
self
.
num_tokentypes
,
self
.
config
.
hidden_size
)
# Initialize the token-type embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
else
:
self
.
tokentype_embeddings
=
None
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
self
.
config
.
hidden_dropout
)
def
language_model_embedding_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
tokentype_ids
:
int
=
None
,
weight
:
Tensor
=
None
)
->
Tensor
:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if
weight
is
None
:
if
self
.
word_embeddings
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
word_embeddings
.
weight
word_embeddings
=
self
.
word_embeddings
(
input_ids
,
weight
)
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
word_embeddings
+
position_embeddings
else
:
embeddings
=
word_embeddings
if
not
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding
=
self
.
tokentype_embeddings
(
tokentype_ids
).
permute
(
1
,
0
,
2
)
embeddings
=
embeddings
+
tokentype_embedding
else
:
assert
self
.
tokentype_embeddings
is
None
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
config
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
if
self
.
config
.
sequence_parallel
:
if
not
self
.
reduce_scatter_embeddings
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if
self
.
config
.
clone_scatter_output_in_embedding
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
embeddings
.
clone
()
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
dcu_megatron/core/transformer/mtp/mtp_spec.py
deleted
100644 → 0
View file @
1f7b14ab
import
warnings
from
megatron.core.tensor_parallel
import
ColumnParallelLinear
from
megatron.core.transformer
import
ModuleSpec
from
.multi_token_predictor
import
(
MultiTokenPredicationSubmodules
,
MultiTokenPredictor
)
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TENorm
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
try
:
import
apex
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
LNImpl
=
FusedLayerNorm
except
ImportError
:
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
get_mtp_spec
(
transformer_layer
,
use_te
=
False
):
"""
Multi Token Predication Layer Specification.
"""
use_te
=
use_te
&
HAVE_TE
mtp_spec
=
ModuleSpec
(
module
=
MultiTokenPredictor
,
submodules
=
MultiTokenPredicationSubmodules
(
embedding
=
None
,
enorm
=
TENorm
if
use_te
else
LNImpl
,
hnorm
=
TENorm
if
use_te
else
LNImpl
,
eh_proj
=
TEColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
transformer_layer
=
transformer_layer
,
final_layernorm
=
TENorm
if
use_te
else
LNImpl
,
output_layer
=
None
,
)
)
return
mtp_spec
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
deleted
100644 → 0
View file @
1f7b14ab
import
os
import
logging
from
dataclasses
import
dataclass
from
typing
import
Union
,
Optional
,
Literal
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
,
InferenceParams
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
from
megatron.core.fusions.fused_cross_entropy
import
fused_vocab_parallel_cross_entropy
from
megatron.core.transformer
import
ModuleSpec
,
TransformerConfig
,
build_module
from
...tensor_parallel.random
import
CheckpointWithoutOutput
from
...tensor_parallel
import
FluxColumnParallelLinear
@
dataclass
class
MultiTokenPredicationSubmodules
:
embedding
:
Union
[
ModuleSpec
,
type
]
=
None
output_layer
:
Union
[
ModuleSpec
,
type
]
=
None
eh_proj
:
Union
[
ModuleSpec
,
type
]
=
None
enorm
:
Union
[
ModuleSpec
,
type
]
=
None
hnorm
:
Union
[
ModuleSpec
,
type
]
=
None
transformer_layer
:
Union
[
ModuleSpec
,
type
]
=
None
final_layernorm
:
Union
[
ModuleSpec
,
type
]
=
None
class
MultiTokenPredictor
(
MegatronModule
):
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MultiTokenPredicationSubmodules
,
vocab_size
:
int
,
max_sequence_length
:
int
,
layer_number
:
int
=
1
,
hidden_dropout
:
float
=
None
,
pre_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
recompute_mtp_norm
=
False
,
recompute_mtp_layer
=
False
,
add_output_layer_bias
=
False
):
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
submodules
=
submodules
self
.
layer_number
=
layer_number
self
.
hidden_dropout
=
hidden_dropout
self
.
hidden_size
=
self
.
config
.
hidden_size
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
position_embedding_type
=
position_embedding_type
self
.
recompute_layer_norm
=
recompute_mtp_norm
self
.
recompute_mtp_layer
=
recompute_mtp_layer
self
.
add_output_layer_bias
=
add_output_layer_bias
if
self
.
position_embedding_type
==
'rope'
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
self
.
enorm
=
build_module
(
self
.
submodules
.
enorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
hnorm
=
build_module
(
self
.
submodules
.
hnorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
eh_proj
=
build_module
(
self
.
submodules
.
eh_proj
,
self
.
hidden_size
+
self
.
hidden_size
,
self
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
gather_output
=
False
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
True
,
is_expert
=
False
,
tp_comm_buffer_name
=
'eh'
,
)
self
.
transformer_layer
=
build_module
(
self
.
submodules
.
transformer_layer
,
config
=
self
.
config
,
)
if
self
.
submodules
.
final_layernorm
:
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
final_layernorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
else
:
self
.
final_layernorm
=
None
if
self
.
config
.
defer_embedding_wgrad_compute
:
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
def
forward
(
self
,
hidden_input_ids
:
Tensor
,
embed_input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
embedding_layer
=
None
,
output_layer
=
None
,
output_weight
=
None
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input
=
embedding_layer
(
input_ids
=
embed_input_ids
,
position_ids
=
position_ids
,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
inference_params
is
not
None
:
rotary_seq_len
=
inference_params
.
max_sequence_length
else
:
rotary_seq_len
=
decoder_input
.
size
(
0
)
if
self
.
config
.
sequence_parallel
:
rotary_seq_len
*=
self
.
config
.
tensor_model_parallel_size
rotary_seq_len
*=
self
.
config
.
context_parallel_size
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
if
self
.
recompute_layer_norm
:
self
.
enorm_ckpt
=
CheckpointWithoutOutput
()
enorm_output
=
self
.
enorm_ckpt
.
checkpoint
(
self
.
enorm
,
False
,
decoder_input
)
self
.
hnorm_ckpt
=
CheckpointWithoutOutput
()
hnorm_output
=
self
.
hnorm_ckpt
.
checkpoint
(
self
.
hnorm
,
False
,
hidden_input_ids
)
else
:
enorm_output
=
self
.
enorm
(
decoder_input
)
hnorm_output
=
self
.
hnorm
(
hidden_input_ids
)
# [s, b, h] -> [s, b, 2h]
hidden_states
=
torch
.
concat
(
[
hnorm_output
,
enorm_output
],
dim
=-
1
)
if
self
.
recompute_layer_norm
:
self
.
enorm_ckpt
.
discard_output
()
self
.
hnorm_ckpt
.
discard_output
()
hidden_states
.
register_hook
(
self
.
enorm_ckpt
.
recompute
)
hidden_states
.
register_hook
(
self
.
hnorm_ckpt
.
recompute
)
# hidden_states -> [s, b, h]
hidden_states
,
_
=
self
.
eh_proj
(
hidden_states
)
if
self
.
config
.
tensor_model_parallel_size
>
1
:
hidden_states
=
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
hidden_states
)
if
self
.
config
.
sequence_parallel
:
hidden_states
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
hidden_states
)
if
self
.
recompute_mtp_layer
:
hidden_states
,
context
=
tensor_parallel
.
checkpoint
(
self
.
transformer_layer
,
self
.
config
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
None
,
None
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
)
else
:
hidden_states
,
_
=
self
.
transformer_layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
**
(
extra_block_kwargs
or
{}),
)
# Final layer norm.
if
self
.
final_layernorm
is
not
None
:
if
self
.
recompute_layer_norm
:
self
.
finalnorm_ckpt
=
CheckpointWithoutOutput
()
finalnorm_output
=
self
.
finalnorm_ckpt
.
checkpoint
(
self
.
final_layernorm
,
False
,
hidden_states
)
else
:
finalnorm_output
=
self
.
final_layernorm
(
hidden_states
)
else
:
finalnorm_output
=
hidden_states
logits
,
_
=
output_layer
(
finalnorm_output
,
weight
=
output_weight
)
if
self
.
recompute_layer_norm
:
self
.
finalnorm_ckpt
.
discard_output
()
logits
.
register_hook
(
self
.
finalnorm_ckpt
.
recompute
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
hidden_states
,
loss
def
compute_language_model_loss
(
self
,
labels
:
Tensor
,
logits
:
Tensor
)
->
Tensor
:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
config
.
cross_entropy_loss_fusion
:
loss
=
fused_vocab_parallel_cross_entropy
(
logits
,
labels
)
else
:
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
logits
,
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
return
loss
\ No newline at end of file
dcu_megatron/core/utils.py
View file @
66c3d7c9
...
...
@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
if
check_equality
:
return
get_flux_version
()
>=
PkgVersion
(
version
)
return
get_flux_version
()
>
PkgVersion
(
version
)
def
tensor_slide
(
tensor
:
Optional
[
torch
.
Tensor
],
num_slice
:
int
,
dims
:
Union
[
int
,
List
[
int
]]
=
-
1
,
step
:
int
=
1
,
return_first
=
False
,
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
"""通用滑动窗口函数,支持任意维度"""
if
tensor
is
None
:
# return `List[None]` to avoid NoneType Error
return
[
None
]
*
(
num_slice
+
1
)
if
num_slice
==
0
:
return
[
tensor
]
window_size
=
tensor
.
shape
[
-
1
]
-
num_slice
dims
=
[
dims
]
if
isinstance
(
dims
,
int
)
else
sorted
(
dims
,
reverse
=
True
)
# 连续多维度滑动
slices
=
[]
for
i
in
range
(
0
,
tensor
.
size
(
dims
[
-
1
])
-
window_size
+
1
,
step
):
slice_obj
=
[
slice
(
None
)]
*
tensor
.
dim
()
for
dim
in
dims
:
slice_obj
[
dim
]
=
slice
(
i
,
i
+
window_size
)
slices
.
append
(
tensor
[
tuple
(
slice_obj
)])
if
return_first
:
return
slices
return
slices
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment