Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Any, Dict
import torch
from megatron.core import parallel_state
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference_params import InferenceParams
# pylint: disable=line-too-long
class VLMInferenceWrapper(GPTInferenceWrapper):
"""Inference wrapper for VLMs"""
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop.
It puts the model in eval mode.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
super().prep_model_for_inference(prompts_tokens)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
self._recv_only_vision_embeds = False
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
# Checks if the previous stage only has a vision encoder, and that the current stage
# has part of the LM decoder. In this case, the current stage should only receive
# vision embeddings.
if pp_rank > 0:
self._recv_only_vision_embeds = (
parallel_state.is_inside_encoder(pp_rank - 1)
and (not parallel_state.is_inside_decoder(pp_rank - 1))
and parallel_state.is_inside_decoder()
)
# Checks if the current stage only has a vision encoder
self._encoder_only = (
parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder()
)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def prep_inference_input(
self,
prompts_tokens: torch.Tensor,
num_img_embeddings_per_tile: int,
images: torch.Tensor,
num_tiles: torch.Tensor,
decoder_seq_length: int,
):
"""Prepares the inference input data.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
num_img_embeddings_per_tile (int): The number of image embeddings per tile
images (torch.Tensor): The image embeddings
num_tiles (torch.Tensor): The number of tiles for each input image
decoder_seq_length (int): The decoder sequence length
"""
inference_input = super().prep_inference_input(prompts_tokens)
total_num_tiles = torch.sum(num_tiles).item()
num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles
batch_size, max_sequence_length = prompts_tokens.shape
self.inference_params = InferenceParams(
batch_size, max_sequence_length + num_img_embeddings
)
inference_input["images"] = images
inference_input["num_tiles"] = num_tiles
inference_input["num_img_embeddings"] = num_img_embeddings
inference_input["decoder_seq_length"] = decoder_seq_length
return inference_input
def get_batch_for_context_window(
self,
inference_input: Dict[str, Any],
context_start_position: int,
context_end_position: int,
) -> Dict[str, Any]:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
Args:
inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During the first inference step it is mostly 0
context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length.
Returns:
Dict[str, Any]: A dict of inputs that will be used by your model in the forward step
"""
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
images = inference_input["images"]
num_tiles = inference_input["num_tiles"]
num_img_embeddings = inference_input["num_img_embeddings"]
decoder_seq_length = inference_input["decoder_seq_length"]
tokens2use = tokens[:, context_start_position:context_end_position]
positions2use = position_ids[:, context_start_position:context_end_position]
return {
"tokens": tokens2use,
"position_ids": positions2use,
"images": images,
"num_tiles": num_tiles,
"num_img_embeddings": num_img_embeddings,
"decoder_seq_length": decoder_seq_length,
}
def _forward(self, inference_input: Dict[str, Any]):
"""Runs a forward pass of the model.
Args:
inference_input(Dict[str, Any]): The input data.
Returns:
The model output logits.
"""
images = inference_input["images"]
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
num_image_tiles = inference_input["num_tiles"]
output = self.model(
images,
tokens,
position_ids=position_ids,
attention_mask=None,
inference_params=self.inference_params,
num_image_tiles=num_image_tiles,
runtime_gather_output=True,
)
if isinstance(output, tuple):
logits, _ = output
else:
logits = output
return logits
def run_one_forward_step(self, inference_input: Dict[str, Any]) -> torch.Tensor:
tokens = inference_input["tokens"]
num_image_tokens = (tokens == self.model.module.image_token_index).sum().item()
num_img_embeddings = inference_input["num_img_embeddings"]
decoder_seq_length = inference_input["decoder_seq_length"]
num_tokens = tokens.size(1)
recv_buffer_seq_len = None
if num_image_tokens > 0:
# When there are image tokens and this stage only receives vision embeddings,
# adjust the recv buffer seq length to match the image embeddings sequence length.
# If there are image tokens and this stage receives full embeddings, make sure we
# compensate for expansion of image tokens.
# Note that this will set a recv_buffer_seq_len for the encoder stage,
# this length is irrelevant since that recv buffer is never allocated.
if self._recv_only_vision_embeds:
recv_buffer_seq_len = num_img_embeddings
else:
recv_buffer_seq_len = min(
num_img_embeddings + num_tokens - num_image_tokens, decoder_seq_length
)
elif self._recv_only_vision_embeds:
# If this stage only receives vision embeddings and there are no image tokens
# we won't run the encoder and therefore shouldn't try to recv.
recv_buffer_seq_len = 0
# If the pipeline stage only has a vision encoder, then it only needs to
# run when there are image tokens
if not (self._encoder_only and num_image_tokens == 0):
output = super().run_one_forward_step(
inference_input, recv_buffer_seq_len=recv_buffer_seq_len
)
else:
output = None
logits = output
# On the first inference iteration, we compute image tokens.
# On every PP stage(although inference params should only matter for decoder),
# update the sequence length offset by the number of image tokens.
if num_tokens > 1 and num_image_tokens > 0:
if "image_tokens_count" not in self.inference_params.key_value_memory_dict:
self.inference_params.key_value_memory_dict["image_tokens_count"] = (
num_img_embeddings
)
if num_img_embeddings + num_tokens - num_image_tokens > decoder_seq_length:
self.inference_params.sequence_len_offset += decoder_seq_length - num_tokens
else:
self.inference_params.sequence_len_offset += (
self.inference_params.key_value_memory_dict["image_tokens_count"]
- num_image_tokens
)
return logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from collections import deque
from typing import Any, List, Tuple
import numpy
import torch
from megatron.core import tensor_parallel
from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.models.T5 import T5Model
# pylint: disable=line-too-long
class T5InferenceWrapper(AbstractModelInferenceWrapper):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy)
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
use_local (bool): Whether the T5 model's transformer impl
is local (vs transformer_engine)
"""
def __init__(
self,
model: T5Model,
inference_wrapper_config: InferenceWrapperConfig,
use_local: bool = False,
):
super().__init__(model, inference_wrapper_config)
self.use_local = use_local
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, encoder_prompts: List[str] = None, tokenizer: Any = None
):
"""A utility function for preparing model for inference
This function is called before the forward pass. It puts the model in eval mode, builds
position ids, and creates attention masks so that required slices can be extracted during
the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
encoder_prompts (dict): List of string of encoder input prompts
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
"""
super().prep_model_for_inference(prompts_tokens=prompts_tokens)
# get max_sequence_length
if hasattr(self.model, "module"): # if self.model is Float16Module
max_sequence_length = self.model.module.max_sequence_length
else:
max_sequence_length = self.model.max_sequence_length
encoder_prompts_tokens_list = [
self.tokenize_encoder_prompt(encoder_prompt, tokenizer)
for encoder_prompt in encoder_prompts
]
self.batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens(
encoder_prompts_tokens_list, max_sequence_length, tokenizer
)
# create batch mask for encoder_prompt (self.batch_input_tokens) and
# decoder_input (self.prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens = self.prompts_tokens.cpu().numpy()
encoder_prompts_tokens = self.batch_encoder_prompts_tokens.cpu().numpy()
self.batch_mask_encoder = []
self.batch_mask_decoder = []
for i in range(len(self.prompts_tokens)):
mask_encoder = encoder_prompts_tokens[i] == tokenizer.pad
mask_decoder = decoder_prompts_tokens[i] == tokenizer.pad
self.batch_mask_encoder.append(mask_encoder)
self.batch_mask_decoder.append(mask_decoder)
self.batch_mask_encoder = torch.tensor(numpy.array(self.batch_mask_encoder)).cuda()
self.batch_mask_decoder = torch.tensor(numpy.array(self.batch_mask_decoder)).cuda()
def tokenize_encoder_prompt(
self, encoder_prompt: str, tokenizer
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the encoder_prompt
Args:
encoder_prompt (str): The encoder_prompt
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
Returns:
torch.Tensor: Returns the tokenized prompt
"""
# if there is the word "<mask>" in prompt, replacing it with special_additional_token,
# similar to processing step in megatron/core/datasets/t5_dataset.py
divided_encoder_prompt_list = encoder_prompt.split("<mask>")
masks_count = len(divided_encoder_prompt_list) - 1
sentinels = deque(tokenizer.additional_special_tokens_ids)
encoder_prompt_tokens = []
for divided_encoder_prompt in divided_encoder_prompt_list:
divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt)
encoder_prompt_tokens.extend(divided_encoder_prompt_tokens)
if masks_count > 0:
sentinel = sentinels.popleft()
encoder_prompt_tokens.extend([sentinel])
masks_count -= 1
return encoder_prompt_tokens
def pad_encoder_prompts_tokens(
self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
encoder_prompts_tokens_list (List[List[int]]): A list containing the
encoder_input_tokens
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Returns:
torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
"""
for encoder_prompt_tokens in encoder_prompts_tokens_list:
padding_size = max_sequence_length - len(encoder_prompt_tokens)
encoder_prompt_tokens.extend([tokenizer.pad] * padding_size)
return torch.tensor(encoder_prompts_tokens_list).cuda()
def get_batch_for_context_window(
self, context_start_position: int, context_end_position: int
) -> List:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context
positions , it extracts the appropriate data.
Args:
context_start_position (int): Start of the context window. During
the first inference step it is mostly 0
context_end_position (int): End of the context window. During the
last inference step it will mostly be the max generated sequence length.
Returns:
List: A list of inputs that will be used by your model in the forward step
"""
# T5 inference not yet support kv_cache
encoder_tokens2use = self.batch_encoder_prompts_tokens
decoder_tokens2use = self.prompts_tokens[:, :context_end_position]
encoder_mask2use = self.batch_mask_encoder
decoder_mask2use = self.batch_mask_decoder[:, :context_end_position]
# Configure attention mask based on different conditions
# (e.g., transformer-impl, TE versions, TE backends)
[encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = (
T5MaskedWordPieceDataset.config_attention_mask(
encoder_tokens2use,
decoder_tokens2use,
encoder_mask2use,
decoder_mask2use,
self.use_local,
)
)
data_at_step_idx = [
encoder_tokens2use,
decoder_tokens2use,
encoder_mask2use,
decoder_mask2use,
encoder_decoder_mask2use,
]
return data_at_step_idx
def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without
any parallelism or only tensor parallelism.
Args:
inference_input (List): A list containg the inputs for the gpt
model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
[encoder_tokens, decoder_tokens, encoder_mask, decoder_mask, encoder_decoder_mask] = (
inference_input
)
tokens = decoder_tokens
# T5 inference not yet support kv_cache
logits = self.model(
encoder_tokens,
decoder_tokens,
encoder_mask,
decoder_mask,
encoder_decoder_mask,
inference_params=None,
)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
return logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from collections import deque
from typing import Any, Dict, List, Optional
import numpy
import torch
from megatron.core import tensor_parallel
from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.models.T5 import T5Model
from megatron.core.utils import get_attr_wrapped_model
# pylint: disable=line-too-long
class T5InferenceWrapper(AbstractModelInferenceWrapper):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy)
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
use_local (bool): Whether the T5 model's transformer impl
is local (vs transformer_engine)
"""
def __init__(
self,
model: T5Model,
inference_wrapper_config: InferenceWrapperConfig,
use_local: bool = False,
):
super().__init__(model, inference_wrapper_config)
self.use_local = use_local
def prep_inference_input(
self,
prompts_tokens: torch.Tensor,
encoder_prompts: Optional[List[str]] = None,
tokenizer: Any = None,
) -> Dict[str, Any]:
"""Prepares the inference input data.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
encoder_prompts (dict): List of string of encoder input prompts
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Returns:
A dict with all the inference input needed for the batch.
"""
# get max_sequence_length
max_sequence_length = get_attr_wrapped_model(self.model, "max_sequence_length")
encoder_prompts_tokens_list = [
self.tokenize_encoder_prompt(encoder_prompt, tokenizer)
for encoder_prompt in encoder_prompts
]
batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens(
encoder_prompts_tokens_list, max_sequence_length, tokenizer
)
# create batch mask for encoder_prompt (self.batch_input_tokens) and
# decoder_input (prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens = prompts_tokens
encoder_prompts_tokens = batch_encoder_prompts_tokens
decoder_prompts_tokens_numpy = decoder_prompts_tokens.cpu().numpy()
encoder_prompts_tokens_numpy = encoder_prompts_tokens.cpu().numpy()
batch_mask_encoder = []
batch_mask_decoder = []
for i in range(len(prompts_tokens)):
mask_encoder = encoder_prompts_tokens_numpy[i] == tokenizer.pad
mask_decoder = decoder_prompts_tokens_numpy[i] == tokenizer.pad
batch_mask_encoder.append(mask_encoder)
batch_mask_decoder.append(mask_decoder)
batch_mask_encoder = torch.tensor(numpy.array(batch_mask_encoder)).cuda()
batch_mask_decoder = torch.tensor(numpy.array(batch_mask_decoder)).cuda()
return {
"encoder_tokens": encoder_prompts_tokens,
"decoder_tokens": decoder_prompts_tokens,
"encoder_mask": batch_mask_encoder,
"decoder_mask": batch_mask_decoder,
}
def tokenize_encoder_prompt(self, encoder_prompt: str, tokenizer) -> torch.Tensor:
"""Utility to tokenize the encoder_prompt
Args:
encoder_prompt (str): The encoder_prompt
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
Returns:
torch.Tensor: Returns the tokenized prompt
"""
# if there is the word "<mask>" in prompt, replacing it with special_additional_token,
# similar to processing step in megatron/core/datasets/t5_dataset.py
divided_encoder_prompt_list = encoder_prompt.split("<mask>")
masks_count = len(divided_encoder_prompt_list) - 1
sentinels = deque(tokenizer.additional_special_tokens_ids)
encoder_prompt_tokens = []
for divided_encoder_prompt in divided_encoder_prompt_list:
divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt)
encoder_prompt_tokens.extend(divided_encoder_prompt_tokens)
if masks_count > 0:
sentinel = sentinels.popleft()
encoder_prompt_tokens.extend([sentinel])
masks_count -= 1
return encoder_prompt_tokens
def pad_encoder_prompts_tokens(
self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
encoder_prompts_tokens_list (List[List[int]]): A list containing the
encoder_input_tokens
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Returns:
torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
"""
for encoder_prompt_tokens in encoder_prompts_tokens_list:
padding_size = max_sequence_length - len(encoder_prompt_tokens)
encoder_prompt_tokens.extend([tokenizer.pad] * padding_size)
return torch.tensor(encoder_prompts_tokens_list).cuda()
def get_batch_for_context_window(
self,
inference_input: Dict[str, Any],
context_start_position: int,
context_end_position: int,
) -> Dict[str, Any]:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context
positions , it extracts the appropriate data.
Args:
inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During
the first inference step it is mostly 0
context_end_position (int): End of the context window. During the
last inference step it will mostly be the max generated sequence length.
Returns:
Dict: A dict of inputs that will be used by your model in the forward step
"""
# T5 inference not yet support kv_cache
encoder_tokens2use = inference_input["encoder_tokens"]
decoder_tokens2use = inference_input["decoder_tokens"][:, :context_end_position]
encoder_mask2use = inference_input["encoder_mask"]
decoder_mask2use = inference_input["decoder_mask"][:, :context_end_position]
# Configure attention mask based on different conditions
# (e.g., transformer-impl, TE versions, TE backends)
[encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = (
T5MaskedWordPieceDataset.config_attention_mask(
encoder_tokens2use,
decoder_tokens2use,
encoder_mask2use,
decoder_mask2use,
self.use_local,
)
)
return {
"encoder_tokens": encoder_tokens2use,
"decoder_tokens": decoder_tokens2use,
"encoder_mask": encoder_mask2use,
"decoder_mask": decoder_mask2use,
"encoder_decoder_mask": encoder_decoder_mask2use,
}
def forward_pass_without_pipeline_parallel(
self, inference_input: Dict[str, Any]
) -> torch.Tensor:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without
any parallelism or only tensor parallelism.
Args:
inference_input (Dict[str, Any]): A dict containg the inputs for the gpt
model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
encoder_tokens = inference_input["encoder_tokens"]
decoder_tokens = inference_input["decoder_tokens"]
encoder_mask = inference_input["encoder_mask"]
decoder_mask = inference_input["decoder_mask"]
encoder_decoder_mask = inference_input["encoder_decoder_mask"]
tokens = decoder_tokens
# T5 inference not yet support kv_cache
logits = self.model(
encoder_tokens,
decoder_tokens,
encoder_mask,
decoder_mask,
encoder_decoder_mask,
inference_params=None,
)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
return logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to
compress model for efficient inference on NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference. More details on ModelOpt including
installation and usage can be found at https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques
including quantization and sparsity to compress model for efficient inference on
NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference.
More details on ModelOpt including installation and usage can be found at
https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_gpt_layer_modelopt_spec(
num_experts: int = None,
moe_grouped_gemm: bool = False,
remap_te_layernorm: bool = False,
qk_layernorm: bool = False,
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False
)
sharded_state_dict_keys_map = {}
if remap_te_layernorm:
if num_experts:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
}
else:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_gpt_layer_modelopt_spec(
num_experts: Optional[int] = None,
local_core_attention: bool = False,
moe_grouped_gemm: bool = False,
remap_te_layernorm: bool = False,
qk_layernorm: bool = False,
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
mlp = get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False
)
sharded_state_dict_keys_map = {}
if remap_te_layernorm:
if num_experts:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
}
else:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
),
)
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_mamba_stack_modelopt_spec(
local_core_attention: bool = False, remap_te_layernorm: bool = False
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map = {}
transformer_state_dict_keys_map = {}
if remap_te_layernorm:
mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'}
transformer_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
mamba_layer = ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
norm=TENorm,
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=ColumnParallelLinear, out_proj=RowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=mamba_state_dict_keys_map,
),
)
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
attention_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
mlp_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
return ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
@dataclass
class SamplingParams:
"""Inference parameters sent along with the prompts.
This class contains request-level attributes that control the sampling techniques used when
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
temperature-parameters-ed6a31313910
"""
temperature: float = 1.0
top_k: int = 0
top_p: float = 0.0
return_log_probs: bool = False
num_tokens_to_generate: int = 30
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c = SamplingParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and
their values as the values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
@dataclass
class SamplingParams:
"""Inference parameters sent along with the prompts.
This class contains request-level attributes that control the sampling techniques used when
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
temperature-parameters-ed6a31313910
"""
temperature: float = 1.0
top_k: int = 0
top_p: float = 0.0
return_log_probs: bool = False
return_segments: bool = False # Whether to return individually detokenized tokens
num_tokens_to_generate: int = 30
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c = SamplingParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and
their values as the values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import time
import typing
from collections import OrderedDict
from typing import Dict
import torch
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import Counter
class Scheduler:
"""Scheduler for handling requests to inference engine
This class is responsible for handing of all the incomign requests
Args:
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
"""
def __init__(self, max_batch_size: int):
self.max_batch_size = max_batch_size
self.active_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.waiting_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.completed_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.request_counter = Counter()
def add_request(
self,
prompt: str,
prompt_tokens: torch.Tensor,
encoder_prompt: str = None,
inference_parameters: SamplingParams = None,
arrival_time: float = None,
):
"""Add an incoming request
This method will add the request to either the active pool or the waiting pool
depending on the batch size.
Args:
prompt (str): Input prompt string
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
encoder_prompt (str): Encoder input string
inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
"""
request_id = str(next(self.request_counter))
if arrival_time is None:
arrival_time = time.time()
status = (
Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
if len(self.active_request_pool) < self.max_batch_size
else Status.WAITING_IN_QUEUE
)
inference_request = InferenceRequest(
request_id=request_id,
prompt=prompt,
inference_parameters=inference_parameters,
arrival_time=arrival_time,
prompt_tokens=prompt_tokens,
status=status,
encoder_prompt=encoder_prompt,
)
if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS:
self.active_request_pool[request_id] = inference_request
else:
self.waiting_request_pool[request_id] = inference_request
def have_requests_pending(self) -> bool:
"""Method to check if there are requests pending
This method returns False only when there are no active requests or waiting requests.
"""
num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool)
return num_requests_pending > 0
def add_earliest_waiting_request_to_active_pool(self):
"""Utility to add the waiting request to active pool
This method will add the earliest request (FIFO) that is in the waiting request
pool to the active request pool.
"""
assert (
len(self.active_request_pool) < self.max_batch_size
), "Active request pool is already full. Cant add any more requests"
if len(self.waiting_request_pool) > 0:
(earliest_waiting_request_request_id, earliest_waiting_request) = (
self.waiting_request_pool.popitem(last=False)
)
earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request
def update_requests_pools(self, result_dict: typing.OrderedDict[int, InferenceRequest] = None):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed
request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[int, InferenceRequest], optional): The result returned
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
"""
for result_request_id in list(result_dict.keys()):
active_request = self.active_request_pool[result_request_id]
# If a request has completed put it into the completed request pool.
if active_request.status == Status.COMPLETED:
completed_request = self.active_request_pool.pop(result_request_id)
self.completed_request_pool[result_request_id] = completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while (
len(self.active_request_pool) < self.max_batch_size
and len(self.waiting_request_pool) > 0
):
self.add_earliest_waiting_request_to_active_pool()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import functools
import time
import typing
from collections import OrderedDict
from typing import Dict, Optional, Type, Union
import torch
from megatron.core.inference.async_stream import AsyncStream
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import Counter
class Scheduler:
"""Scheduler for handling requests to inference engine
This class is responsible for handing of all the incomign requests
Args:
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
request_type (InferenceRequest): The class to use for instantiating new requests.
"""
def __init__(self, max_batch_size):
self.max_batch_size = max_batch_size
self.requests: Dict[str, InferenceRequest] = OrderedDict()
self.streams: Dict[str, AsyncStream] = OrderedDict()
self.active_request_pool: Dict[str, InferenceRequest] = OrderedDict()
self.waiting_request_pool: Dict[str, InferenceRequest] = OrderedDict()
self.completed_request_pool: Dict[str, InferenceRequest] = OrderedDict()
self.request_counter = Counter()
def get_new_request_id(self) -> str:
"""Gets a new request id"""
request_id = str(next(self.request_counter))
return request_id
def add_request(
self,
prompt: Optional[str] = None,
prompt_tokens: Optional[torch.Tensor] = None,
encoder_prompt: Optional[str] = None,
inference_parameters: Optional[SamplingParams] = None,
arrival_time: Optional[float] = None,
streaming: bool = False,
inference_request: Optional[InferenceRequest] = None,
) -> str:
"""Add an incoming request
This method will add the request to either the active pool or the waiting pool
depending on the batch size.
Args:
prompt (str): Input prompt string
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
encoder_prompt (str): Encoder input string
inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
streaming (bool, optional): Whether to asynchronously stream tokens for this request.
inference_request (InferenceRequest, optional): A fully constructed request.
Defaults to None.
Returns:
The request_id for the new request.
"""
status = (
Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
if len(self.active_request_pool) < self.max_batch_size
else Status.WAITING_IN_QUEUE
)
if inference_request is None:
assert prompt is not None
assert prompt_tokens is not None
request_id = self.get_new_request_id()
if arrival_time is None:
arrival_time = time.time()
inference_request = InferenceRequest(
request_id=request_id,
prompt=prompt,
inference_parameters=inference_parameters,
arrival_time=arrival_time,
prompt_tokens=prompt_tokens,
status=status,
encoder_prompt=encoder_prompt,
)
else:
request_id = inference_request.request_id
inference_request.status = status
if inference_request.arrival_time is None:
inference_request.arrival_time = time.time()
self.requests[request_id] = inference_request
if streaming:
abort_request = functools.partial(self.abort_request, request_id=request_id)
self.streams[request_id] = AsyncStream(request_id, abort_request)
if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS:
self.active_request_pool[request_id] = inference_request
else:
self.waiting_request_pool[request_id] = inference_request
return request_id
def have_requests_pending(self) -> bool:
"""Method to check if there are requests pending
This method returns False only when there are no active requests or waiting requests.
"""
num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool)
return num_requests_pending > 0
def add_earliest_waiting_request_to_active_pool(self):
"""Utility to add the waiting request to active pool
This method will add the earliest request (FIFO) that is in the waiting request
pool to the active request pool.
"""
assert (
len(self.active_request_pool) < self.max_batch_size
), "Active request pool is already full. Cant add any more requests"
if len(self.waiting_request_pool) > 0:
(earliest_waiting_request_request_id, earliest_waiting_request) = (
self.waiting_request_pool.popitem(last=False)
)
earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request
def update_requests_pools(
self, result_dict: Optional[typing.OrderedDict[str, InferenceRequest]] = None
):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed
request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[str, InferenceRequest], optional): The result returned
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
"""
for result_request_id in list(result_dict.keys()):
active_request = self.active_request_pool[result_request_id]
# If a request has completed put it into the completed request pool.
if active_request.status == Status.COMPLETED:
completed_request = self.active_request_pool.pop(result_request_id)
self.completed_request_pool[result_request_id] = completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while (
len(self.active_request_pool) < self.max_batch_size
and len(self.waiting_request_pool) > 0
):
self.add_earliest_waiting_request_to_active_pool()
def abort_request(
self,
request_id: str,
*,
exception: Optional[Union[BaseException, Type[BaseException]]] = None
):
"""Cancels the given request"""
stream = self.streams.get(request_id, None)
if stream is not None:
stream.finish(exception=exception)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import OrderedDict
import torch
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class EncoderDecoderTextGenerationController(TextGenerationController):
"""The text generation controller for encoder-decoder architecture
This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
"""
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
encoder_prompts = list(
map(lambda request: request.encoder_prompt, active_requests.values())
)
self.inference_wrapped_model.prep_model_for_inference(
prompts_tokens=prompts_tokens, encoder_prompts=encoder_prompts, tokenizer=self.tokenizer
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Any, Dict, OrderedDict
import torch
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class EncoderDecoderTextGenerationController(TextGenerationController):
"""The text generation controller for encoder-decoder architecture
This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
"""
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
) -> Dict[str, Any]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
encoder_prompts = list(
map(lambda request: request.encoder_prompt, active_requests.values())
)
return self.inference_wrapped_model.prep_inference_input(
prompts_tokens, encoder_prompts, tokenizer=self.tokenizer
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import
TextGenerationController as SimpleTextGenerationController,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import
TextGenerationController as SimpleTextGenerationController,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, OrderedDict, Tuple
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams
class TextGenerationController:
"""The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
self.inference_wrapped_model = inference_wrapped_model
self.tokenizer = tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def tokenize_prompt(
self, prompt: str, add_BOS: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
prompt_tokens = self.tokenizer.tokenize(prompt)
if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
return prompt_tokens
def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
"""Detokenize the output generations
Args:
prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt
tokens plus the generated tokens
Returns:
str: The detokenized output
"""
tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist()
return self.tokenizer.detokenize(tokens)
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
sampling_params: SamplingParams = None,
vocab_size: int = None,
**kwargs
) -> torch.Tensor:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it
according to the parameters defined in sampling_params
and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
if kwargs.get('common_inference_params'):
sampling_params = kwargs['common_inference_params']
top_p = sampling_params.top_p
top_k = sampling_params.top_k
temperature = sampling_params.temperature
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
assert top_p <= 1.0, 'top-p should be in (0,1]'
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
# Greedy sampling
if top_k == 1:
sampled_logits = torch.argmax(last_token_logits, dim=-1)
else:
last_token_logits = last_token_logits.clone()
if temperature != 1.0:
last_token_logits.div_(temperature)
if top_k > 1:
assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(last_token_logits, top_k)
elif top_p > 0.0:
modify_logits_for_top_p_filtering(last_token_logits, top_p)
# After filtering, we need to recalculate the distribution.
probabilities = last_token_logits.softmax(dim=-1)
sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if vocab_size:
sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
return sampled_logits
def update_generation_status(
self,
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
increase as we keep generating, until that prompts hits an end condition. The
generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
generated tokens. A tensor of shape [batch_size, max_seq_len]
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean
is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples = updated_prompts_tokens[:, current_context_end_position]
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
is_generation_done_tensor = is_generation_done_tensor | reached_eod
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
generated_sequence_lengths += ~is_generation_done_tensor & generation_started
return is_generation_done_tensor, generated_sequence_lengths
def pad_input_prompt_tokens(
self,
batch_prompt_tokens_list: List[List[int]],
max_prompt_length_in_batch: int,
num_tokens_to_generate: int,
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
with extra indices for each tensor padded with mask id.
"""
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
for prompt_tokens in batch_prompt_tokens_list:
padding_size = max_seq_len - len(prompt_tokens)
prompt_tokens.extend([self.tokenizer.eod] * padding_size)
return torch.tensor(batch_prompt_tokens_list).cuda()
def generate_output_tokens_dynamic_batch(
self, active_requests: OrderedDict[int, InferenceRequest]
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step
at a time, and pass control back to the engine, which will update the request pool and call
this method again.
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
after running one forward step.
"""
raise Exception("Not implemented yet")
def generate_all_output_tokens_static_batch(
self, active_requests: OrderedDict[int, InferenceRequest]
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
"""
batch_prompt_tokens_list = list(
map(lambda request: request.prompt_tokens, active_requests.values())
)
prompt_lengths_in_batch = torch.tensor(
[len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list]
).cuda()
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch)
# For batch inference the inference params are the same for all request
sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens = self.pad_input_prompt_tokens(
batch_prompt_tokens_list,
max_prompt_length_in_batch=max_prompt_length_in_batch,
num_tokens_to_generate=sampling_params.num_tokens_to_generate,
)
batch_size, max_sequence_length = batch_prompt_tokens.shape
# Pre allocate log probs tensor
output_log_probs = None
if sampling_params.return_log_probs:
output_log_probs = torch.empty(
(batch_size, max_sequence_length - 1), dtype=torch.float32
).cuda()
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda()
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths = torch.zeros(batch_size).cuda()
with torch.no_grad():
self.prep_model_for_inference(
prompts_tokens=batch_prompt_tokens, active_requests=active_requests
)
context_start_position = 0
# Pick the context window that we need to pass through the network.
for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
inference_input = self.inference_wrapped_model.get_batch_for_context_window(
context_start_position, context_end_position
)
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, self.tokenizer.vocab_size],
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
tensor=logits,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, sampling_params, self.tokenizer.vocab_size
)
# Substitute the sampled logits only for only the prompts that
# have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
generation_started
]
if sampling_params.return_log_probs:
log_probs = F.log_softmax(logits, dim=2)
indices = torch.unsqueeze(
batch_prompt_tokens[
:, (context_start_position + 1) : (context_end_position + 1)
],
2,
)
# Get the log probabilities for only the prompt tokens
output_log_probs[:, context_start_position:context_end_position] = torch.gather(
log_probs, 2, indices
).squeeze(2)
context_start_position = context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(is_generation_done_tensor, generated_sequence_lengths) = (
self.update_generation_status(
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started,
current_context_end_position=context_end_position,
is_generation_done_tensor=is_generation_done_tensor,
generated_sequence_lengths=generated_sequence_lengths,
)
)
# Boolean flag indicating if all prompts are finished
all_prompts_done = torch.all(is_generation_done_tensor)
if all_prompts_done:
break
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if sampling_params.return_log_probs:
output_log_probs = output_log_probs[:, :context_end_position]
generated_sequence_lengths[
generated_sequence_lengths > sampling_params.num_tokens_to_generate
] = sampling_params.num_tokens_to_generate
for idx, request in enumerate(active_requests.values()):
input_prompt_length = int(prompt_lengths_in_batch[idx])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length = int(
min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate)
)
# Extract only the generated tokens
required_result_tokens = batch_prompt_tokens_with_generations[
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
]
request.generated_length = required_sequence_length
request.generated_tokens = required_result_tokens
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, input_prompt_length:required_sequence_length]
)
request.status = Status.COMPLETED
request.generated_text = self.detokenize_generations(required_result_tokens)
return active_requests
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import concurrent
import copy
import functools
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.inference.async_stream import AsyncStream
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import get_model_config
class TextGenerationController:
"""The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
self.inference_wrapped_model = inference_wrapped_model
self.tokenizer = tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def tokenize_prompt(
self, prompt: str, add_BOS: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
prompt_tokens = self.tokenizer.tokenize(prompt)
if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
return prompt_tokens
def detokenize_generations(
self,
tokens_gpu_tensor: torch.Tensor,
lengths_gpu_tensor: torch.Tensor,
detokenize_segments: bool,
) -> tuple[str, Optional[List[List[str]]]]:
"""Detokenize the generated tokens.
Args:
tokens_gpu_tensor (torch.Tensor): Tensor containing the tokens
lengths_gpu_tensor (torch.Tensor): Tensor containing the lengths of each sequence
detokenize_segments (bool): If True, returns individually detokenized tokens. If False,
returns None as second element. Helpful for understanding per-token boundaries in
generated text.
Returns:
tuple[str, List[str] | None]: A tuple containing:
- str: The complete detokenized text
- List[str] | None: List of segmented tokens if detokenize_segments is True, else None
"""
# TODO(helenn): Unify with `detokenize_generations` from legacy textgen path
if not detokenize_segments:
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
return self.tokenizer.detokenize(tokens), None
prompts_plus_generations: List[str] = []
prompts_plus_generations_segments: List[List[str]] = []
tokens_gpu_tensor = torch.unsqueeze(tokens_gpu_tensor, 0)
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
detok_str = self.tokenizer.detokenize(sequence_tokens)
prompts_plus_generations.append(detok_str)
offsets = self.tokenizer.offsets(sequence_tokens, detok_str)
words = [
detok_str[start:end] for start, end in zip(offsets, offsets[1:] + [len(detok_str)])
]
prompts_plus_generations_segments.append(words)
text = self.tokenizer.detokenize(tokens[0])
return text, prompts_plus_generations_segments
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
sampling_params: Optional[SamplingParams] = None,
vocab_size: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it
according to the parameters defined in sampling_params
and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
if kwargs.get('common_inference_params'):
sampling_params = kwargs['common_inference_params']
top_p = sampling_params.top_p
top_k = sampling_params.top_k
temperature = sampling_params.temperature
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
assert top_p <= 1.0, 'top-p should be in (0,1]'
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
# Greedy sampling
if top_k == 1:
sampled_logits = torch.argmax(last_token_logits, dim=-1)
else:
last_token_logits = last_token_logits.clone()
if temperature != 1.0:
last_token_logits.div_(temperature)
if top_k > 1:
assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(last_token_logits, top_k)
elif top_p > 0.0:
modify_logits_for_top_p_filtering(last_token_logits, top_p)
# After filtering, we need to recalculate the distribution.
probabilities = last_token_logits.softmax(dim=-1)
sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if vocab_size:
sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
return sampled_logits
def update_generation_status(
self,
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
increase as we keep generating, until that prompts hits an end condition. The
generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
generated tokens. A tensor of shape [batch_size, max_seq_len]
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Returns the boolean
is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples = updated_prompts_tokens[:, current_context_end_position]
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
is_generation_done_tensor = is_generation_done_tensor | reached_eod
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
generated_sequence_lengths += ~is_generation_done_tensor & generation_started
return is_generation_done_tensor, generated_sequence_lengths.int()
def pad_input_prompt_tokens(
self,
batch_prompt_tokens_list: List[List[int]],
max_prompt_length_in_batch: int,
num_tokens_to_generate: int,
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
"""
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
for prompt_tokens in batch_prompt_tokens_list:
padding_size = max_seq_len - len(prompt_tokens)
prompt_tokens.extend([self.tokenizer.eod] * padding_size)
return torch.tensor(batch_prompt_tokens_list, device=torch.cuda.current_device())
def generate_output_tokens_dynamic_batch(
self, active_requests: OrderedDict[str, InferenceRequest]
) -> OrderedDict[str, InferenceRequest]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step
at a time, and pass control back to the engine, which will update the request pool and call
this method again.
Args:
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
Returns:
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
after running one forward step.
"""
raise Exception("Not implemented yet")
def generate_all_output_tokens_static_batch(
self,
active_requests: OrderedDict[str, InferenceRequest],
active_streams: Optional[OrderedDict[str, AsyncStream]] = None,
) -> OrderedDict[str, InferenceRequest]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
Args:
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
Returns:
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
"""
assert all(request.prompt_tokens is not None for request in active_requests.values())
# Perform a deep copy so that the request prompt tokens do not get modified.
batch_prompt_tokens_list: List[List[int]] = list(
map(
lambda request: copy.deepcopy(request.prompt_tokens), # type: ignore[arg-type]
active_requests.values(),
)
)
prompt_lengths_in_batch = torch.tensor(
[len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list],
device=torch.cuda.current_device(),
)
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch)
# For batch inference the inference params are the same for all request
sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens = self.pad_input_prompt_tokens(
batch_prompt_tokens_list,
max_prompt_length_in_batch=max_prompt_length_in_batch,
num_tokens_to_generate=sampling_params.num_tokens_to_generate,
)
batch_size, max_sequence_length = batch_prompt_tokens.shape
# Verify that output sequence length is within configured limit
# TODO(ksanthanam): Raise TokenOverflowError once !2518 is merged
inference_max_sequence_length = (
self.inference_wrapped_model.inference_wrapper_config.inference_max_seq_length
)
assert max_sequence_length <= inference_max_sequence_length, (
f"Maximum allowed sequence length was set to {inference_max_sequence_length} tokens "
f"but requested generation of {max_sequence_length} tokens"
)
# Pre allocate log probs tensor
output_log_probs = None
if sampling_params.return_log_probs:
output_log_probs = torch.empty(
(batch_size, max_sequence_length - 1),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor = torch.zeros(
batch_size, dtype=torch.bool, device=torch.cuda.current_device()
)
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths = torch.zeros(
batch_size, device=torch.cuda.current_device()
).cuda()
# Use padded vocab size because tokenizer vocab size might not include padding
# to nearest power of 2
vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size
# Check whether CUDA graphs are enabled
enable_cuda_graph = get_model_config(self.inference_wrapped_model.model).enable_cuda_graph
streaming_enabled = active_streams is not None and len(active_streams) > 0
if streaming_enabled:
# Start a separate thread for streaming tokens to avoid blocking the
# main computation
streaming_idx: List[int] = [
i
for (i, request_id) in enumerate(active_requests.keys())
if request_id in active_streams
]
streaming_request_ids: List[str] = list(active_streams.keys())
streams: List[AsyncStream] = list(active_streams.values())
streaming_requests: List[InferenceRequest] = [
active_requests[request_id] for request_id in streaming_request_ids
]
streaming_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
stream_tokens = functools.partial(self.stream_tokens, sampling_params)
with torch.no_grad():
self.inference_wrapped_model.prep_model_for_inference(
prompts_tokens=batch_prompt_tokens
)
inference_input: Dict[str, Any] = self.prep_inference_input(
prompts_tokens=batch_prompt_tokens, active_requests=active_requests
)
assert (
not self.inference_wrapped_model.inference_params.decode_mode
), f"Generation must start in prefill mode"
context_start_position = 0
# Pick the context window that we need to pass through the network.
for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
inference_input_for_context_window: Dict[str, Any] = (
self.inference_wrapped_model.get_batch_for_context_window(
inference_input, context_start_position, context_end_position
)
)
# Disable attention mask when using CUDA graphs for decode
if (
enable_cuda_graph
and self.inference_wrapped_model.inference_params.decode_mode
and "attention_mask" in inference_input_for_context_window
):
inference_input_for_context_window["attention_mask"] = None
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(
inference_input_for_context_window
)
if enable_cuda_graph:
create_cudagraphs()
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, vocab_size],
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
tensor=logits,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, sampling_params, vocab_size
)
# Substitute the sampled logits only for the prompts that
# have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
generation_started
]
if sampling_params.return_log_probs:
log_probs = F.log_softmax(logits, dim=2)
indices = torch.unsqueeze(
batch_prompt_tokens[
:, (context_start_position + 1) : (context_end_position + 1)
],
2,
)
# Get the log probabilities for only the prompt tokens
assert output_log_probs is not None
output_log_probs[:, context_start_position:context_end_position] = torch.gather(
log_probs, 2, indices
).squeeze(2)
context_start_position = context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(is_generation_done_tensor, generated_sequence_lengths) = (
self.update_generation_status(
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started,
current_context_end_position=context_end_position,
is_generation_done_tensor=is_generation_done_tensor,
generated_sequence_lengths=generated_sequence_lengths,
)
)
# Stream intermediate outputs
if streaming_enabled:
streaming_executor.submit(
stream_tokens,
streaming_request_ids,
streaming_requests,
streams,
generation_started[streaming_idx].cpu(),
is_generation_done_tensor[streaming_idx].cpu(),
batch_prompt_tokens[streaming_idx].cpu(),
prompt_lengths_in_batch[streaming_idx].cpu(),
generated_sequence_lengths[streaming_idx].cpu(),
(
output_log_probs[streaming_idx].cpu()
if output_log_probs is not None
else [None] * len(streaming_idx)
),
)
# Boolean flag indicating if all prompts are finished
all_prompts_done = torch.all(is_generation_done_tensor)
if all_prompts_done:
break
# Change to decode mode if all prefill is complete
if torch.all(generation_started):
self.inference_wrapped_model.inference_params.enable_decode_mode()
# Close all streams
if streaming_enabled:
streaming_executor.shutdown()
for stream in streams:
stream.finish()
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if sampling_params.return_log_probs:
assert output_log_probs is not None
output_log_probs = output_log_probs[:, :context_end_position]
generated_sequence_lengths[
generated_sequence_lengths > sampling_params.num_tokens_to_generate
] = sampling_params.num_tokens_to_generate
for idx, request in enumerate(active_requests.values()):
input_prompt_length = int(prompt_lengths_in_batch[idx])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length = int(
min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate)
)
# Extract only the generated tokens
required_result_tokens = batch_prompt_tokens_with_generations[
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
]
generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32)
request.generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32)
request.generated_length = required_sequence_length
request.generated_tokens = required_result_tokens
request.prompt_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, :input_prompt_length].cpu().numpy().tolist()
)
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[
idx,
input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1),
]
.cpu()
.numpy()
.tolist()
)
request.status = Status.COMPLETED
text, segments = self.detokenize_generations(
batch_prompt_tokens_with_generations[idx],
input_prompt_length + generated_sequence_lengths,
sampling_params.return_segments,
)
request.text = text # Inference server returns prompts & generations together
if sampling_params.return_segments:
request.segments = segments[0]
request.generated_text = text[len(request.prompt) :]
return active_requests
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
) -> Dict[str, Any]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
return self.inference_wrapped_model.prep_inference_input(prompts_tokens)
def stream_tokens(
self,
sampling_params: SamplingParams,
request_ids: List[str],
requests: List[InferenceRequest],
streams: List[AsyncStream],
generation_started: List[bool],
is_generation_done: List[bool],
tokens: torch.Tensor,
prompt_lengths: List[int],
generated_lengths: List[int],
output_log_probs: Union[torch.Tensor, None],
):
"""Asynchronously streams tokens for the given requests.
Args:
sampling_params (SamplingParams): The sampling parameters.
request_ids (List[str]): The request IDs.
request (List[InferenceRequest]): The requests.
stream (List[AsyncStream]): The streams over which to send tokens.
generation_started (List[bool]): Whether the decode step has started.
is_generation_done (List[bool]): Whether generation has completed.
tokens (torch.Tensor): The tokens for this request.
prompt_lengths (List[int]): The number of prompt tokens for each request.
generated_lengths (List[int]): The number of output tokens for each request.
output_log_probs (torch.Tensor, optional): The log probs for each request.
"""
def stream_token(
request_id: str,
request: InferenceRequest,
stream: AsyncStream,
generation_started: bool,
is_generation_done: bool,
tokens: torch.Tensor,
prompt_length: int,
generated_length: int,
output_log_probs: Union[torch.Tensor, None],
):
"""Asynchronously streams a token for the given request."""
if not generation_started or stream.finished:
return
num_tokens_to_generate = sampling_params.num_tokens_to_generate
return_segments = sampling_params.return_segments
detokenize_streaming_text = not getattr(
sampling_params, "no_detokenize_streaming_text", False
)
generated_tokens = tokens[prompt_length : prompt_length + generated_length]
if detokenize_streaming_text:
generated_text, generated_segments = self.detokenize_generations(
generated_tokens, prompt_length + generated_length, return_segments
)
else:
generated_text = ""
generated_segments = []
if output_log_probs is not None:
generated_log_probs = (
output_log_probs[prompt_length - 1 : prompt_length + generated_length - 1]
.cpu()
.numpy()
.tolist()
)
else:
generated_log_probs = None
stream.put(
InferenceRequest(
request_id=request_id,
prompt=request.prompt,
inference_parameters=request.inference_parameters,
prompt_tokens=request.prompt_tokens,
arrival_time=request.arrival_time,
status=request.status,
encoder_prompt=request.encoder_prompt,
generated_text=generated_text,
generated_segments=generated_segments,
generated_tokens=generated_tokens,
generated_log_probs=generated_log_probs,
generated_length=generated_length,
)
)
if is_generation_done or generated_length == num_tokens_to_generate:
stream.finish()
ret = map(
stream_token,
request_ids,
requests,
streams,
generation_started,
is_generation_done,
tokens,
prompt_lengths,
generated_lengths,
output_log_probs,
)
list(ret)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import OrderedDict
import torch
from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class VLMTextGenerationController(TextGenerationController):
"""The text generation controller for VLMs"""
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
):
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Currently only supports batch size 1 inference.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
"""
assert len(active_requests) == 1, f"VLM inference currently only supports batch size 1"
request = list(active_requests.values())[0]
assert isinstance(
request, VLMInferenceRequest
), f"Found inference request of type {type(request)}, expected VLMInferenceRequest"
return self.inference_wrapped_model.prep_inference_input(
prompts_tokens,
request.num_img_embeddings_per_tile,
request.imgs,
request.num_tiles,
request.decoder_seq_length,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert (
len(batch_idx) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
def __str__(self):
return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})"
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.current_batch_size = max_batch_size # Required for bookkeeping variable-sized batches
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.decode_mode = False
self.key_value_memory_dict = {}
self.decode_mode = False
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert (
len(batch_idx) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
def enable_prefill_mode(self):
"""
Indicates the generation loop is in the prefill phase (still processing
input prompt tokens). This should be enabled if the generation loop is
encoding prompt tokens for *any* request in a batch.
"""
self.decode_mode = False
def enable_decode_mode(self):
"""
Indicates the generation loop is in the decode phase (generating new output
tokens). This should only be enabled if the generation loop has fully encoded
the prompts for *all* requests in a batch.
"""
self.decode_mode = True
def reset(self):
"""Resets the inference state for a new batch."""
self.current_batch_size = self.max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.enable_prefill_mode()
def __str__(self):
return (
f"InferenceParams(max_seq_len = {self.max_sequence_length}, "
f"max_batch_size = {self.max_batch_size}, "
f"current_batch_size = {self.current_batch_size}, "
f"sequence_len_offset = {self.sequence_len_offset}, "
f"batch_size_offset = {self.batch_size_offset}, "
f"key_value_memory_dict = {self.key_value_memory_dict.keys()})"
f"decode_mode = {self.decode_mode}"
)
def __eq__(self, other):
if not isinstance(other, InferenceParams):
return False
# Check all attributes match
basic_attrs = [
'max_sequence_length',
'max_batch_size',
'current_batch_size',
'sequence_len_offset',
'batch_size_offset',
]
if not all(hasattr(other, attr) for attr in basic_attrs):
return False
# Check dictionary keys match; i.e. the same number of layers are cached
if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys():
return False
# Check each tensor tuple in the dictionary
for key in self.key_value_memory_dict:
self_tensors = self.key_value_memory_dict[key]
other_tensors = other.key_value_memory_dict[key]
# Compare each key, value tensor in the tuple
for self_tensor, other_tensor in zip(self_tensors, other_tensors):
if (
self_tensor.data_ptr() != other_tensor.data_ptr()
or self_tensor.shape != other_tensor.shape
):
return False
return True
......@@ -7,18 +7,4 @@ from megatron.core.utils import is_torch_min_version
jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if is_torch_min_version("2.2.0a0"):
jit_fuser = torch.compile(mode='max-autotune-no-cudagraphs')
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
jit_fuser = torch.compile
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, ContextManager, Optional
import torch
@dataclass
class ModelParallelConfig:
"""Base configuration for Megatron Core
The initialization function has an argument for each parameter.
"""
###################
# Model parallelism
###################
tensor_model_parallel_size: int = 1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_size: int = 1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
virtual_pipeline_model_parallel_size: Optional[int] = None
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
"""
sequence_parallel: bool = False
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
"""
context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""
hierarchical_context_parallel_sizes: Optional[list[int]] = None
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
groups of two levels, so the first value of the list indicates the group size of the a2a
communication type, and the second value indicates the group size of the p2p communication
type.
"""
expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""
expert_tensor_parallel_size: Optional[int] = None
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
moe_extended_tp: bool = False
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size.
"""
###################
# Initialization
###################
perform_initialization: bool = True
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
"""
use_cpu_initialization: bool = False
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
weights from CPU to GPU can take a significant amount of time for large models.
"""
###################
# Training
###################
fp16: bool = False
"""If true, train with fp16 mixed precision training."""
bf16: bool = False
"""If true, train with bf16 mixed precision training."""
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights."""
timers: Optional[Callable] = None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
finalize_model_grads_func: Optional[Callable] = None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
dimensions.
"""
grad_scale_func: Optional[Callable] = None
"""If using loss scaling, this function should take the loss and return the scaled loss. If
None, no function is called on the loss.
"""
no_sync_func: Optional[Callable] = None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
"""
grad_sync_func: Optional[Callable] = None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""
param_sync_func: Optional[Callable] = None
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument: an iterable of parameters to
be synchronized.
"""
deterministic_mode: bool = False
"""If true, code that has deterministic execution will be chosen. This usually
means slower execution, but is good for debugging and testing. Defaults to False."""
enable_autocast: bool = False
"""If true runs the forward step function inside torch.autocast context."""
autocast_dtype: Optional[torch.dtype] = None
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
"""
###################
# Optimizations
###################
gradient_accumulation_fusion: bool = False
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\"
--global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
"""
async_tensor_model_parallel_allreduce: bool = False
"""NOTE: Deprecated. This flag is ignored."""
use_te_rng_tracker: bool = False
"""If true, uses RNG state tracker in TransformerEngine if exists.
"""
tp_comm_overlap: bool = False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass.
"""
tp_comm_bulk_wgrad: bool = True
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_bulk_dgrad: bool = True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_overlap_ag: bool = True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs: bool = True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs_dgrad: bool = False
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_ag: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_ag: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
both done atomically. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_rs: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_rs: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
cross_entropy_loss_fusion: bool = False
"""If this is enabled, the fused cross entropy implementation would be used.
Defaults to False.
"""
tp_comm_overlap_disable_qkv: bool = False
"""
If true, the AllGather -> Gemm overlap for QKV gets disabled
"""
tp_comm_overlap_disable_fc1: bool = False
"""
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
"""
tp_comm_bootstrap_backend: str = 'nccl'
"""
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
"""
###################
# Pipeline Parallel
###################
pipeline_dtype: torch.dtype = None
"""dtype used in p2p communication, usually params_dtype"""
variable_seq_lengths: bool = False
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
"""
overlap_p2p_comm: bool = False
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
computation. Must be False if batch_p2p_comm is true.
"""
batch_p2p_comm: bool = True
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
overlap_p2p_comm is True.
"""
batch_p2p_sync: bool = True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
older version of PyTorch.
"""
use_ring_exchange_p2p: bool = False
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
custom built torch with torch.distributed.ring_exchange.
"""
deallocate_pipeline_outputs: bool = False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""
defer_embedding_wgrad_compute: bool = False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
taking place enabling us to hide pipeline flush latency. Defaults to False.
"""
wgrad_deferral_limit: int = 0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
needs to be deferred to pipeline flush, this argument is invalid if
`defer_embedding_wgrad_compute` is False.
Defaults to 0, which means all micro-batches are deferred.
"""
pipeline_model_parallel_split_rank: Optional[int] = None
"""If int, rank where encoder and decoder should be split in cases where the model has both an
encoder and decoder (e.g., T5). Ignored if None.
"""
overlap_p2p_comm_warmup_flush: bool = False
"""If true, overlap communication and computation in warm up and flush phase.
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
Defaults to False.
"""
microbatch_group_size_per_vp_stage: Optional[int] = None
"""This value specifies the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward).
Default (in __post_init__() method below) to pipeline_parallel_size
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
num_microbatches = 4, we have
rank 0 | 0 1 0 1 2 3 2 3
rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
we have
rank 0 | 0 1 2 0 1 2 3 4 3 4
rank 1 | 0 1 2 0 1 2 3 4 3 4
"""
###################
# CPU Offloading
###################
cpu_offloading: bool = False
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
cpu_offloading_num_layers: int = 0
"""Tells the number of transformer layers for which activations has to be offloaded."""
_cpu_offloading_context: Optional[ContextManager] = (
None
# Used for internal use only, not to be set by a user.
# TODO: Need to move to the 'right' place when possible.
)
"""For internal use only, do not set."""
cpu_offloading_activations: bool = True
"""If True, offloads the activations to CPU."""
cpu_offloading_weights: bool = True
"""If True, offloads the weights to CPU."""
###################
# Timing
###################
barrier_with_L1_time: bool = True
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.expert_tensor_parallel_size is None:
self.expert_tensor_parallel_size = self.tensor_model_parallel_size
if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
raise ValueError(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if self.autocast_dtype is None:
self.autocast_dtype = self.params_dtype
if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1:
raise ValueError(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
)
if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion:
raise ValueError(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
)
if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0:
raise ValueError(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
)
if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
raise ValueError(
"When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
)
if self.microbatch_group_size_per_vp_stage is None:
self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size
if self.overlap_p2p_comm_warmup_flush:
if not self.overlap_p2p_comm or self.batch_p2p_comm:
raise ValueError(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, ContextManager, Optional
import torch
@dataclass
class ModelParallelConfig:
"""Base configuration for Megatron Core
The initialization function has an argument for each parameter.
"""
###################
# Model parallelism
###################
tensor_model_parallel_size: int = 1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_comm_backend: Optional[str] = None
"""Configuring backend option of pipeline parallel communication (e.g., nccl, ucc)
If None, the default backend will be used.
"""
pipeline_model_parallel_size: int = 1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
virtual_pipeline_model_parallel_size: Optional[int] = None
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
"""
sequence_parallel: bool = False
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
"""
context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""
hierarchical_context_parallel_sizes: Optional[list[int]] = None
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
groups of two levels, so the first value of the list indicates the group size of the a2a
communication type, and the second value indicates the group size of the p2p communication
type.
"""
expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""
expert_tensor_parallel_size: Optional[int] = None
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
moe_extended_tp: bool = False
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size.
"""
###################
# Initialization
###################
perform_initialization: bool = True
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
"""
use_cpu_initialization: bool = False
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
weights from CPU to GPU can take a significant amount of time for large models.
"""
###################
# Training
###################
fp16: bool = False
"""If true, train with fp16 mixed precision training."""
bf16: bool = False
"""If true, train with bf16 mixed precision training."""
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights."""
timers: Optional[Callable] = None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
finalize_model_grads_func: Optional[Callable] = None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
dimensions.
"""
grad_scale_func: Optional[Callable] = None
"""If using loss scaling, this function should take the loss and return the scaled loss. If
None, no function is called on the loss.
"""
no_sync_func: Optional[Callable] = None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
"""
grad_sync_func: Optional[Callable] = None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""
param_sync_func: Optional[Callable] = None
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument: an iterable of parameters to
be synchronized.
"""
deterministic_mode: bool = False
"""If true, code that has deterministic execution will be chosen. This usually
means slower execution, but is good for debugging and testing. Defaults to False."""
enable_autocast: bool = False
"""If true runs the forward step function inside torch.autocast context."""
autocast_dtype: Optional[torch.dtype] = None
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
"""
###################
# Optimizations
###################
gradient_accumulation_fusion: bool = False
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\"
--global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
"""
async_tensor_model_parallel_allreduce: bool = False
"""NOTE: Deprecated. This flag is ignored."""
use_te_rng_tracker: bool = False
"""If true, uses RNG state tracker in TransformerEngine if exists.
"""
tp_comm_overlap: bool = False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass.
"""
tp_comm_bulk_wgrad: bool = True
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_bulk_dgrad: bool = True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_overlap_ag: bool = True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs: bool = True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs_dgrad: bool = False
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_ag: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_ag: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
both done atomically. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_rs: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_rs: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
cross_entropy_loss_fusion: bool = False
"""If this is enabled, the fused cross entropy implementation would be used.
Defaults to False.
"""
tp_comm_overlap_disable_qkv: bool = False
"""
If true, the AllGather -> Gemm overlap for QKV gets disabled
"""
tp_comm_overlap_disable_fc1: bool = False
"""
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
"""
tp_comm_bootstrap_backend: str = 'nccl'
"""
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
"""
###################
# Pipeline Parallel
###################
pipeline_dtype: torch.dtype = None
"""dtype used in p2p communication, usually params_dtype"""
variable_seq_lengths: bool = False
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
"""
overlap_p2p_comm: bool = False
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
computation. Must be False if batch_p2p_comm is true.
"""
batch_p2p_comm: bool = True
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
overlap_p2p_comm is True.
"""
batch_p2p_sync: bool = True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
older version of PyTorch.
"""
use_ring_exchange_p2p: bool = False
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
custom built torch with torch.distributed.ring_exchange.
"""
deallocate_pipeline_outputs: bool = False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""
defer_embedding_wgrad_compute: bool = False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
taking place enabling us to hide pipeline flush latency. Defaults to False.
"""
wgrad_deferral_limit: int = 0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
needs to be deferred to pipeline flush, this argument is invalid if
`defer_embedding_wgrad_compute` is False.
Defaults to 0, which means all micro-batches are deferred.
"""
pipeline_model_parallel_split_rank: Optional[int] = None
"""If int, rank where encoder and decoder should be split in cases where the model has both an
encoder and decoder (e.g., T5). Ignored if None.
"""
overlap_p2p_comm_warmup_flush: bool = False
"""If true, overlap communication and computation in warm up and flush phase.
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
Defaults to False.
"""
microbatch_group_size_per_vp_stage: Optional[int] = None
"""This value specifies the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward).
Default (in __post_init__() method below) to pipeline_parallel_size
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
num_microbatches = 4, we have
rank 0 | 0 1 0 1 2 3 2 3
rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
we have
rank 0 | 0 1 2 0 1 2 3 4 3 4
rank 1 | 0 1 2 0 1 2 3 4 3 4
"""
###################
# CPU Offloading
###################
cpu_offloading: bool = False
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
cpu_offloading_num_layers: int = 0
"""Tells the number of transformer layers for which activations has to be offloaded."""
_cpu_offloading_context: Optional[ContextManager] = (
None
# Used for internal use only, not to be set by a user.
# TODO: Need to move to the 'right' place when possible.
)
"""For internal use only, do not set."""
cpu_offloading_activations: bool = True
"""If True, offloads the activations to CPU."""
cpu_offloading_weights: bool = True
"""If True, offloads the weights to CPU."""
###################
# Timing
###################
barrier_with_L1_time: bool = True
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.expert_tensor_parallel_size is None:
self.expert_tensor_parallel_size = self.tensor_model_parallel_size
if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
raise ValueError(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if self.autocast_dtype is None:
self.autocast_dtype = self.params_dtype
if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1:
raise ValueError(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
)
if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion:
raise ValueError(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
)
if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0:
raise ValueError(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
)
if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
raise ValueError(
"When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
)
if self.microbatch_group_size_per_vp_stage is None:
self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size
if self.overlap_p2p_comm_warmup_flush:
if not self.overlap_p2p_comm or self.batch_p2p_comm:
raise ValueError(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
......@@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.enums import ModelType
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.relative_pos_embedding import RelativePositionEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
......@@ -135,9 +137,13 @@ class T5Model(LanguageModule):
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
position_embedding_type: Literal[
'learned_absolute', 'rope', 'relative'
] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
add_encoder: bool = True,
add_decoder: bool = True,
):
......@@ -193,6 +199,23 @@ class T5Model(LanguageModule):
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Relative Position Embeddings
if self.position_embedding_type == 'relative':
self.encoder_relative_pos_emb = RelativePositionEmbedding(
bidirectional=True,
init_method=self.config.init_method,
num_attention_heads=self.config.num_attention_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
)
self.decoder_relative_pos_emb = RelativePositionEmbedding(
bidirectional=False,
init_method=self.config.init_method,
num_attention_heads=self.config.num_attention_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
)
# Transformer encoder
encoder_spec, decoder_spec = (
self.transformer_encoder_layer_spec,
......@@ -284,6 +307,27 @@ class T5Model(LanguageModule):
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Relative positional embeddings
encoder_attention_bias_parallel = None
if self.position_embedding_type == 'relative':
query_seq_length = RelativePositionEmbedding.get_relative_seq_len(
inference_params, self.encoder, encoder_input, self.config
)
key_seq_length = query_seq_length
attention_bias = self.encoder_relative_pos_emb(query_seq_length, key_seq_length)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias = torch.permute(attention_bias, (0, 2, 3, 1))
# Then, scatter to TP region
attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
encoder_attention_bias_parallel = torch.permute(
attention_bias_parallel, (0, 3, 1, 2)
)
# Run encoder.
if self.add_encoder:
encoder_hidden_states = self.encoder(
......@@ -291,6 +335,7 @@ class T5Model(LanguageModule):
attention_mask=encoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
attention_bias=encoder_attention_bias_parallel,
)
else:
encoder_hidden_states = self.encoder_hidden_state
......@@ -315,10 +360,29 @@ class T5Model(LanguageModule):
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config, packed_seq_params
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Relative positional embeddings
decoder_attention_bias_parallel = None
if self.position_embedding_type == 'relative':
query_seq_length = RelativePositionEmbedding.get_relative_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
key_seq_length = query_seq_length
attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias = torch.permute(attention_bias, (0, 2, 3, 1))
# Then, scatter to TP region
attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
decoder_attention_bias_parallel = torch.permute(attention_bias_parallel, (0, 3, 1, 2))
# Run decoder.
decoder_hidden_states = self.decoder(
hidden_states=decoder_input,
......@@ -327,12 +391,15 @@ class T5Model(LanguageModule):
context_mask=encoder_decoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
attention_bias=decoder_attention_bias_parallel,
)
if self.post_process:
lm_logits = self.lm_head(
decoder_hidden_states, self.shared_embedding_or_output_weight()
)
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
lm_logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight)
if lm_labels is None:
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous()
......
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import logging
import math
from typing import Callable
import torch
from torch import Tensor, nn
from megatron.core.inference_params import InferenceParams
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
logger = logging.getLogger(__name__)
__all__ = ['RelativePositionEmbedding']
class RelativePositionEmbedding(nn.Module):
"""Relative Position Embedding for language model.
Args:
"""
def __init__(
self,
bidirectional: bool,
init_method: Callable,
num_attention_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
) -> None:
super().__init__()
self.bidirectional = bidirectional
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.relative_attention_bias = torch.nn.Embedding(
self.relative_attention_num_buckets, num_attention_heads
)
init_method(self.relative_attention_bias.weight)
def _relative_position_bucket(
self, relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from HuggingFace T5 Model:
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L397
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e. the
distance in tokens from the attending position to the attended-to position.
If bidirectional=False, then positive relative positions are invalid. We use
smaller buckets for small absolute relative_position and larger buckets for
larger absolute relative_positions. All relative positions >=max_distance map
to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the
model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position,
containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger
# bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def _compute_bias(self, query_length, key_length):
"""
Adapted from HuggingFace T5 Model
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L444C9-L444C21
Compute binned relative position bias
Args:
query_length (int): The length of the query sequence
(e.g., the input sequence in attention).
key_length (int): The length of the key sequence
(e.g., the sequence to compare against in attention).
Returns:
torch.Tensor: A tensor representing the relative position bias, with shape
(1, num_heads, query_length, key_length).
"""
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape(query_length,key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape(query_length,key_length,num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape(1, num_heads,query_length,key_length)
return values
@staticmethod
def get_relative_seq_len(
inference_params: InferenceParams,
transformer: TransformerBlock,
transformer_input: Tensor,
transformer_config: TransformerConfig,
) -> float:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
Returns:
float: The rotary sequence length
"""
if inference_params is not None:
relative_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
relative_seq_len = transformer.input_tensor.size(0)
else:
relative_seq_len = transformer_input.size(0)
if transformer_config.sequence_parallel:
relative_seq_len *= transformer_config.tensor_model_parallel_size
return relative_seq_len
def forward(self, query_seq_length, key_seq_length):
"""
Args:
Returns:
"""
return self._compute_bias(query_seq_length, key_seq_length)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.inference_params import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams
import logging
import math
from functools import lru_cache
import torch
from torch import Tensor, nn
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import
_apply_rotary_pos_emb_bshd,
_apply_rotary_pos_emb_thd,
_rotate_half,
apply_rotary_pos_emb,
get_pos_emb_on_this_cp_rank,
)
logger = logging.getLogger(__name__)
__all__ = ['RotaryEmbedding']
class RotaryEmbedding(nn.Module):
"""Rotary Embedding for language model.
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained
from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position
embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
on the GPU. Defaults to False
"""
def __init__(
self,
kv_channels: int,
rotary_percent: float,
rotary_interleaved: bool = False,
seq_len_interpolation_factor: float = None,
rotary_base: int = 10000,
rope_scaling: bool = False,
use_cpu_initialization: bool = False,
) -> None:
super().__init__()
dim = kv_channels
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.rotary_interleaved = rotary_interleaved
self.seq_len_interpolation_factor = seq_len_interpolation_factor
device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
self.inv_freq = 1.0 / (
rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
if rope_scaling:
self.inv_freq = self._apply_scaling(self.inv_freq)
def _apply_scaling(
self,
freqs,
factor=8,
low_freq_factor=1,
high_freq_factor=4,
original_max_position_embeddings=8192,
):
# This implementation is adapted from:
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
factor = factor # `8` in the original implementation
low_freq_factor = low_freq_factor # `1` in the original implementation
high_freq_factor = high_freq_factor # `4` in the original implementation
old_context_len = original_max_position_embeddings # `8192` in the original implementation
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / freqs
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama
def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor:
"""Generates matrix of frequencies based on positions in the sequence,
used to create positional encodings"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if self.seq_len_interpolation_factor is not None:
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.outer(seq, self.inv_freq) # [seq len, dim]
return freqs
def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor):
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length"""
freqs = self.get_freqs_non_repeated(max_seq_len, offset)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return cos, sin
@lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
"""Forward pass of RoPE embedding.
Args:
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
Returns:
Tensor: Embeddings after applying RoPE.
"""
if self.inv_freq.device.type == 'cpu':
# move `inv_freq` to GPU once at the first micro-batch forward pass
self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device())
freqs = self.get_freqs_non_repeated(max_seq_len, offset)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if not self.rotary_interleaved:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim]
emb = emb[:, None, None, :]
if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0)
return emb
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
state_dict.pop(f'{prefix}inv_freq', None)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_rotary_seq_len(
self,
inference_params: InferenceParams,
transformer: TransformerBlock,
transformer_input: Tensor,
transformer_config: TransformerConfig,
packed_seq_params: PackedSeqParams,
) -> float:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
packed_seq_params (PackedSeqParams): Packed sequence params
Returns:
float: The rotary sequence length
"""
if packed_seq_params is not None:
# max_seqlen are the max sequence length in the packed sequence before being divived
# by the tp and cp size.
return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv)
elif inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
rotary_seq_len = transformer.input_tensor.size(0)
else:
rotary_seq_len = transformer_input.size(0)
if transformer_config.sequence_parallel:
rotary_seq_len *= transformer_config.tensor_model_parallel_size
rotary_seq_len *= transformer_config.context_parallel_size
return rotary_seq_len
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.inference_params import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams
import logging
import math
from functools import lru_cache
import torch
from torch import Tensor, nn
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import
_apply_rotary_pos_emb_bshd,
_apply_rotary_pos_emb_thd,
_rotate_half,
apply_rotary_pos_emb,
get_pos_emb_on_this_cp_rank,
)
logger = logging.getLogger(__name__)
__all__ = ['RotaryEmbedding']
class RotaryEmbedding(nn.Module):
"""Rotary Embedding for language model.
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained
from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position
embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.x.
rope_scaling_factor (float, optional): rope scaling factor in llama 3.x. Defaults to 8.
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
on the GPU. Defaults to False
"""
def __init__(
self,
kv_channels: int,
rotary_percent: float,
rotary_interleaved: bool = False,
seq_len_interpolation_factor: float = None,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
use_cpu_initialization: bool = False,
) -> None:
super().__init__()
dim = kv_channels
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.rotary_interleaved = rotary_interleaved
self.seq_len_interpolation_factor = seq_len_interpolation_factor
device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
self.inv_freq = 1.0 / (
rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
if rope_scaling:
self.inv_freq = self._apply_scaling(self.inv_freq, factor=rope_scaling_factor)
def _apply_scaling(
self,
freqs,
factor=8,
low_freq_factor=1,
high_freq_factor=4,
original_max_position_embeddings=8192,
):
# This implementation is adapted from:
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
factor = factor # `8` in the original implementation
low_freq_factor = low_freq_factor # `1` in the original implementation
high_freq_factor = high_freq_factor # `4` in the original implementation
old_context_len = original_max_position_embeddings # `8192` in the original implementation
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / freqs
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama
def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor:
"""Generates matrix of frequencies based on positions in the sequence,
used to create positional encodings"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if self.seq_len_interpolation_factor is not None:
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.outer(seq, self.inv_freq) # [seq len, dim]
return freqs
def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor):
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length"""
freqs = self.get_freqs_non_repeated(max_seq_len, offset)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return cos, sin
@lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
"""Forward pass of RoPE embedding.
Args:
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
Returns:
Tensor: Embeddings after applying RoPE.
"""
if self.inv_freq.device.type == 'cpu':
# move `inv_freq` to GPU once at the first micro-batch forward pass
self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device())
freqs = self.get_freqs_non_repeated(max_seq_len, offset)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if not self.rotary_interleaved:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim]
emb = emb[:, None, None, :]
if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0)
return emb
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
state_dict.pop(f'{prefix}inv_freq', None)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_rotary_seq_len(
self,
inference_params: InferenceParams,
transformer: TransformerBlock,
transformer_input: Tensor,
transformer_config: TransformerConfig,
packed_seq_params: PackedSeqParams,
) -> float:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
packed_seq_params (PackedSeqParams): Packed sequence params
Returns:
float: The rotary sequence length
"""
if packed_seq_params is not None:
# max_seqlen are the max sequence length in the packed sequence before being divived
# by the tp and cp size.
return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv)
elif inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if transformer is not None and transformer.input_tensor is not None:
rotary_seq_len = transformer.input_tensor.size(0)
else:
rotary_seq_len = transformer_input.size(0)
if transformer_config.sequence_parallel:
rotary_seq_len *= transformer_config.tensor_model_parallel_size
rotary_seq_len *= transformer_config.context_parallel_size
return rotary_seq_len
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.utils import is_te_min_version
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_gpt_layer_with_transformer_engine_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = _get_mlp_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=TEColumnParallelLinear,
linear_q_down_proj=TEColumnParallelLinear,
linear_q_up_proj=TEColumnParallelLinear,
linear_kv_down_proj=TEColumnParallelLinear,
linear_kv_up_proj=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_gpt_layer_local_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = _get_mlp_module_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear,
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear,
linear_kv_up_proj=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
kv_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
k_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_gpt_decoder_block_spec(
config: TransformerConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""GPT block spec."""
if use_transformer_engine:
layer_norm_impl = TENorm
else:
layer_norm_impl = LNImpl
# Layer specs.
dense_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
moe_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if isinstance(config.moe_layer_freq, int):
moe_layer_pattern = [
1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers)
]
elif isinstance(config.moe_layer_freq, list):
moe_layer_pattern = config.moe_layer_freq
assert len(moe_layer_pattern) == config.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {config.num_layers}, "
f"current moe layer pattern: {config.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}"
)
# Create the layer specs for the model.
layer_specs = []
for layer_number in range(config.num_layers):
if moe_layer_pattern[layer_number] == 1:
layer_specs.append(moe_layer_spec)
elif moe_layer_pattern[layer_number] == 0:
layer_specs.append(dense_layer_spec)
else:
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}")
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset = TransformerLayer._get_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl)
return block_spec
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
get_transformer_layer_offset,
)
from megatron.core.utils import is_te_min_version
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_gpt_layer_with_transformer_engine_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = get_mlp_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=TEColumnParallelLinear,
linear_q_down_proj=TEColumnParallelLinear,
linear_q_up_proj=(
TELayerNormColumnParallelLinear
if qk_layernorm
else TEColumnParallelLinear
),
linear_kv_down_proj=TEColumnParallelLinear,
linear_kv_up_proj=(
TELayerNormColumnParallelLinear
if qk_layernorm
else TEColumnParallelLinear
),
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_gpt_layer_local_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = get_mlp_module_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear,
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear,
linear_kv_up_proj=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
kv_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
k_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
):
warnings.warn(
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
since it will be removed in a future release."""
)
return get_mlp_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
fp8=fp8,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_gpt_decoder_block_spec(
config: TransformerConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""GPT block spec."""
if use_transformer_engine:
layer_norm_impl = TENorm
else:
layer_norm_impl = LNImpl
# Layer specs.
dense_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
moe_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if isinstance(config.moe_layer_freq, int):
moe_layer_pattern = [
1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers)
]
elif isinstance(config.moe_layer_freq, list):
moe_layer_pattern = config.moe_layer_freq
assert len(moe_layer_pattern) == config.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {config.num_layers}, "
f"current moe layer pattern: {config.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}"
)
# Create the layer specs for the model.
layer_specs = []
for layer_number in range(config.num_layers):
if moe_layer_pattern[layer_number] == 1:
layer_specs.append(moe_layer_spec)
elif moe_layer_pattern[layer_number] == 0:
layer_specs.append(dense_layer_spec)
else:
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}")
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset = get_transformer_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl)
return block_spec
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import OrderedDict
from typing import Dict, Literal, Optional
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cos_sin(
inference_params.max_sequence_length
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import OrderedDict
from typing import Dict, Literal, Optional
import torch
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super().__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment