Commit 4bb958ec authored by dongcl's avatar dongcl
Browse files

support a2a_overlap

parent 7c9dc3ec
from megatron.core.utils import is_te_min_version
def a2a_overlap_adaptation(patches_manager):
"""
patches_manager: MegatronPatchesManager
"""
from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from ..core.transformer.transformer_block import TransformerBlock
from ..core.transformer.transformer_layer import TransformerLayer
from ..core.models.gpt.gpt_model import GPTModel
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from ..core.extensions.transformer_engine import _get_extra_te_kwargs_wrapper, TELinear, TELayerNormColumnParallelLinear
from ..core.transformer.multi_latent_attention import MLASelfAttention
from ..core.transformer.mlp import MLP
from ..core.transformer.moe.experts import TEGroupedMLP
from ..core.transformer.moe.moe_layer import MoELayer
# num_warmup_microbatches + 1
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_block.TransformerBlock',
TransformerBlock)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel',
GPTModel)
# backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear)
if is_te_min_version("1.9.0.dev0"):
from ..core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention',
MLASelfAttention)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP',
MLP)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP',
TEGroupedMLP)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer',
MoELayer)
......@@ -24,6 +24,13 @@ class MegatronAdaptation:
adaptation.execute()
MegatronAdaptation.apply()
# apply features
from .patch_utils import MegatronPatchesManager
from .features_manager import a2a_overlap_adaptation
a2a_overlap_adaptation(MegatronPatchesManager)
MegatronPatchesManager.apply_patches()
@classmethod
def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False, apply_wrapper=False, remove_origin_wrappers=False):
"""
......@@ -91,14 +98,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def patch_core_models(self):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, GPTModel
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel',
GPTModel)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
......@@ -142,18 +149,6 @@ class CoreAdaptation(MegatronAdaptationABC):
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchedLinear if is_te_min_version("2.3.0.dev0") else te.pytorch.BatchLinear,)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
......@@ -190,9 +185,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
pass
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......
......@@ -3,7 +3,7 @@ import torch
import dataclasses
import transformer_engine as te
from typing import Any, Optional
from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion
from megatron.core.packed_seq_params import PackedSeqParams
......@@ -13,6 +13,9 @@ from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.extensions.transformer_engine import TELinear as MegatronCoreTELinear
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
......@@ -22,6 +25,112 @@ from megatron.core.parallel_state import (
)
def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn)
def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.get("split_bw", False)
return extra_transformer_engine_kwargs
return wrapper
class TELinear(MegatronCoreTELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
input_size,
output_size,
parallel_mode=parallel_mode,
config=config,
init_method=init_method,
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
is_expert=is_expert,
tp_group=tp_group,
)
def backward_dw(self):
if not self.split_bw:
return
class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
input_size,
output_size,
config=config,
init_method=init_method,
gather_output=gather_output,
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
if not self.split_bw:
return
class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__(
self,
......@@ -176,3 +285,52 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
layer_number=layer_number,
**extra_kwargs,
)
if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEGroupedLinear as MegatronCoreTEGroupedLinear
class TEGroupedLinear(MegatronCoreTEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
num_gemms,
input_size,
output_size,
parallel_mode=parallel_mode,
config=config,
init_method=init_method,
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
if not self.split_bw:
return
......@@ -239,7 +239,6 @@ class PostProcessNode(ScheduleNode):
return loss
class TransformerLayerNode(ScheduleNode):
def __init__(self, chunk_state, common_state, layer, stream, event, free_inputs=False):
......@@ -598,8 +597,6 @@ def schedule_layer_1f1b(
with f_context:
f_input = f_layer.mlp.forward(f_input)
def next_iter_pre_forward():
if f_layer is not None:
with f_context:
......
......@@ -46,6 +46,192 @@ def gpt_model_init_wrapper(fn):
return wrapper
def gpt_model_forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context = deprecate_inference_params(inference_context, inference_params)
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_context:
assert (
inference_context.is_static_batching()
), "GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_context, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
if self.training or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_context
and inference_context.is_static_batching()
and not self.training
):
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * inference_context.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if (
inference_context is not None
and not self.training
and not has_config_logger_enabled(self.config)
):
decoder_input = WrappedTensor(decoder_input)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
if self.mtp_process:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
compute_language_model_loss=self.compute_language_model_loss,
**(extra_block_kwargs or {}),
)
if (
self.mtp_process is not None
and getattr(self.decoder, "main_final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.main_final_layernorm(hidden_states)
if not self.post_process:
return hidden_states
if (
not self.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
class GPTModel(MegatronCoreGPTModel):
"""
patch megatron GPTModel
......@@ -115,188 +301,3 @@ class GPTModel(MegatronCoreGPTModel):
inference_params=inference_params,
loss_mask=loss_mask,
)
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context = deprecate_inference_params(inference_context, inference_params)
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_context:
assert (
inference_context.is_static_batching()
), "GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_context, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
if self.training or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_context
and inference_context.is_static_batching()
and not self.training
):
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * inference_context.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if (
inference_context is not None
and not self.training
and not has_config_logger_enabled(self.config)
):
decoder_input = WrappedTensor(decoder_input)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
if self.mtp_process:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
compute_language_model_loss=self.compute_language_model_loss,
**(extra_block_kwargs or {}),
)
if (
self.mtp_process is not None
and getattr(self.decoder, "main_final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.main_final_layernorm(hidden_states)
if not self.post_process:
return hidden_states
if (
not self.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
......@@ -503,6 +503,7 @@ def get_default_cls_for_unwrap():
pass
return cls
def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
return_list = True
......
from megatron.core.transformer.mlp import MLP as MegatronCoreMLP
class MLP(MegatronCoreMLP):
def backward_dw(self):
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()
\ No newline at end of file
from megatron.core.transformer.experts import TEGroupedMLP as MegatronCoreTEGroupedMLP
class TEGroupedMLP(MegatronCoreTEGroupedMLP):
def backward_dw(self):
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()
from megatron.core.transformer.moe.moe_layer import MoELayer as MegatronCoreMoELayer
class MoELayer(MegatronCoreMoELayer):
def backward_dw(self):
self.experts.backward_dw()
self.shared_experts.backward_dw()
from megatron.core.transformer.multi_latent_attention import MLASelfAttention as MegatronCoreMLASelfAttention
class MLASelfAttention(MegatronCoreMLASelfAttention):
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def backward_dw(self):
self.linear_kv_up_proj.backward_dw()
self.linear_kv_down_proj.backward_dw()
if self.config.q_lora_rank is None:
self.linear_q_proj.backward_dw()
else:
self.linear_q_down_proj.backward_dw()
self.linear_q_up_proj.backward_dw()
self.linear_proj.backward_dw()
......@@ -17,16 +17,6 @@ def transformer_block_init_wrapper(fn):
class TransformerBlock(MegatronCoreTransformerBlock):
def __init__(
self, *args, **kwargs
):
super().__init__(*args, **kwargs)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
def get_layer_callables(self, layer_number: int):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment