Commit f3581a8d authored by sdwldchl's avatar sdwldchl
Browse files

patch for megatron 6ba97dd

parent bc3d72d1
...@@ -89,12 +89,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -89,12 +89,9 @@ class CoreAdaptation(MegatronAdaptationABC):
pass pass
def patch_core_models(self): def patch_core_models(self):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward from ..core.models.gpt.gpt_model import gpt_model_forward
# GPT Model # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward) gpt_model_forward)
...@@ -116,9 +113,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -116,9 +113,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
...@@ -174,6 +171,8 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -174,6 +171,8 @@ class CoreAdaptation(MegatronAdaptationABC):
FluxRowParallelLinear) FluxRowParallelLinear)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec", MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec) get_gpt_layer_with_flux_spec)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear",
FluxColumnParallelLinear)
def patch_pipeline_parallel(self): def patch_pipeline_parallel(self):
pass pass
......
import os
from collections import OrderedDict from collections import OrderedDict
from typing import Optional from typing import Optional
from functools import wraps
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import deprecate_inference_params
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# Output
if self.post_process or self.mtp_process:
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
parallel_linear_impl = FluxColumnParallelLinear
else:
parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = parallel_linear_impl(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
return wrapper
def gpt_model_forward( def gpt_model_forward(
...@@ -52,14 +17,16 @@ def gpt_model_forward( ...@@ -52,14 +17,16 @@ def gpt_model_forward(
attention_mask: Tensor, attention_mask: Tensor,
decoder_input: Tensor = None, decoder_input: Tensor = None,
labels: Tensor = None, labels: Tensor = None,
inference_params: InferenceParams = None, inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None, packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None, extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None, runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None, loss_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors """Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post through the embedding layer, and then the decoeder and finally into the post
processing layer (optional). processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units It either returns the Loss values if labels are given or the final hidden units
...@@ -71,6 +38,8 @@ def gpt_model_forward( ...@@ -71,6 +38,8 @@ def gpt_model_forward(
# If decoder_input is provided (not None), then input_ids and position_ids are ignored. # If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context = deprecate_inference_params(inference_context, inference_params)
# Decoder embedding. # Decoder embedding.
if decoder_input is not None: if decoder_input is not None:
pass pass
...@@ -86,28 +55,43 @@ def gpt_model_forward( ...@@ -86,28 +55,43 @@ def gpt_model_forward(
rotary_pos_cos = None rotary_pos_cos = None
rotary_pos_sin = None rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params: if not self.training and self.config.flash_decode and inference_context:
assert (
inference_context.is_static_batching()
), "GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE # Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length, inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
) )
else: else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params inference_context, self.decoder, decoder_input, self.config, packed_seq_params
) )
rotary_pos_emb = self.rotary_pos_emb( rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len, rotary_seq_len,
packed_seq=packed_seq_params is not None packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd', and packed_seq_params.qkv_format == 'thd',
) )
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
if self.training or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if ( if (
(self.config.enable_cuda_graph or self.config.flash_decode) (self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None and rotary_pos_cos is not None
and inference_params and inference_context
and inference_context.is_static_batching()
and not self.training
): ):
sequence_len_offset = torch.tensor( sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size, [inference_context.sequence_len_offset] * inference_context.current_batch_size,
dtype=torch.int32, dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
) )
...@@ -118,7 +102,7 @@ def gpt_model_forward( ...@@ -118,7 +102,7 @@ def gpt_model_forward(
hidden_states = self.decoder( hidden_states = self.decoder(
hidden_states=decoder_input, hidden_states=decoder_input,
attention_mask=attention_mask, attention_mask=attention_mask,
inference_params=inference_params, inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos, rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin, rotary_pos_sin=rotary_pos_sin,
...@@ -127,6 +111,12 @@ def gpt_model_forward( ...@@ -127,6 +111,12 @@ def gpt_model_forward(
**(extra_block_kwargs or {}), **(extra_block_kwargs or {}),
) )
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss # logits and loss
output_weight = None output_weight = None
if self.share_embeddings_and_output_weights: if self.share_embeddings_and_output_weights:
...@@ -164,6 +154,13 @@ def gpt_model_forward( ...@@ -164,6 +154,13 @@ def gpt_model_forward(
if not self.post_process: if not self.post_process:
return hidden_states return hidden_states
if (
not self.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
logits, _ = self.output_layer( logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment