Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
4bb958ec
Commit
4bb958ec
authored
May 07, 2025
by
dongcl
Browse files
support a2a_overlap
parent
7c9dc3ec
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
454 additions
and
217 deletions
+454
-217
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+60
-0
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+10
-18
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+159
-1
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+0
-3
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+186
-185
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+1
-0
dcu_megatron/core/transformer/mlp.py
dcu_megatron/core/transformer/mlp.py
+7
-0
dcu_megatron/core/transformer/moe/experts.py
dcu_megatron/core/transformer/moe/experts.py
+6
-0
dcu_megatron/core/transformer/moe/moe_layer.py
dcu_megatron/core/transformer/moe/moe_layer.py
+7
-0
dcu_megatron/core/transformer/multi_latent_attention.py
dcu_megatron/core/transformer/multi_latent_attention.py
+18
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+0
-10
No files found.
dcu_megatron/adaptor/features_manager.py
0 → 100644
View file @
4bb958ec
from
megatron.core.utils
import
is_te_min_version
def
a2a_overlap_adaptation
(
patches_manager
):
"""
patches_manager: MegatronPatchesManager
"""
from
..core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
..core.transformer.transformer_block
import
TransformerBlock
from
..core.transformer.transformer_layer
import
TransformerLayer
from
..core.models.gpt.gpt_model
import
GPTModel
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
from
..core.extensions.transformer_engine
import
_get_extra_te_kwargs_wrapper
,
TELinear
,
TELayerNormColumnParallelLinear
from
..core.transformer.multi_latent_attention
import
MLASelfAttention
from
..core.transformer.mlp
import
MLP
from
..core.transformer.moe.experts
import
TEGroupedMLP
from
..core.transformer.moe.moe_layer
import
MoELayer
# num_warmup_microbatches + 1
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
MoEAlltoAllTokenDispatcher
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_block.TransformerBlock'
,
TransformerBlock
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
TransformerLayer
)
patches_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
# backward_dw
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
TELayerNormColumnParallelLinear
)
if
is_te_min_version
(
"1.9.0.dev0"
):
from
..core.extensions.transformer_engine
import
TEGroupedLinear
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
TEGroupedLinear
)
patches_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention'
,
MLASelfAttention
)
patches_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP'
,
MLP
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP'
,
TEGroupedMLP
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer'
,
MoELayer
)
dcu_megatron/adaptor/megatron_adaptor.py
View file @
4bb958ec
...
...
@@ -24,6 +24,13 @@ class MegatronAdaptation:
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
# apply features
from
.patch_utils
import
MegatronPatchesManager
from
.features_manager
import
a2a_overlap_adaptation
a2a_overlap_adaptation
(
MegatronPatchesManager
)
MegatronPatchesManager
.
apply_patches
()
@
classmethod
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
"""
...
...
@@ -91,14 +98,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
GPTModel
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel
.forward
'
,
gpt_model_forward
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
...
...
@@ -142,18 +149,6 @@ class CoreAdaptation(MegatronAdaptationABC):
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchedLinear
if
is_te_min_version
(
"2.3.0.dev0"
)
else
te
.
pytorch
.
BatchLinear
,)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
...
...
@@ -190,9 +185,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
pass
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
4bb958ec
...
...
@@ -3,7 +3,7 @@ import torch
import
dataclasses
import
transformer_engine
as
te
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Callable
from
packaging.version
import
Version
as
PkgVersion
from
megatron.core.packed_seq_params
import
PackedSeqParams
...
...
@@ -13,6 +13,9 @@ from megatron.core.extensions.transformer_engine import TEDotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.process_groups_config
import
ModelCommProcessGroups
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.extensions.transformer_engine
import
TELinear
as
MegatronCoreTELinear
from
megatron.core.extensions.transformer_engine
import
TELayerNormColumnParallelLinear
as
MegatronCoreTELayerNormColumnParallelLinear
from
megatron.core.parallel_state
import
(
get_context_parallel_global_ranks
,
...
...
@@ -22,6 +25,112 @@ from megatron.core.parallel_state import (
)
def
_get_extra_te_kwargs_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
get
(
"split_bw"
,
False
)
return
extra_transformer_engine_kwargs
return
wrapper
class
TELinear
(
MegatronCoreTELinear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
Optional
[
str
],
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
skip_weight_param_allocation
:
bool
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
is_expert
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
input_size
,
output_size
,
parallel_mode
=
parallel_mode
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
is_expert
=
is_expert
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
class
TELayerNormColumnParallelLinear
(
MegatronCoreTELayerNormColumnParallelLinear
):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
TransformerConfig
,
init_method
:
Callable
,
gather_output
:
bool
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
input_size
,
output_size
,
config
=
config
,
init_method
=
init_method
,
gather_output
=
gather_output
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
self
,
...
...
@@ -176,3 +285,52 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
layer_number
=
layer_number
,
**
extra_kwargs
,
)
if
is_te_min_version
(
"1.9.0.dev0"
):
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
as
MegatronCoreTEGroupedLinear
class
TEGroupedLinear
(
MegatronCoreTEGroupedLinear
):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def
__init__
(
self
,
num_gemms
:
int
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
Optional
[
str
],
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
num_gemms
,
input_size
,
output_size
,
parallel_mode
=
parallel_mode
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
4bb958ec
...
...
@@ -239,7 +239,6 @@ class PostProcessNode(ScheduleNode):
return
loss
class
TransformerLayerNode
(
ScheduleNode
):
def
__init__
(
self
,
chunk_state
,
common_state
,
layer
,
stream
,
event
,
free_inputs
=
False
):
...
...
@@ -598,8 +597,6 @@ def schedule_layer_1f1b(
with
f_context
:
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
def
next_iter_pre_forward
():
if
f_layer
is
not
None
:
with
f_context
:
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
4bb958ec
...
...
@@ -46,6 +46,192 @@ def gpt_model_init_wrapper(fn):
return
wrapper
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_context
:
BaseInferenceContext
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_context
:
assert
(
inference_context
.
is_static_batching
()
),
"GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_context
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_context
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_context
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
elif
self
.
position_embedding_type
==
'mrope'
and
not
self
.
config
.
multi_latent_attention
:
if
self
.
training
or
not
self
.
config
.
flash_decode
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
position_ids
,
self
.
mrope_section
)
else
:
# Flash decoding uses precomputed cos and sin for RoPE
raise
NotImplementedError
(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_context
and
inference_context
.
is_static_batching
()
and
not
self
.
training
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_context
.
sequence_len_offset
]
*
inference_context
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if
(
inference_context
is
not
None
and
not
self
.
training
and
not
has_config_logger_enabled
(
self
.
config
)
):
decoder_input
=
WrappedTensor
(
decoder_input
)
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_context
=
inference_context
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
# Process inference output.
if
inference_context
and
not
inference_context
.
is_static_batching
():
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
if
(
self
.
mtp_process
is
not
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
if
(
not
self
.
training
and
inference_context
is
not
None
and
inference_context
.
is_static_batching
()
and
inference_context
.
materialize_only_last_token_logits
):
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
class
GPTModel
(
MegatronCoreGPTModel
):
"""
patch megatron GPTModel
...
...
@@ -115,188 +301,3 @@ class GPTModel(MegatronCoreGPTModel):
inference_params
=
inference_params
,
loss_mask
=
loss_mask
,
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_context
:
BaseInferenceContext
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_context
:
assert
(
inference_context
.
is_static_batching
()
),
"GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_context
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_context
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_context
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
elif
self
.
position_embedding_type
==
'mrope'
and
not
self
.
config
.
multi_latent_attention
:
if
self
.
training
or
not
self
.
config
.
flash_decode
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
position_ids
,
self
.
mrope_section
)
else
:
# Flash decoding uses precomputed cos and sin for RoPE
raise
NotImplementedError
(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_context
and
inference_context
.
is_static_batching
()
and
not
self
.
training
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_context
.
sequence_len_offset
]
*
inference_context
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if
(
inference_context
is
not
None
and
not
self
.
training
and
not
has_config_logger_enabled
(
self
.
config
)
):
decoder_input
=
WrappedTensor
(
decoder_input
)
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_context
=
inference_context
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
# Process inference output.
if
inference_context
and
not
inference_context
.
is_static_batching
():
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
if
(
self
.
mtp_process
is
not
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
if
(
not
self
.
training
and
inference_context
is
not
None
and
inference_context
.
is_static_batching
()
and
inference_context
.
materialize_only_last_token_logits
):
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
4bb958ec
...
...
@@ -503,6 +503,7 @@ def get_default_cls_for_unwrap():
pass
return
cls
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
return_list
=
True
...
...
dcu_megatron/core/transformer/mlp.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.mlp
import
MLP
as
MegatronCoreMLP
class
MLP
(
MegatronCoreMLP
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
\ No newline at end of file
dcu_megatron/core/transformer/moe/experts.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.experts
import
TEGroupedMLP
as
MegatronCoreTEGroupedMLP
class
TEGroupedMLP
(
MegatronCoreTEGroupedMLP
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
dcu_megatron/core/transformer/moe/moe_layer.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
as
MegatronCoreMoELayer
class
MoELayer
(
MegatronCoreMoELayer
):
def
backward_dw
(
self
):
self
.
experts
.
backward_dw
()
self
.
shared_experts
.
backward_dw
()
dcu_megatron/core/transformer/multi_latent_attention.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.multi_latent_attention
import
MLASelfAttention
as
MegatronCoreMLASelfAttention
class
MLASelfAttention
(
MegatronCoreMLASelfAttention
):
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
backward_dw
(
self
):
self
.
linear_kv_up_proj
.
backward_dw
()
self
.
linear_kv_down_proj
.
backward_dw
()
if
self
.
config
.
q_lora_rank
is
None
:
self
.
linear_q_proj
.
backward_dw
()
else
:
self
.
linear_q_down_proj
.
backward_dw
()
self
.
linear_q_up_proj
.
backward_dw
()
self
.
linear_proj
.
backward_dw
()
dcu_megatron/core/transformer/transformer_block.py
View file @
4bb958ec
...
...
@@ -17,16 +17,6 @@ def transformer_block_init_wrapper(fn):
class
TransformerBlock
(
MegatronCoreTransformerBlock
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"mtp_num_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
def
get_layer_callables
(
self
,
layer_number
:
int
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment