Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
4bb958ec
Commit
4bb958ec
authored
May 07, 2025
by
dongcl
Browse files
support a2a_overlap
parent
7c9dc3ec
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
454 additions
and
217 deletions
+454
-217
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+60
-0
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+10
-18
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+159
-1
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+0
-3
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+186
-185
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+1
-0
dcu_megatron/core/transformer/mlp.py
dcu_megatron/core/transformer/mlp.py
+7
-0
dcu_megatron/core/transformer/moe/experts.py
dcu_megatron/core/transformer/moe/experts.py
+6
-0
dcu_megatron/core/transformer/moe/moe_layer.py
dcu_megatron/core/transformer/moe/moe_layer.py
+7
-0
dcu_megatron/core/transformer/multi_latent_attention.py
dcu_megatron/core/transformer/multi_latent_attention.py
+18
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+0
-10
No files found.
dcu_megatron/adaptor/features_manager.py
0 → 100644
View file @
4bb958ec
from
megatron.core.utils
import
is_te_min_version
def
a2a_overlap_adaptation
(
patches_manager
):
"""
patches_manager: MegatronPatchesManager
"""
from
..core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
..core.transformer.transformer_block
import
TransformerBlock
from
..core.transformer.transformer_layer
import
TransformerLayer
from
..core.models.gpt.gpt_model
import
GPTModel
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
from
..core.extensions.transformer_engine
import
_get_extra_te_kwargs_wrapper
,
TELinear
,
TELayerNormColumnParallelLinear
from
..core.transformer.multi_latent_attention
import
MLASelfAttention
from
..core.transformer.mlp
import
MLP
from
..core.transformer.moe.experts
import
TEGroupedMLP
from
..core.transformer.moe.moe_layer
import
MoELayer
# num_warmup_microbatches + 1
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
MoEAlltoAllTokenDispatcher
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_block.TransformerBlock'
,
TransformerBlock
)
patches_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
TransformerLayer
)
patches_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
# backward_dw
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
TELayerNormColumnParallelLinear
)
if
is_te_min_version
(
"1.9.0.dev0"
):
from
..core.extensions.transformer_engine
import
TEGroupedLinear
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
TEGroupedLinear
)
patches_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention'
,
MLASelfAttention
)
patches_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP'
,
MLP
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP'
,
TEGroupedMLP
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer'
,
MoELayer
)
dcu_megatron/adaptor/megatron_adaptor.py
View file @
4bb958ec
...
...
@@ -24,6 +24,13 @@ class MegatronAdaptation:
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
# apply features
from
.patch_utils
import
MegatronPatchesManager
from
.features_manager
import
a2a_overlap_adaptation
a2a_overlap_adaptation
(
MegatronPatchesManager
)
MegatronPatchesManager
.
apply_patches
()
@
classmethod
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
"""
...
...
@@ -91,14 +98,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
GPTModel
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel
.forward
'
,
gpt_model_forward
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
...
...
@@ -142,18 +149,6 @@ class CoreAdaptation(MegatronAdaptationABC):
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchedLinear
if
is_te_min_version
(
"2.3.0.dev0"
)
else
te
.
pytorch
.
BatchLinear
,)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
...
...
@@ -190,9 +185,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
pass
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
4bb958ec
...
...
@@ -3,7 +3,7 @@ import torch
import
dataclasses
import
transformer_engine
as
te
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Callable
from
packaging.version
import
Version
as
PkgVersion
from
megatron.core.packed_seq_params
import
PackedSeqParams
...
...
@@ -13,6 +13,9 @@ from megatron.core.extensions.transformer_engine import TEDotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.process_groups_config
import
ModelCommProcessGroups
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.extensions.transformer_engine
import
TELinear
as
MegatronCoreTELinear
from
megatron.core.extensions.transformer_engine
import
TELayerNormColumnParallelLinear
as
MegatronCoreTELayerNormColumnParallelLinear
from
megatron.core.parallel_state
import
(
get_context_parallel_global_ranks
,
...
...
@@ -22,6 +25,112 @@ from megatron.core.parallel_state import (
)
def
_get_extra_te_kwargs_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
get
(
"split_bw"
,
False
)
return
extra_transformer_engine_kwargs
return
wrapper
class
TELinear
(
MegatronCoreTELinear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
Optional
[
str
],
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
skip_weight_param_allocation
:
bool
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
is_expert
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
input_size
,
output_size
,
parallel_mode
=
parallel_mode
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
is_expert
=
is_expert
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
class
TELayerNormColumnParallelLinear
(
MegatronCoreTELayerNormColumnParallelLinear
):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
TransformerConfig
,
init_method
:
Callable
,
gather_output
:
bool
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
input_size
,
output_size
,
config
=
config
,
init_method
=
init_method
,
gather_output
=
gather_output
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
self
,
...
...
@@ -176,3 +285,52 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
layer_number
=
layer_number
,
**
extra_kwargs
,
)
if
is_te_min_version
(
"1.9.0.dev0"
):
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
as
MegatronCoreTEGroupedLinear
class
TEGroupedLinear
(
MegatronCoreTEGroupedLinear
):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def
__init__
(
self
,
num_gemms
:
int
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
Optional
[
str
],
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
self
.
split_bw
=
config
.
get
(
"split_bw"
,
False
)
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
super
().
__init__
(
num_gemms
,
input_size
,
output_size
,
parallel_mode
=
parallel_mode
,
config
=
config
,
init_method
=
init_method
,
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
4bb958ec
...
...
@@ -239,7 +239,6 @@ class PostProcessNode(ScheduleNode):
return
loss
class
TransformerLayerNode
(
ScheduleNode
):
def
__init__
(
self
,
chunk_state
,
common_state
,
layer
,
stream
,
event
,
free_inputs
=
False
):
...
...
@@ -598,8 +597,6 @@ def schedule_layer_1f1b(
with
f_context
:
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
def
next_iter_pre_forward
():
if
f_layer
is
not
None
:
with
f_context
:
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
4bb958ec
...
...
@@ -46,18 +46,7 @@ def gpt_model_init_wrapper(fn):
return
wrapper
class
GPTModel
(
MegatronCoreGPTModel
):
"""
patch megatron GPTModel
"""
def
get_transformer_callables_by_layer
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given transformer layer number.
"""
return
self
.
decoder
.
get_layer_callables
(
layer_number
)
def
build_schedule_plan
(
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
...
...
@@ -71,66 +60,7 @@ class GPTModel(MegatronCoreGPTModel):
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from
.fine_grained_schedule
import
build_model_chunk_schedule_plan
return
build_model_chunk_schedule_plan
(
self
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
=
decoder_input
,
labels
=
labels
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
inference_params
=
inference_params
,
loss_mask
=
loss_mask
,
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_context
:
BaseInferenceContext
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
...
...
@@ -300,3 +230,74 @@ class GPTModel(MegatronCoreGPTModel):
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
class
GPTModel
(
MegatronCoreGPTModel
):
"""
patch megatron GPTModel
"""
def
get_transformer_callables_by_layer
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given transformer layer number.
"""
return
self
.
decoder
.
get_layer_callables
(
layer_number
)
def
build_schedule_plan
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_context
:
BaseInferenceContext
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from
.fine_grained_schedule
import
build_model_chunk_schedule_plan
return
build_model_chunk_schedule_plan
(
self
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
=
decoder_input
,
labels
=
labels
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
inference_params
=
inference_params
,
loss_mask
=
loss_mask
,
)
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
4bb958ec
...
...
@@ -503,6 +503,7 @@ def get_default_cls_for_unwrap():
pass
return
cls
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
return_list
=
True
...
...
dcu_megatron/core/transformer/mlp.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.mlp
import
MLP
as
MegatronCoreMLP
class
MLP
(
MegatronCoreMLP
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
\ No newline at end of file
dcu_megatron/core/transformer/moe/experts.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.experts
import
TEGroupedMLP
as
MegatronCoreTEGroupedMLP
class
TEGroupedMLP
(
MegatronCoreTEGroupedMLP
):
def
backward_dw
(
self
):
self
.
linear_fc2
.
backward_dw
()
self
.
linear_fc1
.
backward_dw
()
dcu_megatron/core/transformer/moe/moe_layer.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
as
MegatronCoreMoELayer
class
MoELayer
(
MegatronCoreMoELayer
):
def
backward_dw
(
self
):
self
.
experts
.
backward_dw
()
self
.
shared_experts
.
backward_dw
()
dcu_megatron/core/transformer/multi_latent_attention.py
0 → 100644
View file @
4bb958ec
from
megatron.core.transformer.multi_latent_attention
import
MLASelfAttention
as
MegatronCoreMLASelfAttention
class
MLASelfAttention
(
MegatronCoreMLASelfAttention
):
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
backward_dw
(
self
):
self
.
linear_kv_up_proj
.
backward_dw
()
self
.
linear_kv_down_proj
.
backward_dw
()
if
self
.
config
.
q_lora_rank
is
None
:
self
.
linear_q_proj
.
backward_dw
()
else
:
self
.
linear_q_down_proj
.
backward_dw
()
self
.
linear_q_up_proj
.
backward_dw
()
self
.
linear_proj
.
backward_dw
()
dcu_megatron/core/transformer/transformer_block.py
View file @
4bb958ec
...
...
@@ -17,16 +17,6 @@ def transformer_block_init_wrapper(fn):
class
TransformerBlock
(
MegatronCoreTransformerBlock
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"mtp_num_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
def
get_layer_callables
(
self
,
layer_number
:
int
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment