Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
4bb958ec
Commit
4bb958ec
authored
May 07, 2025
by
dongcl
Browse files
support a2a_overlap
parent
7c9dc3ec
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
454 additions
and
217 deletions
+454
-217
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+60
-0
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+10
-18
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+159
-1
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+0
-3
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+186
-185
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+1
-0
dcu_megatron/core/transformer/mlp.py
dcu_megatron/core/transformer/mlp.py
+7
-0
dcu_megatron/core/transformer/moe/experts.py
dcu_megatron/core/transformer/moe/experts.py
+6
-0
dcu_megatron/core/transformer/moe/moe_layer.py
dcu_megatron/core/transformer/moe/moe_layer.py
+7
-0
dcu_megatron/core/transformer/multi_latent_attention.py
dcu_megatron/core/transformer/multi_latent_attention.py
+18
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+0
-10
No files found.
dcu_megatron/adaptor/features_manager.py
0 → 100644
View file @
4bb958ec
from
megatron.core.utils
import
is_te_min_version
def
a2a_overlap_adaptation
(
patches_manager
):
"""
patches_manager: MegatronPatchesManager
"""
from
..core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
..core.transformer.transformer_block
import
TransformerBlock
from
..core.transformer.transformer_layer
import
TransformerLayer
from
..core.models.gpt.gpt_model
import
GPTModel
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
from
..core.extensions.transformer_engine
import
_get_extra_te_kwargs_wrapper
,
TELinear
,
TELayerNormColumnParallelLinear
from
..core.transformer.multi_latent_attention
import
MLASelfAttention
from
..core.transformer.mlp
import
MLP
from
..core.transformer.moe.experts
import
TEGroupedMLP
from
..core.transformer.moe.moe_layer
import
MoELayer
# num_warmup_microbatches + 1
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
MoEAlltoAllTokenDispatcher
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_block.TransformerBlock'
,
TransformerBlock
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
TransformerLayer
)
patches_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
# backward_dw
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
TELayerNormColumnParallelLinear
)
if
is_te_min_version
(
"1.9.0.dev0"
):
from
..core.extensions.transformer_engine
import
TEGroupedLinear
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
TEGroupedLinear
)
patches_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention'
,
MLASelfAttention
)
patches_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP'
,
MLP
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP'
,
TEGroupedMLP
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer'
,
MoELayer
)
dcu_megatron/adaptor/megatron_adaptor.py
View file @
4bb958ec
...
@@ -24,6 +24,13 @@ class MegatronAdaptation:
...
@@ -24,6 +24,13 @@ class MegatronAdaptation:
adaptation
.
execute
()
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
MegatronAdaptation
.
apply
()
# apply features
from
.patch_utils
import
MegatronPatchesManager
from
.features_manager
import
a2a_overlap_adaptation
a2a_overlap_adaptation
(
MegatronPatchesManager
)
MegatronPatchesManager
.
apply_patches
()
@
classmethod
@
classmethod
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
"""
"""
...
@@ -91,14 +98,14 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -91,14 +98,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
pass
def
patch_core_models
(
self
):
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
GPTModel
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
# GPT Model
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel
.forward
'
,
GPTModel
)
gpt_model_forward
)
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core
import
transformer_block_init_wrapper
...
@@ -142,18 +149,6 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -142,18 +149,6 @@ class CoreAdaptation(MegatronAdaptationABC):
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchedLinear
if
is_te_min_version
(
"2.3.0.dev0"
)
else
te
.
pytorch
.
BatchLinear
,)
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchedLinear
if
is_te_min_version
(
"2.3.0.dev0"
)
else
te
.
pytorch
.
BatchLinear
,)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
def
patch_tensor_parallel
(
self
):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
...
@@ -190,9 +185,6 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -190,9 +185,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
pass
def
patch_training
(
self
):
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
4bb958ec
...
@@ -3,7 +3,7 @@ import torch
...
@@ -3,7 +3,7 @@ import torch
import
dataclasses
import
dataclasses
import
transformer_engine
as
te
import
transformer_engine
as
te
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Callable
from
packaging.version
import
Version
as
PkgVersion
from
packaging.version
import
Version
as
PkgVersion
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
...
@@ -13,6 +13,9 @@ from megatron.core.extensions.transformer_engine import TEDotProductAttention
...
@@ -13,6 +13,9 @@ from megatron.core.extensions.transformer_engine import TEDotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.process_groups_config
import
ModelCommProcessGroups
from
megatron.core.process_groups_config
import
ModelCommProcessGroups
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.extensions.transformer_engine
import
TELinear
as
MegatronCoreTELinear
from
megatron.core.extensions.transformer_engine
import
TELayerNormColumnParallelLinear
as
MegatronCoreTELayerNormColumnParallelLinear
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_context_parallel_global_ranks
,
get_context_parallel_global_ranks
,
...
@@ -22,6 +25,112 @@ from megatron.core.parallel_state import (
...
@@ -22,6 +25,112 @@ from megatron.core.parallel_state import (
)
)
def
_get_extra_te_kwargs_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
get
(
"split_bw"
,
False
)
return
extra_transformer_engine_kwargs
return
wrapper
class
TELinear
(
MegatronCoreTELinear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
Optional
[
str
],
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
skip_weight_param_allocation
:
bool
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
is_expert
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
input_size
,
output_size
,
parallel_mode
=
parallel_mode
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
is_expert
=
is_expert
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
class
TELayerNormColumnParallelLinear
(
MegatronCoreTELayerNormColumnParallelLinear
):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
TransformerConfig
,
init_method
:
Callable
,
gather_output
:
bool
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
input_size
,
output_size
,
config
=
config
,
init_method
=
init_method
,
gather_output
=
gather_output
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -176,3 +285,52 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
...
@@ -176,3 +285,52 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
layer_number
=
layer_number
,
layer_number
=
layer_number
,
**
extra_kwargs
,
**
extra_kwargs
,
)
)
if
is_te_min_version
(
"1.9.0.dev0"
):
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
as
MegatronCoreTEGroupedLinear
class
TEGroupedLinear
(
MegatronCoreTEGroupedLinear
):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def
__init__
(
self
,
num_gemms
:
int
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
Optional
[
str
],
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
num_gemms
,
input_size
,
output_size
,
parallel_mode
=
parallel_mode
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
4bb958ec
...
@@ -239,7 +239,6 @@ class PostProcessNode(ScheduleNode):
...
@@ -239,7 +239,6 @@ class PostProcessNode(ScheduleNode):
return
loss
return
loss
class
TransformerLayerNode
(
ScheduleNode
):
class
TransformerLayerNode
(
ScheduleNode
):
def
__init__
(
self
,
chunk_state
,
common_state
,
layer
,
stream
,
event
,
free_inputs
=
False
):
def
__init__
(
self
,
chunk_state
,
common_state
,
layer
,
stream
,
event
,
free_inputs
=
False
):
...
@@ -598,8 +597,6 @@ def schedule_layer_1f1b(
...
@@ -598,8 +597,6 @@ def schedule_layer_1f1b(
with
f_context
:
with
f_context
:
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
def
next_iter_pre_forward
():
def
next_iter_pre_forward
():
if
f_layer
is
not
None
:
if
f_layer
is
not
None
:
with
f_context
:
with
f_context
:
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
4bb958ec
...
@@ -46,18 +46,7 @@ def gpt_model_init_wrapper(fn):
...
@@ -46,18 +46,7 @@ def gpt_model_init_wrapper(fn):
return
wrapper
return
wrapper
class
GPTModel
(
MegatronCoreGPTModel
):
def
gpt_model_forward
(
"""
patch megatron GPTModel
"""
def
get_transformer_callables_by_layer
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given transformer layer number.
"""
return
self
.
decoder
.
get_layer_callables
(
layer_number
)
def
build_schedule_plan
(
self
,
self
,
input_ids
:
Tensor
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
position_ids
:
Tensor
,
...
@@ -71,66 +60,7 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -71,66 +60,7 @@ class GPTModel(MegatronCoreGPTModel):
*
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from
.fine_grained_schedule
import
build_model_chunk_schedule_plan
return
build_model_chunk_schedule_plan
(
self
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
=
decoder_input
,
labels
=
labels
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
inference_params
=
inference_params
,
loss_mask
=
loss_mask
,
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_context
:
BaseInferenceContext
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
processing layer (optional).
...
@@ -300,3 +230,74 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -300,3 +230,74 @@ class GPTModel(MegatronCoreGPTModel):
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
return
loss
class
GPTModel
(
MegatronCoreGPTModel
):
"""
patch megatron GPTModel
"""
def
get_transformer_callables_by_layer
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given transformer layer number.
"""
return
self
.
decoder
.
get_layer_callables
(
layer_number
)
def
build_schedule_plan
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_context
:
BaseInferenceContext
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from
.fine_grained_schedule
import
build_model_chunk_schedule_plan
return
build_model_chunk_schedule_plan
(
self
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
=
decoder_input
,
labels
=
labels
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
inference_params
=
inference_params
,
loss_mask
=
loss_mask
,
)
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
4bb958ec
...
@@ -503,6 +503,7 @@ def get_default_cls_for_unwrap():
...
@@ -503,6 +503,7 @@ def get_default_cls_for_unwrap():
pass
pass
return
cls
return
cls
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
return_list
=
True
return_list
=
True
...
...
dcu_megatron/core/transformer/mlp.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.mlp
import
MLP
as
MegatronCoreMLP
class
MLP
(
MegatronCoreMLP
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
\ No newline at end of file
dcu_megatron/core/transformer/moe/experts.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.experts
import
TEGroupedMLP
as
MegatronCoreTEGroupedMLP
class
TEGroupedMLP
(
MegatronCoreTEGroupedMLP
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
dcu_megatron/core/transformer/moe/moe_layer.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
as
MegatronCoreMoELayer
class
MoELayer
(
MegatronCoreMoELayer
):
def
backward_dw
(
self
):
self
.
experts
.
backward_dw
()
self
.
shared_experts
.
backward_dw
()
dcu_megatron/core/transformer/multi_latent_attention.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.multi_latent_attention
import
MLASelfAttention
as
MegatronCoreMLASelfAttention
class
MLASelfAttention
(
MegatronCoreMLASelfAttention
):
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
backward_dw
(
self
):
self
.
linear_kv_up_proj
.
backward_dw
()
self
.
linear_kv_down_proj
.
backward_dw
()
if
self
.
config
.
q_lora_rank
is
None
:
self
.
linear_q_proj
.
backward_dw
()
else
:
self
.
linear_q_down_proj
.
backward_dw
()
self
.
linear_q_up_proj
.
backward_dw
()
self
.
linear_proj
.
backward_dw
()
dcu_megatron/core/transformer/transformer_block.py
View file @
4bb958ec
...
@@ -17,16 +17,6 @@ def transformer_block_init_wrapper(fn):
...
@@ -17,16 +17,6 @@ def transformer_block_init_wrapper(fn):
class
TransformerBlock
(
MegatronCoreTransformerBlock
):
class
TransformerBlock
(
MegatronCoreTransformerBlock
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"mtp_num_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
def
get_layer_callables
(
self
,
layer_number
:
int
):
def
get_layer_callables
(
self
,
layer_number
:
int
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment