Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Any, Dict
import torch
from megatron.core import parallel_state
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference_params import InferenceParams
# pylint: disable=line-too-long
class VLMInferenceWrapper(GPTInferenceWrapper):
"""Inference wrapper for VLMs"""
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop.
It puts the model in eval mode.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
super().prep_model_for_inference(prompts_tokens)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
self._recv_only_vision_embeds = False
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
# Checks if the previous stage only has a vision encoder, and that the current stage
# has part of the LM decoder. In this case, the current stage should only receive
# vision embeddings.
if pp_rank > 0:
self._recv_only_vision_embeds = (
parallel_state.is_inside_encoder(pp_rank - 1)
and (not parallel_state.is_inside_decoder(pp_rank - 1))
and parallel_state.is_inside_decoder()
)
# Checks if the current stage only has a vision encoder
self._encoder_only = (
parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder()
)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def prep_inference_input(
self,
prompts_tokens: torch.Tensor,
num_img_embeddings_per_tile: int,
images: torch.Tensor,
num_tiles: torch.Tensor,
decoder_seq_length: int,
):
"""Prepares the inference input data.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
num_img_embeddings_per_tile (int): The number of image embeddings per tile
images (torch.Tensor): The image embeddings
num_tiles (torch.Tensor): The number of tiles for each input image
decoder_seq_length (int): The decoder sequence length
"""
inference_input = super().prep_inference_input(prompts_tokens)
total_num_tiles = torch.sum(num_tiles).item()
num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles
batch_size, max_sequence_length = prompts_tokens.shape
self.inference_params = InferenceParams(
batch_size, max_sequence_length + num_img_embeddings
)
inference_input["images"] = images
inference_input["num_tiles"] = num_tiles
inference_input["num_img_embeddings"] = num_img_embeddings
inference_input["decoder_seq_length"] = decoder_seq_length
return inference_input
def get_batch_for_context_window(
self,
inference_input: Dict[str, Any],
context_start_position: int,
context_end_position: int,
) -> Dict[str, Any]:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
Args:
inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During the first inference step it is mostly 0
context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length.
Returns:
Dict[str, Any]: A dict of inputs that will be used by your model in the forward step
"""
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
images = inference_input["images"]
num_tiles = inference_input["num_tiles"]
num_img_embeddings = inference_input["num_img_embeddings"]
decoder_seq_length = inference_input["decoder_seq_length"]
tokens2use = tokens[:, context_start_position:context_end_position]
positions2use = position_ids[:, context_start_position:context_end_position]
return {
"tokens": tokens2use,
"position_ids": positions2use,
"images": images,
"num_tiles": num_tiles,
"num_img_embeddings": num_img_embeddings,
"decoder_seq_length": decoder_seq_length,
}
def _forward(self, inference_input: Dict[str, Any]):
"""Runs a forward pass of the model.
Args:
inference_input(Dict[str, Any]): The input data.
Returns:
The model output logits.
"""
images = inference_input["images"]
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
num_image_tiles = inference_input["num_tiles"]
output = self.model(
images,
tokens,
position_ids=position_ids,
attention_mask=None,
inference_params=self.inference_params,
num_image_tiles=num_image_tiles,
runtime_gather_output=True,
)
if isinstance(output, tuple):
logits, _ = output
else:
logits = output
return logits
def run_one_forward_step(self, inference_input: Dict[str, Any]) -> torch.Tensor:
tokens = inference_input["tokens"]
num_image_tokens = (tokens == self.model.module.image_token_index).sum().item()
num_img_embeddings = inference_input["num_img_embeddings"]
decoder_seq_length = inference_input["decoder_seq_length"]
num_tokens = tokens.size(1)
recv_buffer_seq_len = None
if num_image_tokens > 0:
# When there are image tokens and this stage only receives vision embeddings,
# adjust the recv buffer seq length to match the image embeddings sequence length.
# If there are image tokens and this stage receives full embeddings, make sure we
# compensate for expansion of image tokens.
# Note that this will set a recv_buffer_seq_len for the encoder stage,
# this length is irrelevant since that recv buffer is never allocated.
if self._recv_only_vision_embeds:
recv_buffer_seq_len = num_img_embeddings
else:
recv_buffer_seq_len = min(
num_img_embeddings + num_tokens - num_image_tokens, decoder_seq_length
)
elif self._recv_only_vision_embeds:
# If this stage only receives vision embeddings and there are no image tokens
# we won't run the encoder and therefore shouldn't try to recv.
recv_buffer_seq_len = 0
# If the pipeline stage only has a vision encoder, then it only needs to
# run when there are image tokens
if not (self._encoder_only and num_image_tokens == 0):
output = super().run_one_forward_step(
inference_input, recv_buffer_seq_len=recv_buffer_seq_len
)
else:
output = None
logits = output
# On the first inference iteration, we compute image tokens.
# On every PP stage(although inference params should only matter for decoder),
# update the sequence length offset by the number of image tokens.
if num_tokens > 1 and num_image_tokens > 0:
if "image_tokens_count" not in self.inference_params.key_value_memory_dict:
self.inference_params.key_value_memory_dict["image_tokens_count"] = (
num_img_embeddings
)
if num_img_embeddings + num_tokens - num_image_tokens > decoder_seq_length:
self.inference_params.sequence_len_offset += decoder_seq_length - num_tokens
else:
self.inference_params.sequence_len_offset += (
self.inference_params.key_value_memory_dict["image_tokens_count"]
- num_image_tokens
)
return logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from collections import deque from collections import deque
from typing import Any, List, Tuple from typing import Any, Dict, List, Optional
import numpy import numpy
import torch import torch
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper, AbstractModelInferenceWrapper,
) )
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig, InferenceWrapperConfig,
) )
from megatron.core.models.T5 import T5Model from megatron.core.models.T5 import T5Model
from megatron.core.utils import get_attr_wrapped_model
# pylint: disable=line-too-long
class T5InferenceWrapper(AbstractModelInferenceWrapper): # pylint: disable=line-too-long
"""Constructor for the model inference wrapper class T5InferenceWrapper(AbstractModelInferenceWrapper):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy) Args:
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed model (T5Model): The T5 model (MCore or legacy)
use_local (bool): Whether the T5 model's transformer impl inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
is local (vs transformer_engine) use_local (bool): Whether the T5 model's transformer impl
""" is local (vs transformer_engine)
"""
def __init__(
self, def __init__(
model: T5Model, self,
inference_wrapper_config: InferenceWrapperConfig, model: T5Model,
use_local: bool = False, inference_wrapper_config: InferenceWrapperConfig,
): use_local: bool = False,
super().__init__(model, inference_wrapper_config) ):
self.use_local = use_local super().__init__(model, inference_wrapper_config)
self.use_local = use_local
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, encoder_prompts: List[str] = None, tokenizer: Any = None def prep_inference_input(
): self,
"""A utility function for preparing model for inference prompts_tokens: torch.Tensor,
encoder_prompts: Optional[List[str]] = None,
This function is called before the forward pass. It puts the model in eval mode, builds tokenizer: Any = None,
position ids, and creates attention masks so that required slices can be extracted during ) -> Dict[str, Any]:
the forward pass. """Prepares the inference input data.
Args: Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
encoder_prompts (dict): List of string of encoder input prompts encoder_prompts (dict): List of string of encoder input prompts
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
"""
Returns:
super().prep_model_for_inference(prompts_tokens=prompts_tokens) A dict with all the inference input needed for the batch.
"""
# get max_sequence_length # get max_sequence_length
if hasattr(self.model, "module"): # if self.model is Float16Module max_sequence_length = get_attr_wrapped_model(self.model, "max_sequence_length")
max_sequence_length = self.model.module.max_sequence_length
else: encoder_prompts_tokens_list = [
max_sequence_length = self.model.max_sequence_length self.tokenize_encoder_prompt(encoder_prompt, tokenizer)
for encoder_prompt in encoder_prompts
encoder_prompts_tokens_list = [ ]
self.tokenize_encoder_prompt(encoder_prompt, tokenizer) batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens(
for encoder_prompt in encoder_prompts encoder_prompts_tokens_list, max_sequence_length, tokenizer
] )
self.batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens(
encoder_prompts_tokens_list, max_sequence_length, tokenizer # create batch mask for encoder_prompt (self.batch_input_tokens) and
) # decoder_input (prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens = prompts_tokens
# create batch mask for encoder_prompt (self.batch_input_tokens) and encoder_prompts_tokens = batch_encoder_prompts_tokens
# decoder_input (self.prompts_tokens), similar to megatron/core/datasets/t5_dataset.py decoder_prompts_tokens_numpy = decoder_prompts_tokens.cpu().numpy()
decoder_prompts_tokens = self.prompts_tokens.cpu().numpy() encoder_prompts_tokens_numpy = encoder_prompts_tokens.cpu().numpy()
encoder_prompts_tokens = self.batch_encoder_prompts_tokens.cpu().numpy() batch_mask_encoder = []
self.batch_mask_encoder = [] batch_mask_decoder = []
self.batch_mask_decoder = [] for i in range(len(prompts_tokens)):
for i in range(len(self.prompts_tokens)): mask_encoder = encoder_prompts_tokens_numpy[i] == tokenizer.pad
mask_encoder = encoder_prompts_tokens[i] == tokenizer.pad mask_decoder = decoder_prompts_tokens_numpy[i] == tokenizer.pad
mask_decoder = decoder_prompts_tokens[i] == tokenizer.pad batch_mask_encoder.append(mask_encoder)
self.batch_mask_encoder.append(mask_encoder) batch_mask_decoder.append(mask_decoder)
self.batch_mask_decoder.append(mask_decoder) batch_mask_encoder = torch.tensor(numpy.array(batch_mask_encoder)).cuda()
self.batch_mask_encoder = torch.tensor(numpy.array(self.batch_mask_encoder)).cuda() batch_mask_decoder = torch.tensor(numpy.array(batch_mask_decoder)).cuda()
self.batch_mask_decoder = torch.tensor(numpy.array(self.batch_mask_decoder)).cuda()
return {
def tokenize_encoder_prompt( "encoder_tokens": encoder_prompts_tokens,
self, encoder_prompt: str, tokenizer "decoder_tokens": decoder_prompts_tokens,
) -> Tuple[torch.Tensor, torch.Tensor]: "encoder_mask": batch_mask_encoder,
"""Utility to tokenize the encoder_prompt "decoder_mask": batch_mask_decoder,
}
Args:
encoder_prompt (str): The encoder_prompt def tokenize_encoder_prompt(self, encoder_prompt: str, tokenizer) -> torch.Tensor:
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string """Utility to tokenize the encoder_prompt
Returns: Args:
torch.Tensor: Returns the tokenized prompt encoder_prompt (str): The encoder_prompt
""" tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
# if there is the word "<mask>" in prompt, replacing it with special_additional_token, Returns:
# similar to processing step in megatron/core/datasets/t5_dataset.py torch.Tensor: Returns the tokenized prompt
divided_encoder_prompt_list = encoder_prompt.split("<mask>") """
masks_count = len(divided_encoder_prompt_list) - 1
sentinels = deque(tokenizer.additional_special_tokens_ids) # if there is the word "<mask>" in prompt, replacing it with special_additional_token,
# similar to processing step in megatron/core/datasets/t5_dataset.py
encoder_prompt_tokens = [] divided_encoder_prompt_list = encoder_prompt.split("<mask>")
for divided_encoder_prompt in divided_encoder_prompt_list: masks_count = len(divided_encoder_prompt_list) - 1
divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt) sentinels = deque(tokenizer.additional_special_tokens_ids)
encoder_prompt_tokens.extend(divided_encoder_prompt_tokens)
if masks_count > 0: encoder_prompt_tokens = []
sentinel = sentinels.popleft() for divided_encoder_prompt in divided_encoder_prompt_list:
encoder_prompt_tokens.extend([sentinel]) divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt)
masks_count -= 1 encoder_prompt_tokens.extend(divided_encoder_prompt_tokens)
if masks_count > 0:
return encoder_prompt_tokens sentinel = sentinels.popleft()
encoder_prompt_tokens.extend([sentinel])
def pad_encoder_prompts_tokens( masks_count -= 1
self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer
) -> torch.Tensor: return encoder_prompt_tokens
"""Method to pad input prompts
def pad_encoder_prompts_tokens(
Given a list of prompts, pad them all to uniform length self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer
) -> torch.Tensor:
Args: """Method to pad input prompts
encoder_prompts_tokens_list (List[List[int]]): A list containing the
encoder_input_tokens Given a list of prompts, pad them all to uniform length
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text Args:
encoder_prompts_tokens_list (List[List[int]]): A list containing the
Returns: encoder_input_tokens
torch.Tensor: A torch tensor of shape [bs, max_sequence_length] max_sequence_length (int): Maximum of the length of the encoder inputs tokens
""" tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
for encoder_prompt_tokens in encoder_prompts_tokens_list: Returns:
padding_size = max_sequence_length - len(encoder_prompt_tokens) torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
encoder_prompt_tokens.extend([tokenizer.pad] * padding_size) """
return torch.tensor(encoder_prompts_tokens_list).cuda() for encoder_prompt_tokens in encoder_prompts_tokens_list:
padding_size = max_sequence_length - len(encoder_prompt_tokens)
def get_batch_for_context_window( encoder_prompt_tokens.extend([tokenizer.pad] * padding_size)
self, context_start_position: int, context_end_position: int
) -> List: return torch.tensor(encoder_prompts_tokens_list).cuda()
"""Returns the inference data given context window
def get_batch_for_context_window(
This function gets called iteratively in a loop . Given the start and end context self,
positions , it extracts the appropriate data. inference_input: Dict[str, Any],
context_start_position: int,
Args: context_end_position: int,
context_start_position (int): Start of the context window. During ) -> Dict[str, Any]:
the first inference step it is mostly 0 """Returns the inference data given context window
context_end_position (int): End of the context window. During the
last inference step it will mostly be the max generated sequence length. This function gets called iteratively in a loop . Given the start and end context
positions , it extracts the appropriate data.
Returns:
List: A list of inputs that will be used by your model in the forward step Args:
""" inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During
# T5 inference not yet support kv_cache the first inference step it is mostly 0
encoder_tokens2use = self.batch_encoder_prompts_tokens context_end_position (int): End of the context window. During the
decoder_tokens2use = self.prompts_tokens[:, :context_end_position] last inference step it will mostly be the max generated sequence length.
encoder_mask2use = self.batch_mask_encoder
decoder_mask2use = self.batch_mask_decoder[:, :context_end_position] Returns:
Dict: A dict of inputs that will be used by your model in the forward step
# Configure attention mask based on different conditions """
# (e.g., transformer-impl, TE versions, TE backends)
[encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = ( # T5 inference not yet support kv_cache
T5MaskedWordPieceDataset.config_attention_mask( encoder_tokens2use = inference_input["encoder_tokens"]
encoder_tokens2use, decoder_tokens2use = inference_input["decoder_tokens"][:, :context_end_position]
decoder_tokens2use, encoder_mask2use = inference_input["encoder_mask"]
encoder_mask2use, decoder_mask2use = inference_input["decoder_mask"][:, :context_end_position]
decoder_mask2use,
self.use_local, # Configure attention mask based on different conditions
) # (e.g., transformer-impl, TE versions, TE backends)
) [encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = (
T5MaskedWordPieceDataset.config_attention_mask(
data_at_step_idx = [ encoder_tokens2use,
encoder_tokens2use, decoder_tokens2use,
decoder_tokens2use, encoder_mask2use,
encoder_mask2use, decoder_mask2use,
decoder_mask2use, self.use_local,
encoder_decoder_mask2use, )
] )
return data_at_step_idx return {
"encoder_tokens": encoder_tokens2use,
def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor: "decoder_tokens": decoder_tokens2use,
"""Utility to carry out simple forward pass for TP or no model parallel models "encoder_mask": encoder_mask2use,
"decoder_mask": decoder_mask2use,
Runs a very simple forward pass for model. Used in the case of models without "encoder_decoder_mask": encoder_decoder_mask2use,
any parallelism or only tensor parallelism. }
Args: def forward_pass_without_pipeline_parallel(
inference_input (List): A list containg the inputs for the gpt self, inference_input: Dict[str, Any]
model [tokens, position ids, attention mask] ) -> torch.Tensor:
"""Utility to carry out simple forward pass for TP or no model parallel models
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] Runs a very simple forward pass for model. Used in the case of models without
""" any parallelism or only tensor parallelism.
[encoder_tokens, decoder_tokens, encoder_mask, decoder_mask, encoder_decoder_mask] = (
inference_input Args:
) inference_input (Dict[str, Any]): A dict containg the inputs for the gpt
tokens = decoder_tokens model [tokens, position ids, attention mask]
# T5 inference not yet support kv_cache Returns:
logits = self.model( torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
encoder_tokens, """
decoder_tokens, encoder_tokens = inference_input["encoder_tokens"]
encoder_mask, decoder_tokens = inference_input["decoder_tokens"]
decoder_mask, encoder_mask = inference_input["encoder_mask"]
encoder_decoder_mask, decoder_mask = inference_input["decoder_mask"]
inference_params=None, encoder_decoder_mask = inference_input["encoder_decoder_mask"]
) tokens = decoder_tokens
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
# T5 inference not yet support kv_cache
return logits logits = self.model(
encoder_tokens,
decoder_tokens,
encoder_mask,
decoder_mask,
encoder_decoder_mask,
inference_params=None,
)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
return logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt). """Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to ModelOpt is a library comprising state-of-the-art model optimization techniques
compress model for efficient inference on NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless including quantization and sparsity to compress model for efficient inference on
experience for users to optimize their Megatron-core models for inference. More details on ModelOpt including NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
installation and usage can be found at https://github.com/NVIDIA/TensorRT-Model-Optimizer. experience for users to optimize their Megatron-core models for inference.
""" More details on ModelOpt including installation and usage can be found at
https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.transformer.enums import AttnMaskType from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
# Use this spec for ModelOpt PTQ and TensorRT-LLM export from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
def get_gpt_layer_modelopt_spec(
num_experts: int = None,
moe_grouped_gemm: bool = False, # Use this spec for ModelOpt PTQ and TensorRT-LLM export
remap_te_layernorm: bool = False, def get_gpt_layer_modelopt_spec(
qk_layernorm: bool = False, num_experts: Optional[int] = None,
) -> ModuleSpec: local_core_attention: bool = False,
"""Mix the native spec with TENorm. moe_grouped_gemm: bool = False,
remap_te_layernorm: bool = False,
This is essentially the native local spec except for the layernorm implementation qk_layernorm: bool = False,
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex ) -> ModuleSpec:
has stopped supporting RMSNorm needed by llama. """Mix the native spec with TENorm.
"""
mlp = _get_mlp_module_spec( This is essentially the native local spec except for the layernorm implementation
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
) has stopped supporting RMSNorm needed by llama.
sharded_state_dict_keys_map = {} """
if remap_te_layernorm: core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
if num_experts: mlp = get_mlp_module_spec(
sharded_state_dict_keys_map = { use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_' )
} sharded_state_dict_keys_map = {}
else: if remap_te_layernorm:
sharded_state_dict_keys_map = { if num_experts:
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', sharded_state_dict_keys_map = {
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
} }
return ModuleSpec( else:
module=TransformerLayer, sharded_state_dict_keys_map = {
submodules=TransformerLayerSubmodules( 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
input_layernorm=TENorm, 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
self_attention=ModuleSpec( }
module=SelfAttention, return ModuleSpec(
params={"attn_mask_type": AttnMaskType.causal}, module=TransformerLayer,
submodules=SelfAttentionSubmodules( submodules=TransformerLayerSubmodules(
linear_qkv=ColumnParallelLinear, input_layernorm=TENorm,
core_attention=TEDotProductAttention, self_attention=ModuleSpec(
linear_proj=RowParallelLinear, module=SelfAttention,
q_layernorm=TENorm if qk_layernorm else IdentityOp, params={"attn_mask_type": AttnMaskType.causal},
k_layernorm=TENorm if qk_layernorm else IdentityOp, submodules=SelfAttentionSubmodules(
), linear_qkv=ColumnParallelLinear,
), core_attention=core_attention,
self_attn_bda=get_bias_dropout_add, linear_proj=RowParallelLinear,
pre_mlp_layernorm=TENorm, q_layernorm=TENorm if qk_layernorm else IdentityOp,
mlp=mlp, k_layernorm=TENorm if qk_layernorm else IdentityOp,
mlp_bda=get_bias_dropout_add, ),
# Map TE-layernorm-fusion keys back ),
sharded_state_dict_keys_map=sharded_state_dict_keys_map, self_attn_bda=get_bias_dropout_add,
), pre_mlp_layernorm=TENorm,
) mlp=mlp,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
),
)
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_mamba_stack_modelopt_spec(
local_core_attention: bool = False, remap_te_layernorm: bool = False
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map = {}
transformer_state_dict_keys_map = {}
if remap_te_layernorm:
mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'}
transformer_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
mamba_layer = ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
norm=TENorm,
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=ColumnParallelLinear, out_proj=RowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=mamba_state_dict_keys_map,
),
)
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
attention_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
mlp_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
return ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class SamplingParams: class SamplingParams:
"""Inference parameters sent along with the prompts. """Inference parameters sent along with the prompts.
This class contains request-level attributes that control the sampling techniques used when This class contains request-level attributes that control the sampling techniques used when
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
inference attributes such as the maximum sequence length, and contains the KV cache. inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog For an explanation of these parameters refer to this blog
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and- https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
temperature-parameters-ed6a31313910 temperature-parameters-ed6a31313910
""" """
temperature: float = 1.0 temperature: float = 1.0
top_k: int = 0 top_k: int = 0
top_p: float = 0.0 top_p: float = 0.0
return_log_probs: bool = False return_log_probs: bool = False
num_tokens_to_generate: int = 30 return_segments: bool = False # Whether to return individually detokenized tokens
num_tokens_to_generate: int = 30
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to sampling params def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c = SamplingParams Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c.add_attributes({'min_length':4, 'eod_id':153}) c = SamplingParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and Args:
their values as the values. attribute_value_pair (dict): A dictionary containing attributes as the key names and
""" their values as the values.
for key, value in attribute_value_pair.items(): """
setattr(self, key, value) for key, value in attribute_value_pair.items():
setattr(self, key, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import time import functools
import typing import time
from collections import OrderedDict import typing
from typing import Dict from collections import OrderedDict
from typing import Dict, Optional, Type, Union
import torch
import torch
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.async_stream import AsyncStream
from megatron.core.inference.utils import Counter from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import Counter
class Scheduler:
"""Scheduler for handling requests to inference engine
class Scheduler:
This class is responsible for handing of all the incomign requests """Scheduler for handling requests to inference engine
Args: This class is responsible for handing of all the incomign requests
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time. Args:
""" max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
def __init__(self, max_batch_size: int): request_type (InferenceRequest): The class to use for instantiating new requests.
self.max_batch_size = max_batch_size """
self.active_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.waiting_request_pool: Dict[int, InferenceRequest] = OrderedDict() def __init__(self, max_batch_size):
self.completed_request_pool: Dict[int, InferenceRequest] = OrderedDict() self.max_batch_size = max_batch_size
self.request_counter = Counter() self.requests: Dict[str, InferenceRequest] = OrderedDict()
self.streams: Dict[str, AsyncStream] = OrderedDict()
def add_request( self.active_request_pool: Dict[str, InferenceRequest] = OrderedDict()
self, self.waiting_request_pool: Dict[str, InferenceRequest] = OrderedDict()
prompt: str, self.completed_request_pool: Dict[str, InferenceRequest] = OrderedDict()
prompt_tokens: torch.Tensor, self.request_counter = Counter()
encoder_prompt: str = None,
inference_parameters: SamplingParams = None, def get_new_request_id(self) -> str:
arrival_time: float = None, """Gets a new request id"""
): request_id = str(next(self.request_counter))
"""Add an incoming request return request_id
This method will add the request to either the active pool or the waiting pool def add_request(
depending on the batch size. self,
prompt: Optional[str] = None,
Args: prompt_tokens: Optional[torch.Tensor] = None,
prompt (str): Input prompt string encoder_prompt: Optional[str] = None,
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized inference_parameters: Optional[SamplingParams] = None,
encoder_prompt (str): Encoder input string arrival_time: Optional[float] = None,
inference_parameters (SamplingParams): The inference parameters streaming: bool = False,
arrival_time (float, optional): The incoming request time. Defaults to None. inference_request: Optional[InferenceRequest] = None,
""" ) -> str:
request_id = str(next(self.request_counter)) """Add an incoming request
if arrival_time is None: This method will add the request to either the active pool or the waiting pool
arrival_time = time.time() depending on the batch size.
status = ( Args:
Status.ACTIVE_BUT_NOT_GENERATING_TOKENS prompt (str): Input prompt string
if len(self.active_request_pool) < self.max_batch_size prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
else Status.WAITING_IN_QUEUE encoder_prompt (str): Encoder input string
) inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
inference_request = InferenceRequest( streaming (bool, optional): Whether to asynchronously stream tokens for this request.
request_id=request_id, inference_request (InferenceRequest, optional): A fully constructed request.
prompt=prompt, Defaults to None.
inference_parameters=inference_parameters,
arrival_time=arrival_time, Returns:
prompt_tokens=prompt_tokens, The request_id for the new request.
status=status, """
encoder_prompt=encoder_prompt, status = (
) Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
if len(self.active_request_pool) < self.max_batch_size
if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS: else Status.WAITING_IN_QUEUE
self.active_request_pool[request_id] = inference_request )
else:
self.waiting_request_pool[request_id] = inference_request if inference_request is None:
assert prompt is not None
def have_requests_pending(self) -> bool: assert prompt_tokens is not None
"""Method to check if there are requests pending
request_id = self.get_new_request_id()
This method returns False only when there are no active requests or waiting requests.
""" if arrival_time is None:
num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool) arrival_time = time.time()
return num_requests_pending > 0
inference_request = InferenceRequest(
def add_earliest_waiting_request_to_active_pool(self): request_id=request_id,
"""Utility to add the waiting request to active pool prompt=prompt,
inference_parameters=inference_parameters,
This method will add the earliest request (FIFO) that is in the waiting request arrival_time=arrival_time,
pool to the active request pool. prompt_tokens=prompt_tokens,
""" status=status,
assert ( encoder_prompt=encoder_prompt,
len(self.active_request_pool) < self.max_batch_size )
), "Active request pool is already full. Cant add any more requests" else:
if len(self.waiting_request_pool) > 0: request_id = inference_request.request_id
(earliest_waiting_request_request_id, earliest_waiting_request) = ( inference_request.status = status
self.waiting_request_pool.popitem(last=False) if inference_request.arrival_time is None:
) inference_request.arrival_time = time.time()
earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request self.requests[request_id] = inference_request
def update_requests_pools(self, result_dict: typing.OrderedDict[int, InferenceRequest] = None): if streaming:
"""Update request pool status abort_request = functools.partial(self.abort_request, request_id=request_id)
self.streams[request_id] = AsyncStream(request_id, abort_request)
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool. if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS:
If provided with a request dict, it will put the completed requests into the completed self.active_request_pool[request_id] = inference_request
request pool and add waiting request into active pool. else:
self.waiting_request_pool[request_id] = inference_request
Args:
result (typing.OrderedDict[int, InferenceRequest], optional): The result returned return request_id
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None def have_requests_pending(self) -> bool:
""" """Method to check if there are requests pending
for result_request_id in list(result_dict.keys()):
active_request = self.active_request_pool[result_request_id] This method returns False only when there are no active requests or waiting requests.
"""
# If a request has completed put it into the completed request pool. num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool)
if active_request.status == Status.COMPLETED: return num_requests_pending > 0
completed_request = self.active_request_pool.pop(result_request_id)
self.completed_request_pool[result_request_id] = completed_request def add_earliest_waiting_request_to_active_pool(self):
"""Utility to add the waiting request to active pool
# If the active request pool is not full, add waiting requests in FIFO order
while ( This method will add the earliest request (FIFO) that is in the waiting request
len(self.active_request_pool) < self.max_batch_size pool to the active request pool.
and len(self.waiting_request_pool) > 0 """
): assert (
self.add_earliest_waiting_request_to_active_pool() len(self.active_request_pool) < self.max_batch_size
), "Active request pool is already full. Cant add any more requests"
if len(self.waiting_request_pool) > 0:
(earliest_waiting_request_request_id, earliest_waiting_request) = (
self.waiting_request_pool.popitem(last=False)
)
earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request
def update_requests_pools(
self, result_dict: Optional[typing.OrderedDict[str, InferenceRequest]] = None
):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed
request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[str, InferenceRequest], optional): The result returned
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
"""
for result_request_id in list(result_dict.keys()):
active_request = self.active_request_pool[result_request_id]
# If a request has completed put it into the completed request pool.
if active_request.status == Status.COMPLETED:
completed_request = self.active_request_pool.pop(result_request_id)
self.completed_request_pool[result_request_id] = completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while (
len(self.active_request_pool) < self.max_batch_size
and len(self.waiting_request_pool) > 0
):
self.add_earliest_waiting_request_to_active_pool()
def abort_request(
self,
request_id: str,
*,
exception: Optional[Union[BaseException, Type[BaseException]]] = None
):
"""Cancels the given request"""
stream = self.streams.get(request_id, None)
if stream is not None:
stream.finish(exception=exception)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import OrderedDict from typing import Any, Dict, OrderedDict
import torch import torch
from megatron.core.inference.inference_request import InferenceRequest from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import ( from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController, TextGenerationController,
) )
class EncoderDecoderTextGenerationController(TextGenerationController): class EncoderDecoderTextGenerationController(TextGenerationController):
"""The text generation controller for encoder-decoder architecture """The text generation controller for encoder-decoder architecture
This class inherits from TextGenerationController, adding features This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt relating to encoder input encoder_prompt
""" """
def prep_model_for_inference( def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
): ) -> Dict[str, Any]:
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method """Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args: Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests active_requests (OrderedDict[str, InferenceRequest]): The input active requests
"""
encoder_prompts = list( Returns:
map(lambda request: request.encoder_prompt, active_requests.values()) A dict of the inference input for the current batch.
) """
encoder_prompts = list(
self.inference_wrapped_model.prep_model_for_inference( map(lambda request: request.encoder_prompt, active_requests.values())
prompts_tokens=prompts_tokens, encoder_prompts=encoder_prompts, tokenizer=self.tokenizer )
)
return self.inference_wrapped_model.prep_inference_input(
prompts_tokens, encoder_prompts, tokenizer=self.tokenizer
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import
TextGenerationController as SimpleTextGenerationController, TextGenerationController as SimpleTextGenerationController,
) )
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, OrderedDict, Tuple import concurrent
import copy
import torch import functools
import torch.nn.functional as F from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
from megatron.core import parallel_state import torch
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage import torch.nn.functional as F
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( from megatron.core import parallel_state
AbstractModelInferenceWrapper, from megatron.core.inference.async_stream import AsyncStream
) from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
class TextGenerationController: )
"""The text generation controller (the main sampling loop) from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.transformer.cuda_graphs import create_cudagraphs
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output. from megatron.core.utils import get_model_config
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that class TextGenerationController:
is wrapped using the specs given in the abstract_model_inference_wrapper.py """The text generation controller (the main sampling loop)
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
""" This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer): Args:
self.inference_wrapped_model = inference_wrapped_model inference_wrapped_model (AbstractModelInferenceWrapper): A model that
self.tokenizer = tokenizer is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True """
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
) self.inference_wrapped_model = inference_wrapped_model
self.tokenizer = tokenizer
def tokenize_prompt(
self, prompt: str, add_BOS: bool = False # For models without pipeline parallelism, is_first_stage and is_last_stage returns True
) -> Tuple[torch.Tensor, torch.Tensor]: self.model_is_pipeline_parallel = not (
"""Utility to tokenize the input prompts parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
Args:
prompt (str): The input prompt def tokenize_prompt(
self, prompt: str, add_BOS: bool = False
Returns: ) -> Tuple[torch.Tensor, torch.Tensor]:
torch.Tensor: Returns the tokenized prompt """Utility to tokenize the input prompts
"""
prompt_tokens = self.tokenizer.tokenize(prompt) Args:
prompt (str): The input prompt
if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens Returns:
torch.Tensor: Returns the tokenized prompt
return prompt_tokens """
prompt_tokens = self.tokenizer.tokenize(prompt)
def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
"""Detokenize the output generations if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
Args:
prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt return prompt_tokens
tokens plus the generated tokens
def detokenize_generations(
Returns: self,
str: The detokenized output tokens_gpu_tensor: torch.Tensor,
""" lengths_gpu_tensor: torch.Tensor,
tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist() detokenize_segments: bool,
return self.tokenizer.detokenize(tokens) ) -> tuple[str, Optional[List[List[str]]]]:
"""Detokenize the generated tokens.
def sample_from_logits(
self, Args:
last_token_logits: torch.Tensor, tokens_gpu_tensor (torch.Tensor): Tensor containing the tokens
sampling_params: SamplingParams = None, lengths_gpu_tensor (torch.Tensor): Tensor containing the lengths of each sequence
vocab_size: int = None, detokenize_segments (bool): If True, returns individually detokenized tokens. If False,
**kwargs returns None as second element. Helpful for understanding per-token boundaries in
) -> torch.Tensor: generated text.
"""Samples the logits to generate outputs
Returns:
Given the logits of the last token, this function samples it tuple[str, List[str] | None]: A tuple containing:
according to the parameters defined in sampling_params - str: The complete detokenized text
and returns the samples - List[str] | None: List of segmented tokens if detokenize_segments is True, else None
"""
Args: # TODO(helenn): Unify with `detokenize_generations` from legacy textgen path
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size] if not detokenize_segments:
sampling_params (SamplingParams): The parameters to use for inference. tokens = tokens_gpu_tensor.cpu().numpy().tolist()
vocab_size (int): Obtained from the tokenizer. Defaults to None return self.tokenizer.detokenize(tokens), None
Returns: prompts_plus_generations: List[str] = []
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements prompts_plus_generations_segments: List[List[str]] = []
"""
tokens_gpu_tensor = torch.unsqueeze(tokens_gpu_tensor, 0)
if kwargs.get('common_inference_params'): tokens = tokens_gpu_tensor.cpu().numpy().tolist()
sampling_params = kwargs['common_inference_params'] lengths = lengths_gpu_tensor.cpu().numpy().tolist()
top_p = sampling_params.top_p for sequence_tokens, length in zip(tokens, lengths):
top_k = sampling_params.top_k sequence_tokens = sequence_tokens[:length]
temperature = sampling_params.temperature detok_str = self.tokenizer.detokenize(sequence_tokens)
prompts_plus_generations.append(detok_str)
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero' offsets = self.tokenizer.offsets(sequence_tokens, detok_str)
assert top_p <= 1.0, 'top-p should be in (0,1]' words = [
detok_str[start:end] for start, end in zip(offsets, offsets[1:] + [len(detok_str)])
def modify_logits_for_top_k_filtering(logits, top_k): ]
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] prompts_plus_generations_segments.append(words)
logits.masked_fill_(filter_, float('-Inf'))
text = self.tokenizer.detokenize(tokens[0])
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf.""" return text, prompts_plus_generations_segments
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True) def sample_from_logits(
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) self,
last_token_logits: torch.Tensor,
# Filteration based on the cumulative sum. sampling_params: Optional[SamplingParams] = None,
filter_ = cumulative_probs > top_p vocab_size: Optional[int] = None,
# This shift by 1 is weird and I cannot justify it. This existed **kwargs,
# in the original implementation: ) -> torch.Tensor:
# https://github.com/ari-holtzman/degen/blob/master/gen.py """Samples the logits to generate outputs
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone() Given the logits of the last token, this function samples it
# Make sure we at least have one token to select from. according to the parameters defined in sampling_params
filter_[..., 0] = 0 and returns the samples
# Fill in the filtered part Args:
filter_ = filter_.scatter(1, sorted_indices, filter_) last_token_logits (torch.Tensor): The last token logits. A tensor of
logits.masked_fill_(filter_, float('-Inf')) size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
# Greedy sampling vocab_size (int): Obtained from the tokenizer. Defaults to None
if top_k == 1:
sampled_logits = torch.argmax(last_token_logits, dim=-1) Returns:
else: torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
last_token_logits = last_token_logits.clone() """
if temperature != 1.0:
last_token_logits.div_(temperature) if kwargs.get('common_inference_params'):
sampling_params = kwargs['common_inference_params']
if top_k > 1:
assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.' top_p = sampling_params.top_p
if vocab_size: top_k = sampling_params.top_k
assert top_k < vocab_size, 'top-k is larger than vocab size.' temperature = sampling_params.temperature
modify_logits_for_top_k_filtering(last_token_logits, top_k)
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
elif top_p > 0.0: assert top_p <= 1.0, 'top-p should be in (0,1]'
modify_logits_for_top_p_filtering(last_token_logits, top_p)
def modify_logits_for_top_k_filtering(logits, top_k):
# After filtering, we need to recalculate the distribution. """Set the logits for none top-k values to -inf."""
probabilities = last_token_logits.softmax(dim=-1) filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1) logits.masked_fill_(filter_, float('-Inf'))
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size). def modify_logits_for_top_p_filtering(logits, top_p):
if vocab_size: """Set the logits for none top-p values to -inf."""
sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1)) # First sort and calculate cumulative sum of probabilities.
return sampled_logits sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
def update_generation_status(
self, # Filteration based on the cumulative sum.
updated_prompts_tokens: torch.Tensor, filter_ = cumulative_probs > top_p
generation_started: torch.Tensor, # This shift by 1 is weird and I cannot justify it. This existed
current_context_end_position: int, # in the original implementation:
is_generation_done_tensor: torch.Tensor, # https://github.com/ari-holtzman/degen/blob/master/gen.py
generated_sequence_lengths: torch.Tensor, # and I guess it is needed so keeping it for now.
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: filter_[:, 1:] = filter_[:, :-1].clone()
"""Checks which prompts have reached an end condition # Make sure we at least have one token to select from.
filter_[..., 0] = 0
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths # Fill in the filtered part
increase as we keep generating, until that prompts hits an end condition. The filter_ = filter_.scatter(1, sorted_indices, filter_)
generation_started tensor determines which prompts have started generating. logits.masked_fill_(filter_, float('-Inf'))
Args: # Greedy sampling
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest if top_k == 1:
generated tokens. A tensor of shape [batch_size, max_seq_len] sampled_logits = torch.argmax(last_token_logits, dim=-1)
(i.e max_seq_len = max_prompt_len + tokens_to_generate) else:
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True last_token_logits = last_token_logits.clone()
indicates the prompt at that index has started generating tokens. if temperature != 1.0:
current_context_end_position (int): An integer indicating which position to last_token_logits.div_(temperature)
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size]. if top_k > 1:
True indicates the prompt at that index has reached end condition. assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size]. if vocab_size:
Each value represents the generated sequence lengths for that prompt. assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(last_token_logits, top_k)
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean elif top_p > 0.0:
is_generation_done_tensor and the generated_sequence_lengths after updating it modify_logits_for_top_p_filtering(last_token_logits, top_p)
"""
latest_samples = updated_prompts_tokens[:, current_context_end_position] # After filtering, we need to recalculate the distribution.
# Make sure we are checking eod criterion only for prompts that have started generating probabilities = last_token_logits.softmax(dim=-1)
# (i.e) We only look at the generated tokenns and not the input tokens. sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
is_generation_done_tensor = is_generation_done_tensor | reached_eod # If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
# We increment generated sequence lengths when that prompt has not hit the if vocab_size:
# EOD and generation has started sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
generated_sequence_lengths += ~is_generation_done_tensor & generation_started return sampled_logits
return is_generation_done_tensor, generated_sequence_lengths def update_generation_status(
self,
def pad_input_prompt_tokens( updated_prompts_tokens: torch.Tensor,
self, generation_started: torch.Tensor,
batch_prompt_tokens_list: List[List[int]], current_context_end_position: int,
max_prompt_length_in_batch: int, is_generation_done_tensor: torch.Tensor,
num_tokens_to_generate: int, generated_sequence_lengths: torch.Tensor,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to pad input prompts """Checks which prompts have reached an end condition
Given a list of prompts, pad them all to uniform length We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
Args: increase as we keep generating, until that prompts hits an end condition. The
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens generation_started tensor determines which prompts have started generating.
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
Returns: generated tokens. A tensor of shape [batch_size, max_seq_len]
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e) (i.e max_seq_len = max_prompt_len + tokens_to_generate)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate, generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
with extra indices for each tensor padded with mask id. indicates the prompt at that index has started generating tokens.
""" current_context_end_position (int): An integer indicating which position to
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
for prompt_tokens in batch_prompt_tokens_list: True indicates the prompt at that index has reached end condition.
padding_size = max_seq_len - len(prompt_tokens) generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
prompt_tokens.extend([self.tokenizer.eod] * padding_size) Each value represents the generated sequence lengths for that prompt.
return torch.tensor(batch_prompt_tokens_list).cuda() Returns:
Tuple[torch.Tensor, torch.Tensor]: Returns the boolean
def generate_output_tokens_dynamic_batch( is_generation_done_tensor and the generated_sequence_lengths after updating it
self, active_requests: OrderedDict[int, InferenceRequest] """
) -> OrderedDict[int, InferenceRequest]: latest_samples = updated_prompts_tokens[:, current_context_end_position]
"""Utility to generate the output tokens and probabilities for the prompts # Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
This utility generates the output tokens for a dynamic batch. It will run one forward step reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
at a time, and pass control back to the engine, which will update the request pool and call is_generation_done_tensor = is_generation_done_tensor | reached_eod
this method again. # We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
Args: generated_sequence_lengths += ~is_generation_done_tensor & generation_started
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
return is_generation_done_tensor, generated_sequence_lengths.int()
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests def pad_input_prompt_tokens(
after running one forward step. self,
""" batch_prompt_tokens_list: List[List[int]],
raise Exception("Not implemented yet") max_prompt_length_in_batch: int,
num_tokens_to_generate: int,
def generate_all_output_tokens_static_batch( ) -> torch.Tensor:
self, active_requests: OrderedDict[int, InferenceRequest] """Method to pad input prompts
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the all the output tokens and probabilities for the prompts . Given a list of prompts, pad them all to uniform length
This utility generates the output tokens for a static batch. It runs the forward steps till Args:
all prompts complete generation, updates the status of these requests to completed, adds batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
the generated result and returns these requests max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests. Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
Returns: max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests """
""" max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens_list = list(
map(lambda request: request.prompt_tokens, active_requests.values()) for prompt_tokens in batch_prompt_tokens_list:
) padding_size = max_seq_len - len(prompt_tokens)
prompt_lengths_in_batch = torch.tensor( prompt_tokens.extend([self.tokenizer.eod] * padding_size)
[len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list]
).cuda() return torch.tensor(batch_prompt_tokens_list, device=torch.cuda.current_device())
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch) def generate_output_tokens_dynamic_batch(
self, active_requests: OrderedDict[str, InferenceRequest]
# For batch inference the inference params are the same for all request ) -> OrderedDict[str, InferenceRequest]:
sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters """Utility to generate the output tokens and probabilities for the prompts
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate This utility generates the output tokens for a dynamic batch. It will run one forward step
batch_prompt_tokens = self.pad_input_prompt_tokens( at a time, and pass control back to the engine, which will update the request pool and call
batch_prompt_tokens_list, this method again.
max_prompt_length_in_batch=max_prompt_length_in_batch,
num_tokens_to_generate=sampling_params.num_tokens_to_generate, Args:
) active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
batch_size, max_sequence_length = batch_prompt_tokens.shape
Returns:
# Pre allocate log probs tensor OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
output_log_probs = None after running one forward step.
if sampling_params.return_log_probs: """
output_log_probs = torch.empty( raise Exception("Not implemented yet")
(batch_size, max_sequence_length - 1), dtype=torch.float32
).cuda() def generate_all_output_tokens_static_batch(
self,
# An array to check which of the prompts have reached end of generation condition active_requests: OrderedDict[str, InferenceRequest],
is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda() active_streams: Optional[OrderedDict[str, AsyncStream]] = None,
) -> OrderedDict[str, InferenceRequest]:
# An array to act as a counter to keep track of generated sequence lengths """Utility to generate the all the output tokens and probabilities for the prompts .
generated_sequence_lengths = torch.zeros(batch_size).cuda()
This utility generates the output tokens for a static batch. It runs the forward steps till
with torch.no_grad(): all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
self.prep_model_for_inference(
prompts_tokens=batch_prompt_tokens, active_requests=active_requests Args:
) active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
context_start_position = 0 Returns:
# Pick the context window that we need to pass through the network. OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
for context_end_position in range(min_prompt_length_in_batch, max_sequence_length): """
assert all(request.prompt_tokens is not None for request in active_requests.values())
inference_input = self.inference_wrapped_model.get_batch_for_context_window(
context_start_position, context_end_position # Perform a deep copy so that the request prompt tokens do not get modified.
) batch_prompt_tokens_list: List[List[int]] = list(
map(
# Returns the final logits of shape [batch_size, context_length, vocab_size] lambda request: copy.deepcopy(request.prompt_tokens), # type: ignore[arg-type]
# Note: This is returned in all TP ranks or last PP stage in PP models active_requests.values(),
logits = self.inference_wrapped_model.run_one_forward_step(inference_input) )
if self.model_is_pipeline_parallel: )
context_length = context_end_position - context_start_position prompt_lengths_in_batch = torch.tensor(
logits = broadcast_from_last_pipeline_stage( [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list],
[batch_size, context_length, self.tokenizer.vocab_size], device=torch.cuda.current_device(),
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype, )
tensor=logits, max_prompt_length_in_batch = max(prompt_lengths_in_batch)
) min_prompt_length_in_batch = min(prompt_lengths_in_batch)
# Indicates which of the input prompts have started generating tokens. # For batch inference the inference params are the same for all request
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters
# prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
last_token_logits = logits[:, -1, :] batch_prompt_tokens = self.pad_input_prompt_tokens(
sampled_logits = self.sample_from_logits( batch_prompt_tokens_list,
last_token_logits, sampling_params, self.tokenizer.vocab_size max_prompt_length_in_batch=max_prompt_length_in_batch,
) num_tokens_to_generate=sampling_params.num_tokens_to_generate,
)
# Substitute the sampled logits only for only the prompts that batch_size, max_sequence_length = batch_prompt_tokens.shape
# have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[ # Verify that output sequence length is within configured limit
generation_started # TODO(ksanthanam): Raise TokenOverflowError once !2518 is merged
] inference_max_sequence_length = (
self.inference_wrapped_model.inference_wrapper_config.inference_max_seq_length
if sampling_params.return_log_probs: )
log_probs = F.log_softmax(logits, dim=2) assert max_sequence_length <= inference_max_sequence_length, (
indices = torch.unsqueeze( f"Maximum allowed sequence length was set to {inference_max_sequence_length} tokens "
batch_prompt_tokens[ f"but requested generation of {max_sequence_length} tokens"
:, (context_start_position + 1) : (context_end_position + 1) )
],
2, # Pre allocate log probs tensor
) output_log_probs = None
# Get the log probabilities for only the prompt tokens if sampling_params.return_log_probs:
output_log_probs[:, context_start_position:context_end_position] = torch.gather( output_log_probs = torch.empty(
log_probs, 2, indices (batch_size, max_sequence_length - 1),
).squeeze(2) dtype=torch.float32,
device=torch.cuda.current_device(),
context_start_position = context_end_position )
# Check end of generation status for each tensor # An array to check which of the prompts have reached end of generation condition
# and update generated sequence lengths is_generation_done_tensor = torch.zeros(
(is_generation_done_tensor, generated_sequence_lengths) = ( batch_size, dtype=torch.bool, device=torch.cuda.current_device()
self.update_generation_status( )
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started, # An array to act as a counter to keep track of generated sequence lengths
current_context_end_position=context_end_position, generated_sequence_lengths = torch.zeros(
is_generation_done_tensor=is_generation_done_tensor, batch_size, device=torch.cuda.current_device()
generated_sequence_lengths=generated_sequence_lengths, ).cuda()
)
) # Use padded vocab size because tokenizer vocab size might not include padding
# Boolean flag indicating if all prompts are finished # to nearest power of 2
all_prompts_done = torch.all(is_generation_done_tensor) vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size
if all_prompts_done:
break # Check whether CUDA graphs are enabled
enable_cuda_graph = get_model_config(self.inference_wrapped_model.model).enable_cuda_graph
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)] streaming_enabled = active_streams is not None and len(active_streams) > 0
if sampling_params.return_log_probs: if streaming_enabled:
output_log_probs = output_log_probs[:, :context_end_position] # Start a separate thread for streaming tokens to avoid blocking the
# main computation
generated_sequence_lengths[ streaming_idx: List[int] = [
generated_sequence_lengths > sampling_params.num_tokens_to_generate i
] = sampling_params.num_tokens_to_generate for (i, request_id) in enumerate(active_requests.keys())
if request_id in active_streams
for idx, request in enumerate(active_requests.values()): ]
input_prompt_length = int(prompt_lengths_in_batch[idx]) streaming_request_ids: List[str] = list(active_streams.keys())
# Shorter prompts might have generated more than required tokens. So we trim them down streams: List[AsyncStream] = list(active_streams.values())
required_sequence_length = int( streaming_requests: List[InferenceRequest] = [
min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate) active_requests[request_id] for request_id in streaming_request_ids
) ]
# Extract only the generated tokens streaming_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
required_result_tokens = batch_prompt_tokens_with_generations[ stream_tokens = functools.partial(self.stream_tokens, sampling_params)
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
] with torch.no_grad():
request.generated_length = required_sequence_length self.inference_wrapped_model.prep_model_for_inference(
request.generated_tokens = required_result_tokens prompts_tokens=batch_prompt_tokens
request.generated_log_probs = ( )
None
if output_log_probs is None inference_input: Dict[str, Any] = self.prep_inference_input(
else output_log_probs[idx, input_prompt_length:required_sequence_length] prompts_tokens=batch_prompt_tokens, active_requests=active_requests
) )
request.status = Status.COMPLETED
request.generated_text = self.detokenize_generations(required_result_tokens) assert (
not self.inference_wrapped_model.inference_params.decode_mode
return active_requests ), f"Generation must start in prefill mode"
def prep_model_for_inference( context_start_position = 0
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] # Pick the context window that we need to pass through the network.
): for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
inference_input_for_context_window: Dict[str, Any] = (
Args: self.inference_wrapped_model.get_batch_for_context_window(
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] inference_input, context_start_position, context_end_position
active_requests (OrderedDict[int, InferenceRequest]): The input active requests )
""" )
self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens)
# Disable attention mask when using CUDA graphs for decode
if (
enable_cuda_graph
and self.inference_wrapped_model.inference_params.decode_mode
and "attention_mask" in inference_input_for_context_window
):
inference_input_for_context_window["attention_mask"] = None
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(
inference_input_for_context_window
)
if enable_cuda_graph:
create_cudagraphs()
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, vocab_size],
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
tensor=logits,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, sampling_params, vocab_size
)
# Substitute the sampled logits only for the prompts that
# have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
generation_started
]
if sampling_params.return_log_probs:
log_probs = F.log_softmax(logits, dim=2)
indices = torch.unsqueeze(
batch_prompt_tokens[
:, (context_start_position + 1) : (context_end_position + 1)
],
2,
)
# Get the log probabilities for only the prompt tokens
assert output_log_probs is not None
output_log_probs[:, context_start_position:context_end_position] = torch.gather(
log_probs, 2, indices
).squeeze(2)
context_start_position = context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(is_generation_done_tensor, generated_sequence_lengths) = (
self.update_generation_status(
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started,
current_context_end_position=context_end_position,
is_generation_done_tensor=is_generation_done_tensor,
generated_sequence_lengths=generated_sequence_lengths,
)
)
# Stream intermediate outputs
if streaming_enabled:
streaming_executor.submit(
stream_tokens,
streaming_request_ids,
streaming_requests,
streams,
generation_started[streaming_idx].cpu(),
is_generation_done_tensor[streaming_idx].cpu(),
batch_prompt_tokens[streaming_idx].cpu(),
prompt_lengths_in_batch[streaming_idx].cpu(),
generated_sequence_lengths[streaming_idx].cpu(),
(
output_log_probs[streaming_idx].cpu()
if output_log_probs is not None
else [None] * len(streaming_idx)
),
)
# Boolean flag indicating if all prompts are finished
all_prompts_done = torch.all(is_generation_done_tensor)
if all_prompts_done:
break
# Change to decode mode if all prefill is complete
if torch.all(generation_started):
self.inference_wrapped_model.inference_params.enable_decode_mode()
# Close all streams
if streaming_enabled:
streaming_executor.shutdown()
for stream in streams:
stream.finish()
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if sampling_params.return_log_probs:
assert output_log_probs is not None
output_log_probs = output_log_probs[:, :context_end_position]
generated_sequence_lengths[
generated_sequence_lengths > sampling_params.num_tokens_to_generate
] = sampling_params.num_tokens_to_generate
for idx, request in enumerate(active_requests.values()):
input_prompt_length = int(prompt_lengths_in_batch[idx])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length = int(
min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate)
)
# Extract only the generated tokens
required_result_tokens = batch_prompt_tokens_with_generations[
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
]
generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32)
request.generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32)
request.generated_length = required_sequence_length
request.generated_tokens = required_result_tokens
request.prompt_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, :input_prompt_length].cpu().numpy().tolist()
)
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[
idx,
input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1),
]
.cpu()
.numpy()
.tolist()
)
request.status = Status.COMPLETED
text, segments = self.detokenize_generations(
batch_prompt_tokens_with_generations[idx],
input_prompt_length + generated_sequence_lengths,
sampling_params.return_segments,
)
request.text = text # Inference server returns prompts & generations together
if sampling_params.return_segments:
request.segments = segments[0]
request.generated_text = text[len(request.prompt) :]
return active_requests
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
) -> Dict[str, Any]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
return self.inference_wrapped_model.prep_inference_input(prompts_tokens)
def stream_tokens(
self,
sampling_params: SamplingParams,
request_ids: List[str],
requests: List[InferenceRequest],
streams: List[AsyncStream],
generation_started: List[bool],
is_generation_done: List[bool],
tokens: torch.Tensor,
prompt_lengths: List[int],
generated_lengths: List[int],
output_log_probs: Union[torch.Tensor, None],
):
"""Asynchronously streams tokens for the given requests.
Args:
sampling_params (SamplingParams): The sampling parameters.
request_ids (List[str]): The request IDs.
request (List[InferenceRequest]): The requests.
stream (List[AsyncStream]): The streams over which to send tokens.
generation_started (List[bool]): Whether the decode step has started.
is_generation_done (List[bool]): Whether generation has completed.
tokens (torch.Tensor): The tokens for this request.
prompt_lengths (List[int]): The number of prompt tokens for each request.
generated_lengths (List[int]): The number of output tokens for each request.
output_log_probs (torch.Tensor, optional): The log probs for each request.
"""
def stream_token(
request_id: str,
request: InferenceRequest,
stream: AsyncStream,
generation_started: bool,
is_generation_done: bool,
tokens: torch.Tensor,
prompt_length: int,
generated_length: int,
output_log_probs: Union[torch.Tensor, None],
):
"""Asynchronously streams a token for the given request."""
if not generation_started or stream.finished:
return
num_tokens_to_generate = sampling_params.num_tokens_to_generate
return_segments = sampling_params.return_segments
detokenize_streaming_text = not getattr(
sampling_params, "no_detokenize_streaming_text", False
)
generated_tokens = tokens[prompt_length : prompt_length + generated_length]
if detokenize_streaming_text:
generated_text, generated_segments = self.detokenize_generations(
generated_tokens, prompt_length + generated_length, return_segments
)
else:
generated_text = ""
generated_segments = []
if output_log_probs is not None:
generated_log_probs = (
output_log_probs[prompt_length - 1 : prompt_length + generated_length - 1]
.cpu()
.numpy()
.tolist()
)
else:
generated_log_probs = None
stream.put(
InferenceRequest(
request_id=request_id,
prompt=request.prompt,
inference_parameters=request.inference_parameters,
prompt_tokens=request.prompt_tokens,
arrival_time=request.arrival_time,
status=request.status,
encoder_prompt=request.encoder_prompt,
generated_text=generated_text,
generated_segments=generated_segments,
generated_tokens=generated_tokens,
generated_log_probs=generated_log_probs,
generated_length=generated_length,
)
)
if is_generation_done or generated_length == num_tokens_to_generate:
stream.finish()
ret = map(
stream_token,
request_ids,
requests,
streams,
generation_started,
is_generation_done,
tokens,
prompt_lengths,
generated_lengths,
output_log_probs,
)
list(ret)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import OrderedDict
import torch
from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class VLMTextGenerationController(TextGenerationController):
"""The text generation controller for VLMs"""
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
):
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Currently only supports batch size 1 inference.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
"""
assert len(active_requests) == 1, f"VLM inference currently only supports batch size 1"
request = list(active_requests.values())[0]
assert isinstance(
request, VLMInferenceRequest
), f"Found inference request of type {type(request)}, expected VLMInferenceRequest"
return self.inference_wrapped_model.prep_inference_input(
prompts_tokens,
request.num_img_embeddings_per_tile,
request.imgs,
request.num_tiles,
request.decoder_seq_length,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class InferenceParams: class InferenceParams:
"""Inference parameters that are passed to the main model in order """Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.""" to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_length): def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.sequence_len_offset = 0 self.current_batch_size = max_batch_size # Required for bookkeeping variable-sized batches
self.batch_size_offset = 0 self.sequence_len_offset = 0
self.key_value_memory_dict = {} self.batch_size_offset = 0
self.decode_mode = False
def swap_key_value_dict(self, batch_idx): self.key_value_memory_dict = {}
"swap between batches" self.decode_mode = False
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty") def swap_key_value_dict(self, batch_idx):
"swap between batches"
for layer_number in self.key_value_memory_dict.keys(): if len(self.key_value_memory_dict) == 0:
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] raise ValueError("should not swap when dict in empty")
assert (
len(batch_idx) == inference_key_memory.shape[1] for layer_number in self.key_value_memory_dict.keys():
) # make sure batch size is the same inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
new_inference_key_memory = inference_key_memory[:, batch_idx] assert (
new_inference_value_memory = inference_value_memory[:, batch_idx] len(batch_idx) == inference_key_memory.shape[1]
self.key_value_memory_dict[layer_number] = ( ) # make sure batch size is the same
new_inference_key_memory, new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory, new_inference_value_memory = inference_value_memory[:, batch_idx]
) self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
def __str__(self): new_inference_value_memory,
return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})" )
def enable_prefill_mode(self):
"""
Indicates the generation loop is in the prefill phase (still processing
input prompt tokens). This should be enabled if the generation loop is
encoding prompt tokens for *any* request in a batch.
"""
self.decode_mode = False
def enable_decode_mode(self):
"""
Indicates the generation loop is in the decode phase (generating new output
tokens). This should only be enabled if the generation loop has fully encoded
the prompts for *all* requests in a batch.
"""
self.decode_mode = True
def reset(self):
"""Resets the inference state for a new batch."""
self.current_batch_size = self.max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.enable_prefill_mode()
def __str__(self):
return (
f"InferenceParams(max_seq_len = {self.max_sequence_length}, "
f"max_batch_size = {self.max_batch_size}, "
f"current_batch_size = {self.current_batch_size}, "
f"sequence_len_offset = {self.sequence_len_offset}, "
f"batch_size_offset = {self.batch_size_offset}, "
f"key_value_memory_dict = {self.key_value_memory_dict.keys()})"
f"decode_mode = {self.decode_mode}"
)
def __eq__(self, other):
if not isinstance(other, InferenceParams):
return False
# Check all attributes match
basic_attrs = [
'max_sequence_length',
'max_batch_size',
'current_batch_size',
'sequence_len_offset',
'batch_size_offset',
]
if not all(hasattr(other, attr) for attr in basic_attrs):
return False
# Check dictionary keys match; i.e. the same number of layers are cached
if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys():
return False
# Check each tensor tuple in the dictionary
for key in self.key_value_memory_dict:
self_tensors = self.key_value_memory_dict[key]
other_tensors = other.key_value_memory_dict[key]
# Compare each key, value tensor in the tuple
for self_tensor, other_tensor in zip(self_tensors, other_tensors):
if (
self_tensor.data_ptr() != other_tensor.data_ptr()
or self_tensor.shape != other_tensor.shape
):
return False
return True
...@@ -7,18 +7,4 @@ from megatron.core.utils import is_torch_min_version ...@@ -7,18 +7,4 @@ from megatron.core.utils import is_torch_min_version
jit_fuser = torch.jit.script jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2 # nvFuser is deprecated in PyTorch JIT starting from 2.2
if is_torch_min_version("2.2.0a0"): if is_torch_min_version("2.2.0a0"):
jit_fuser = torch.compile(mode='max-autotune-no-cudagraphs') jit_fuser = torch.compile
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, ContextManager, Optional from typing import Callable, ContextManager, Optional
import torch import torch
@dataclass @dataclass
class ModelParallelConfig: class ModelParallelConfig:
"""Base configuration for Megatron Core """Base configuration for Megatron Core
The initialization function has an argument for each parameter. The initialization function has an argument for each parameter.
""" """
################### ###################
# Model parallelism # Model parallelism
################### ###################
tensor_model_parallel_size: int = 1 tensor_model_parallel_size: int = 1
"""Intra-layer model parallelism. Splits tensors across GPU ranks.""" """Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_size: int = 1 pipeline_model_parallel_comm_backend: Optional[str] = None
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks.""" """Configuring backend option of pipeline parallel communication (e.g., nccl, ucc)
If None, the default backend will be used.
virtual_pipeline_model_parallel_size: Optional[int] = None """
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks. pipeline_model_parallel_size: int = 1
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel """Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details. virtual_pipeline_model_parallel_size: Optional[int] = None
""" """Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
sequence_parallel: bool = False The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models arxiv.org/pdf/2104.04473.pdf for more details.
(https://arxiv.org/abs/2205.05198) for more details. """
"""
sequence_parallel: bool = False
context_parallel_size: int = 1 """Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
"""Splits network input along sequence dimension across GPU ranks.""" and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
hierarchical_context_parallel_sizes: Optional[list[int]] = None """
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains context_parallel_size: int = 1
groups of two levels, so the first value of the list indicates the group size of the a2a """Splits network input along sequence dimension across GPU ranks."""
communication type, and the second value indicates the group size of the p2p communication
type. hierarchical_context_parallel_sizes: Optional[list[int]] = None
""" """Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
expert_model_parallel_size: int = 1 groups of two levels, so the first value of the list indicates the group size of the a2a
"""Distributes Moe Experts across sub data parallel dimension.""" communication type, and the second value indicates the group size of the p2p communication
type.
expert_tensor_parallel_size: Optional[int] = None """
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
expert_model_parallel_size: int = 1
moe_extended_tp: bool = False """Distributes Moe Experts across sub data parallel dimension."""
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size. expert_tensor_parallel_size: Optional[int] = None
""" """Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
################### moe_extended_tp: bool = False
# Initialization """NOTE: Deprecated from MCore v0.10. This flag is ignored.
################### Its functionality is replaced by expert_tensor_parallel_size.
perform_initialization: bool = True """
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint. ###################
""" # Initialization
###################
use_cpu_initialization: bool = False perform_initialization: bool = True
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the """If true, weights are initialized. This option can be useful when you know you are going to
same regardless of tensor model parallelism, but GPU initialization is not. Transferring load values from a checkpoint.
weights from CPU to GPU can take a significant amount of time for large models. """
"""
use_cpu_initialization: bool = False
################### """When set to False, we initialize the weights directly on the GPU. CPU initialization is the
# Training same regardless of tensor model parallelism, but GPU initialization is not. Transferring
################### weights from CPU to GPU can take a significant amount of time for large models.
fp16: bool = False """
"""If true, train with fp16 mixed precision training."""
###################
bf16: bool = False # Training
"""If true, train with bf16 mixed precision training.""" ###################
fp16: bool = False
params_dtype: torch.dtype = torch.float32 """If true, train with fp16 mixed precision training."""
"""dtype used when intializing the weights."""
bf16: bool = False
timers: Optional[Callable] = None """If true, train with bf16 mixed precision training."""
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
params_dtype: torch.dtype = torch.float32
finalize_model_grads_func: Optional[Callable] = None """dtype used when intializing the weights."""
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism timers: Optional[Callable] = None
dimensions. """Timers object to call for various timing functions. See megatron.core.timers.Timers"""
"""
finalize_model_grads_func: Optional[Callable] = None
grad_scale_func: Optional[Callable] = None """Function that finalizes gradients on all workers. Could include ensuring that grads are
"""If using loss scaling, this function should take the loss and return the scaled loss. If all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
None, no function is called on the loss. dimensions.
""" """
no_sync_func: Optional[Callable] = None grad_scale_func: Optional[Callable] = None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If """If using loss scaling, this function should take the loss and return the scaled loss. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use None, no function is called on the loss.
core.distributed.DistributedDataParallel.no_sync. """
"""
no_sync_func: Optional[Callable] = None
grad_sync_func: Optional[Callable] = None """Function that creates a context that suppresses asynchronous data-parallel communication. If
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient the model is an instance of core.distributed.DistributedDataParallel, the default is to use
reduce-scatters). The function should take one argument: an iterable of parameters whose core.distributed.DistributedDataParallel.no_sync.
gradients are to be synchronized. """
"""
grad_sync_func: Optional[Callable] = None
param_sync_func: Optional[Callable] = None """Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer reduce-scatters). The function should take one argument: an iterable of parameters whose
parameter all-gathers). The function should take one argument: an iterable of parameters to gradients are to be synchronized.
be synchronized. """
"""
param_sync_func: Optional[Callable] = None
deterministic_mode: bool = False """Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
"""If true, code that has deterministic execution will be chosen. This usually parameter all-gathers). The function should take one argument: an iterable of parameters to
means slower execution, but is good for debugging and testing. Defaults to False.""" be synchronized.
"""
enable_autocast: bool = False
"""If true runs the forward step function inside torch.autocast context.""" deterministic_mode: bool = False
"""If true, code that has deterministic execution will be chosen. This usually
autocast_dtype: Optional[torch.dtype] = None means slower execution, but is good for debugging and testing. Defaults to False."""
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
enable_autocast: bool = False
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None """If true runs the forward step function inside torch.autocast context."""
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding autocast_dtype: Optional[torch.dtype] = None
microbatches will recompute all layers (either full recompute or selective recompute). If """dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
None, the checkpoint and recompute will be left up to the forward_step function.
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
""" """If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
################### microbatches will recompute all layers (either full recompute or selective recompute). If
# Optimizations None, the checkpoint and recompute will be left up to the forward_step function.
###################
gradient_accumulation_fusion: bool = False """
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install ###################
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" # Optimizations
--global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you ###################
must turn off gradient accumulation fusion. gradient_accumulation_fusion: bool = False
""" """If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
async_tensor_model_parallel_allreduce: bool = False APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\"
"""NOTE: Deprecated. This flag is ignored.""" --global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
use_te_rng_tracker: bool = False """
"""If true, uses RNG state tracker in TransformerEngine if exists.
""" async_tensor_model_parallel_allreduce: bool = False
"""NOTE: Deprecated. This flag is ignored."""
tp_comm_overlap: bool = False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication use_te_rng_tracker: bool = False
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever """If true, uses RNG state tracker in TransformerEngine if exists.
possible during the forward and the backward pass. """
"""
tp_comm_overlap: bool = False
tp_comm_bulk_wgrad: bool = True """If true, allows overlapping of Linear layer execution with tensor parallel communication
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
tp_comm_overlap is False. possible during the forward and the backward pass.
""" """
tp_comm_bulk_dgrad: bool = True tp_comm_bulk_wgrad: bool = True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if """If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False. tp_comm_overlap is False.
""" """
tp_comm_overlap_ag: bool = True tp_comm_bulk_dgrad: bool = True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather. """If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
Don't care if tp_comm_overlap is False. tp_comm_overlap is False.
""" """
tp_comm_overlap_rs: bool = True tp_comm_overlap_ag: bool = True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter. """If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False. Don't care if tp_comm_overlap is False.
""" """
tp_comm_overlap_rs_dgrad: bool = False tp_comm_overlap_rs: bool = True
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the """If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False. Don't care if tp_comm_overlap is False.
""" """
tp_comm_split_ag: bool = True tp_comm_overlap_rs_dgrad: bool = False
"""Deprecated from TransformerEngine v1.6.0. """If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
splits. Don't care if tp_comm_overlap is False. """
"""
tp_comm_split_ag: bool = True
tp_comm_atomic_ag: bool = False """Deprecated from TransformerEngine v1.6.0.
"""Deprecated from TransformerEngine v1.6.0. If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather splits. Don't care if tp_comm_overlap is False.
both done atomically. Don't care if tp_comm_overlap is False. """
"""
tp_comm_atomic_ag: bool = False
tp_comm_split_rs: bool = True """Deprecated from TransformerEngine v1.6.0.
"""Deprecated from TransformerEngine v1.6.0. If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and both done atomically. Don't care if tp_comm_overlap is False.
Reduce-Scatter splits. Don't care if tp_comm_overlap is False. """
"""
tp_comm_split_rs: bool = True
tp_comm_atomic_rs: bool = False """Deprecated from TransformerEngine v1.6.0.
"""Deprecated from TransformerEngine v1.6.0. If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False. """
"""
tp_comm_atomic_rs: bool = False
cross_entropy_loss_fusion: bool = False """Deprecated from TransformerEngine v1.6.0.
"""If this is enabled, the fused cross entropy implementation would be used. If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Defaults to False. Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
""" """
tp_comm_overlap_disable_qkv: bool = False cross_entropy_loss_fusion: bool = False
""" """If this is enabled, the fused cross entropy implementation would be used.
If true, the AllGather -> Gemm overlap for QKV gets disabled Defaults to False.
""" """
tp_comm_overlap_disable_fc1: bool = False tp_comm_overlap_disable_qkv: bool = False
""" """
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled If true, the AllGather -> Gemm overlap for QKV gets disabled
""" """
tp_comm_bootstrap_backend: str = 'nccl' tp_comm_overlap_disable_fc1: bool = False
""" """
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo' If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
""" """
################### tp_comm_bootstrap_backend: str = 'nccl'
# Pipeline Parallel """
################### Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
pipeline_dtype: torch.dtype = None """
"""dtype used in p2p communication, usually params_dtype"""
###################
variable_seq_lengths: bool = False # Pipeline Parallel
"""Support for variable sequence lengths across microbatches. Setting this communicates the size ###################
of tensors during pipeline parallelism communication, because of this extra overhead it pipeline_dtype: torch.dtype = None
should only be set if the sequence length varies by microbatch within a global batch. """dtype used in p2p communication, usually params_dtype"""
"""
variable_seq_lengths: bool = False
overlap_p2p_comm: bool = False """Support for variable sequence lengths across microbatches. Setting this communicates the size
"""When True some of the peer to peer communication for pipeline parallelism will overlap with of tensors during pipeline parallelism communication, because of this extra overhead it
computation. Must be False if batch_p2p_comm is true. should only be set if the sequence length varies by microbatch within a global batch.
""" """
batch_p2p_comm: bool = True overlap_p2p_comm: bool = False
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if """When True some of the peer to peer communication for pipeline parallelism will overlap with
overlap_p2p_comm is True. computation. Must be False if batch_p2p_comm is true.
""" """
batch_p2p_sync: bool = True batch_p2p_comm: bool = True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in """Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
older version of PyTorch. overlap_p2p_comm is True.
""" """
use_ring_exchange_p2p: bool = False batch_p2p_sync: bool = True
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires """When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
custom built torch with torch.distributed.ring_exchange. older version of PyTorch.
""" """
deallocate_pipeline_outputs: bool = False use_ring_exchange_p2p: bool = False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage. """Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
Helps with saving memory, does nothing when pipeline parallel is not used. custom built torch with torch.distributed.ring_exchange.
""" """
defer_embedding_wgrad_compute: bool = False deallocate_pipeline_outputs: bool = False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is """If True, output data is deallocated after the tensor is sent to the next pipeline stage.
taking place enabling us to hide pipeline flush latency. Defaults to False. Helps with saving memory, does nothing when pipeline parallel is not used.
""" """
wgrad_deferral_limit: int = 0 defer_embedding_wgrad_compute: bool = False
"""This value tunes the number of micro-batches for which the embedding weight gradient compute """If true, defers the embedding WGRAD GEMMs while pipeline flush is
needs to be deferred to pipeline flush, this argument is invalid if taking place enabling us to hide pipeline flush latency. Defaults to False.
`defer_embedding_wgrad_compute` is False. """
Defaults to 0, which means all micro-batches are deferred.
""" wgrad_deferral_limit: int = 0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
pipeline_model_parallel_split_rank: Optional[int] = None needs to be deferred to pipeline flush, this argument is invalid if
"""If int, rank where encoder and decoder should be split in cases where the model has both an `defer_embedding_wgrad_compute` is False.
encoder and decoder (e.g., T5). Ignored if None. Defaults to 0, which means all micro-batches are deferred.
""" """
overlap_p2p_comm_warmup_flush: bool = False pipeline_model_parallel_split_rank: Optional[int] = None
"""If true, overlap communication and computation in warm up and flush phase. """If int, rank where encoder and decoder should be split in cases where the model has both an
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False. encoder and decoder (e.g., T5). Ignored if None.
Defaults to False. """
"""
overlap_p2p_comm_warmup_flush: bool = False
microbatch_group_size_per_vp_stage: Optional[int] = None """If true, overlap communication and computation in warm up and flush phase.
"""This value specifies the number of micro-batches that are executed Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
at a time for a given virtual stage (both forward and backward). Defaults to False.
Default (in __post_init__() method below) to pipeline_parallel_size """
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2, microbatch_group_size_per_vp_stage: Optional[int] = None
num_microbatches = 4, we have """This value specifies the number of micro-batches that are executed
rank 0 | 0 1 0 1 2 3 2 3 at a time for a given virtual stage (both forward and backward).
rank 1 | 0 1 0 1 2 3 2 3 Default (in __post_init__() method below) to pipeline_parallel_size
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5, which specifies a depth-first schedule.
we have Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
rank 0 | 0 1 2 0 1 2 3 4 3 4 num_microbatches = 4, we have
rank 1 | 0 1 2 0 1 2 3 4 3 4 rank 0 | 0 1 0 1 2 3 2 3
""" rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
################### we have
# CPU Offloading rank 0 | 0 1 2 0 1 2 3 4 3 4
################### rank 1 | 0 1 2 0 1 2 3 4 3 4
cpu_offloading: bool = False """
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
###################
cpu_offloading_num_layers: int = 0 # CPU Offloading
"""Tells the number of transformer layers for which activations has to be offloaded.""" ###################
cpu_offloading: bool = False
_cpu_offloading_context: Optional[ContextManager] = ( """When set to True, all the activations are offloaded to the CPU asynchronously."""
None
# Used for internal use only, not to be set by a user. cpu_offloading_num_layers: int = 0
# TODO: Need to move to the 'right' place when possible. """Tells the number of transformer layers for which activations has to be offloaded."""
)
"""For internal use only, do not set.""" _cpu_offloading_context: Optional[ContextManager] = (
None
cpu_offloading_activations: bool = True # Used for internal use only, not to be set by a user.
"""If True, offloads the activations to CPU.""" # TODO: Need to move to the 'right' place when possible.
)
cpu_offloading_weights: bool = True """For internal use only, do not set."""
"""If True, offloads the weights to CPU."""
cpu_offloading_activations: bool = True
################### """If True, offloads the activations to CPU."""
# Timing
################### cpu_offloading_weights: bool = True
barrier_with_L1_time: bool = True """If True, offloads the weights to CPU."""
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example ###################
the user adds a level 1 timer that is not called by all ranks. # Timing
""" ###################
barrier_with_L1_time: bool = True
def __post_init__(self): """If true, use barrier with level 1 time measurements. It is up to the user to make sure
"""Python dataclass method that is used to modify attributes after initialization. calling barrier with their timers will not result in hangs. This can happen if for example
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more the user adds a level 1 timer that is not called by all ranks.
details. """
"""
if self.sequence_parallel: def __post_init__(self):
if self.tensor_model_parallel_size <= 1: """Python dataclass method that is used to modify attributes after initialization.
raise ValueError("Can not use sequence paralllelism without tensor parallelism") See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
if self.expert_tensor_parallel_size is None: """
self.expert_tensor_parallel_size = self.tensor_model_parallel_size if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
if self.pipeline_model_parallel_size > 1: raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.pipeline_dtype is None:
raise ValueError( if self.expert_tensor_parallel_size is None:
"When using pipeline parallelism, pipeline_dtype must be specified" self.expert_tensor_parallel_size = self.tensor_model_parallel_size
)
if self.pipeline_model_parallel_size > 1:
if self.autocast_dtype is None: if self.pipeline_dtype is None:
self.autocast_dtype = self.params_dtype raise ValueError(
"When using pipeline parallelism, pipeline_dtype must be specified"
if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1: )
raise ValueError(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used" if self.autocast_dtype is None:
) self.autocast_dtype = self.params_dtype
if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion: if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1:
raise ValueError( raise ValueError(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used" "Cannot defer embedding wgrad compute when pipeline model parallel is not used"
) )
if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0: if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion:
raise ValueError( raise ValueError(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!" "Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
) )
if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1: if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0:
if self.sequence_parallel is False: raise ValueError(
raise ValueError( "Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
"When using expert parallelism and tensor parallelism, " )
"sequence parallelism must be used"
) if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
if self.microbatch_group_size_per_vp_stage is None: raise ValueError(
self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size "When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
if self.overlap_p2p_comm_warmup_flush: )
if not self.overlap_p2p_comm or self.batch_p2p_comm:
raise ValueError( if self.microbatch_group_size_per_vp_stage is None:
"Pipeline parallel communication overlapping in warmup and flush is only " self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size
"compatible with overlap_p2p_comm but not batch_p2p_comm."
) if self.overlap_p2p_comm_warmup_flush:
if not self.overlap_p2p_comm or self.batch_p2p_comm:
raise ValueError(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
...@@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to ...@@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.relative_pos_embedding import RelativePositionEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_block import TransformerBlock
...@@ -135,9 +137,13 @@ class T5Model(LanguageModule): ...@@ -135,9 +137,13 @@ class T5Model(LanguageModule):
fp16_lm_cross_entropy: bool = False, fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True, parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False, share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', position_embedding_type: Literal[
'learned_absolute', 'rope', 'relative'
] = 'learned_absolute',
rotary_percent: float = 1.0, rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None, seq_len_interpolation_factor: Optional[float] = None,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
add_encoder: bool = True, add_encoder: bool = True,
add_decoder: bool = True, add_decoder: bool = True,
): ):
...@@ -193,6 +199,23 @@ class T5Model(LanguageModule): ...@@ -193,6 +199,23 @@ class T5Model(LanguageModule):
use_cpu_initialization=self.config.use_cpu_initialization, use_cpu_initialization=self.config.use_cpu_initialization,
) )
# Relative Position Embeddings
if self.position_embedding_type == 'relative':
self.encoder_relative_pos_emb = RelativePositionEmbedding(
bidirectional=True,
init_method=self.config.init_method,
num_attention_heads=self.config.num_attention_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
)
self.decoder_relative_pos_emb = RelativePositionEmbedding(
bidirectional=False,
init_method=self.config.init_method,
num_attention_heads=self.config.num_attention_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
)
# Transformer encoder # Transformer encoder
encoder_spec, decoder_spec = ( encoder_spec, decoder_spec = (
self.transformer_encoder_layer_spec, self.transformer_encoder_layer_spec,
...@@ -284,6 +307,27 @@ class T5Model(LanguageModule): ...@@ -284,6 +307,27 @@ class T5Model(LanguageModule):
) )
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Relative positional embeddings
encoder_attention_bias_parallel = None
if self.position_embedding_type == 'relative':
query_seq_length = RelativePositionEmbedding.get_relative_seq_len(
inference_params, self.encoder, encoder_input, self.config
)
key_seq_length = query_seq_length
attention_bias = self.encoder_relative_pos_emb(query_seq_length, key_seq_length)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias = torch.permute(attention_bias, (0, 2, 3, 1))
# Then, scatter to TP region
attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
encoder_attention_bias_parallel = torch.permute(
attention_bias_parallel, (0, 3, 1, 2)
)
# Run encoder. # Run encoder.
if self.add_encoder: if self.add_encoder:
encoder_hidden_states = self.encoder( encoder_hidden_states = self.encoder(
...@@ -291,6 +335,7 @@ class T5Model(LanguageModule): ...@@ -291,6 +335,7 @@ class T5Model(LanguageModule):
attention_mask=encoder_attn_mask, attention_mask=encoder_attn_mask,
inference_params=inference_params, inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
attention_bias=encoder_attention_bias_parallel,
) )
else: else:
encoder_hidden_states = self.encoder_hidden_state encoder_hidden_states = self.encoder_hidden_state
...@@ -315,10 +360,29 @@ class T5Model(LanguageModule): ...@@ -315,10 +360,29 @@ class T5Model(LanguageModule):
rotary_pos_emb = None rotary_pos_emb = None
if self.position_embedding_type == 'rope': if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config, packed_seq_params inference_params, self.decoder, decoder_input, self.config, packed_seq_params
) )
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Relative positional embeddings
decoder_attention_bias_parallel = None
if self.position_embedding_type == 'relative':
query_seq_length = RelativePositionEmbedding.get_relative_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
key_seq_length = query_seq_length
attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias = torch.permute(attention_bias, (0, 2, 3, 1))
# Then, scatter to TP region
attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
decoder_attention_bias_parallel = torch.permute(attention_bias_parallel, (0, 3, 1, 2))
# Run decoder. # Run decoder.
decoder_hidden_states = self.decoder( decoder_hidden_states = self.decoder(
hidden_states=decoder_input, hidden_states=decoder_input,
...@@ -327,12 +391,15 @@ class T5Model(LanguageModule): ...@@ -327,12 +391,15 @@ class T5Model(LanguageModule):
context_mask=encoder_decoder_attn_mask, context_mask=encoder_decoder_attn_mask,
inference_params=inference_params, inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
attention_bias=decoder_attention_bias_parallel,
) )
if self.post_process: if self.post_process:
lm_logits = self.lm_head( output_weight = None
decoder_hidden_states, self.shared_embedding_or_output_weight() if self.share_embeddings_and_output_weights:
) output_weight = self.shared_embedding_or_output_weight()
lm_logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight)
if lm_labels is None: if lm_labels is None:
# [s b h] => [b s h] # [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous() return lm_logits.transpose(0, 1).contiguous()
......
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import logging
import math
from typing import Callable
import torch
from torch import Tensor, nn
from megatron.core.inference_params import InferenceParams
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
logger = logging.getLogger(__name__)
__all__ = ['RelativePositionEmbedding']
class RelativePositionEmbedding(nn.Module):
"""Relative Position Embedding for language model.
Args:
"""
def __init__(
self,
bidirectional: bool,
init_method: Callable,
num_attention_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
) -> None:
super().__init__()
self.bidirectional = bidirectional
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.relative_attention_bias = torch.nn.Embedding(
self.relative_attention_num_buckets, num_attention_heads
)
init_method(self.relative_attention_bias.weight)
def _relative_position_bucket(
self, relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from HuggingFace T5 Model:
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L397
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e. the
distance in tokens from the attending position to the attended-to position.
If bidirectional=False, then positive relative positions are invalid. We use
smaller buckets for small absolute relative_position and larger buckets for
larger absolute relative_positions. All relative positions >=max_distance map
to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the
model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position,
containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger
# bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def _compute_bias(self, query_length, key_length):
"""
Adapted from HuggingFace T5 Model
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L444C9-L444C21
Compute binned relative position bias
Args:
query_length (int): The length of the query sequence
(e.g., the input sequence in attention).
key_length (int): The length of the key sequence
(e.g., the sequence to compare against in attention).
Returns:
torch.Tensor: A tensor representing the relative position bias, with shape
(1, num_heads, query_length, key_length).
"""
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape(query_length,key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape(query_length,key_length,num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape(1, num_heads,query_length,key_length)
return values
@staticmethod
def get_relative_seq_len(
inference_params: InferenceParams,
transformer: TransformerBlock,
transformer_input: Tensor,
transformer_config: TransformerConfig,
) -> float:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
Returns:
float: The rotary sequence length
"""
if inference_params is not None:
relative_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
relative_seq_len = transformer.input_tensor.size(0)
else:
relative_seq_len = transformer_input.size(0)
if transformer_config.sequence_parallel:
relative_seq_len *= transformer_config.tensor_model_parallel_size
return relative_seq_len
def forward(self, query_seq_length, key_seq_length):
"""
Args:
Returns:
"""
return self._compute_bias(query_seq_length, key_seq_length)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.inference_params import InferenceParams from megatron.core.inference_params import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
import logging import logging
import math import math
from functools import lru_cache from functools import lru_cache
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import from megatron.core.models.common.embeddings.rope_utils import ( # for backward compatibility; pylint: disable=unused-import
_apply_rotary_pos_emb_bshd, _apply_rotary_pos_emb_bshd,
_apply_rotary_pos_emb_thd, _apply_rotary_pos_emb_thd,
_rotate_half, _rotate_half,
apply_rotary_pos_emb, apply_rotary_pos_emb,
get_pos_emb_on_this_cp_rank, get_pos_emb_on_this_cp_rank,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['RotaryEmbedding'] __all__ = ['RotaryEmbedding']
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
"""Rotary Embedding for language model. """Rotary Embedding for language model.
Args: Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained kv_channels (int): Projection weights dimension in multi-head attention. Obtained
from transformer config from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position rotary_percent (float): Percent of rotary dimension to use for rotary position
embeddings. embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False. Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
for longer sequences. The value must be a float larger than 1.0. Defaults to None for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000. 10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1 rope_scaling (bool, optional): Apply rope scaling as used in llama 3.x.
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly rope_scaling_factor (float, optional): rope scaling factor in llama 3.x. Defaults to 8.
on the GPU. Defaults to False use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
""" on the GPU. Defaults to False
"""
def __init__(
self, def __init__(
kv_channels: int, self,
rotary_percent: float, kv_channels: int,
rotary_interleaved: bool = False, rotary_percent: float,
seq_len_interpolation_factor: float = None, rotary_interleaved: bool = False,
rotary_base: int = 10000, seq_len_interpolation_factor: float = None,
rope_scaling: bool = False, rotary_base: int = 10000,
use_cpu_initialization: bool = False, rope_scaling: bool = False,
) -> None: rope_scaling_factor: float = 8.0,
super().__init__() use_cpu_initialization: bool = False,
) -> None:
dim = kv_channels super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent) dim = kv_channels
self.rotary_interleaved = rotary_interleaved if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor self.rotary_interleaved = rotary_interleaved
device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
self.inv_freq = 1.0 / ( self.seq_len_interpolation_factor = seq_len_interpolation_factor
rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) device = 'cpu' if use_cpu_initialization else torch.cuda.current_device()
) self.inv_freq = 1.0 / (
rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
if rope_scaling: )
self.inv_freq = self._apply_scaling(self.inv_freq)
if rope_scaling:
def _apply_scaling( self.inv_freq = self._apply_scaling(self.inv_freq, factor=rope_scaling_factor)
self,
freqs, def _apply_scaling(
factor=8, self,
low_freq_factor=1, freqs,
high_freq_factor=4, factor=8,
original_max_position_embeddings=8192, low_freq_factor=1,
): high_freq_factor=4,
# This implementation is adapted from: original_max_position_embeddings=8192,
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343 ):
# This implementation is adapted from:
factor = factor # `8` in the original implementation # https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
low_freq_factor = low_freq_factor # `1` in the original implementation
high_freq_factor = high_freq_factor # `4` in the original implementation factor = factor # `8` in the original implementation
old_context_len = original_max_position_embeddings # `8192` in the original implementation low_freq_factor = low_freq_factor # `1` in the original implementation
high_freq_factor = high_freq_factor # `4` in the original implementation
low_freq_wavelen = old_context_len / low_freq_factor old_context_len = original_max_position_embeddings # `8192` in the original implementation
high_freq_wavelen = old_context_len / high_freq_factor
low_freq_wavelen = old_context_len / low_freq_factor
wavelen = 2 * math.pi / freqs high_freq_wavelen = old_context_len / high_freq_factor
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor wavelen = 2 * math.pi / freqs
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs) # wavelen < high_freq_wavelen: do nothing
# otherwise: interpolate between the two, using a smooth factor # wavelen > low_freq_wavelen: divide by factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( inv_freq_llama = torch.where(wavelen > low_freq_wavelen, freqs / factor, freqs)
high_freq_factor - low_freq_factor # otherwise: interpolate between the two, using a smooth factor
) smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
smoothed_inv_freq = ( high_freq_factor - low_freq_factor
1 - smooth_factor )
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama smoothed_inv_freq = (
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) 1 - smooth_factor
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
return inv_freq_llama inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor: return inv_freq_llama
"""Generates matrix of frequencies based on positions in the sequence,
used to create positional encodings""" def get_freqs_non_repeated(self, max_seq_len: int, offset: int = 0) -> Tensor:
seq = ( """Generates matrix of frequencies based on positions in the sequence,
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) used to create positional encodings"""
+ offset seq = (
) torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
if self.seq_len_interpolation_factor is not None: )
seq *= 1 / self.seq_len_interpolation_factor
if self.seq_len_interpolation_factor is not None:
freqs = torch.outer(seq, self.inv_freq) # [seq len, dim] seq *= 1 / self.seq_len_interpolation_factor
return freqs freqs = torch.outer(seq, self.inv_freq) # [seq len, dim]
def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): return freqs
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length""" def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor):
freqs = self.get_freqs_non_repeated(max_seq_len, offset) """Cosine and sine values for RoPE are precomputed for all positions up to the maximum
cos = torch.cos(freqs) sequence length"""
sin = torch.sin(freqs) freqs = self.get_freqs_non_repeated(max_seq_len, offset)
return cos, sin cos = torch.cos(freqs)
sin = torch.sin(freqs)
@lru_cache(maxsize=32) return cos, sin
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
"""Forward pass of RoPE embedding. @lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
Args: """Forward pass of RoPE embedding.
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0. Args:
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
Returns: packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
Tensor: Embeddings after applying RoPE.
""" Returns:
if self.inv_freq.device.type == 'cpu': Tensor: Embeddings after applying RoPE.
# move `inv_freq` to GPU once at the first micro-batch forward pass """
self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device()) if self.inv_freq.device.type == 'cpu':
# move `inv_freq` to GPU once at the first micro-batch forward pass
freqs = self.get_freqs_non_repeated(max_seq_len, offset) self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device())
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size freqs = self.get_freqs_non_repeated(max_seq_len, offset)
if not self.rotary_interleaved: # first part even vector components, second part odd vector components,
emb = torch.cat((freqs, freqs), dim=-1) # 2 * dim in dimension size
else: if not self.rotary_interleaved:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( emb = torch.cat((freqs, freqs), dim=-1)
freqs.shape[0], -1 else:
) emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
# emb [seq_length, .., dim] freqs.shape[0], -1
emb = emb[:, None, None, :] )
if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq: # emb [seq_length, .., dim]
# slice rotary_pos_emb along sequence dimension and select the parition of the current emb = emb[:, None, None, :]
# CP rank if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq:
emb = get_pos_emb_on_this_cp_rank(emb, 0) # slice rotary_pos_emb along sequence dimension and select the parition of the current
return emb # CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): return emb
state_dict.pop(f'{prefix}inv_freq', None)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
state_dict.pop(f'{prefix}inv_freq', None)
def get_rotary_seq_len( return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
self,
inference_params: InferenceParams, def get_rotary_seq_len(
transformer: TransformerBlock, self,
transformer_input: Tensor, inference_params: InferenceParams,
transformer_config: TransformerConfig, transformer: TransformerBlock,
packed_seq_params: PackedSeqParams, transformer_input: Tensor,
) -> float: transformer_config: TransformerConfig,
"""Function to get the rotary sequence length. packed_seq_params: PackedSeqParams,
) -> float:
Args: """Function to get the rotary sequence length.
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used Args:
by the model inference_params : Used during Inference time
transformer_input (Tensor): Input tensor to the transformer transformer (TransformerBlock): The transformer block (decoder/encoder) used
transformer_config (TransformerConfig): Transformer config used by the model by the model
packed_seq_params (PackedSeqParams): Packed sequence params transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
Returns: packed_seq_params (PackedSeqParams): Packed sequence params
float: The rotary sequence length
""" Returns:
if packed_seq_params is not None: float: The rotary sequence length
# max_seqlen are the max sequence length in the packed sequence before being divived """
# by the tp and cp size. if packed_seq_params is not None:
return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) # max_seqlen are the max sequence length in the packed sequence before being divived
elif inference_params is not None: # by the tp and cp size.
rotary_seq_len = inference_params.max_sequence_length return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv)
else: elif inference_params is not None:
if transformer.input_tensor is not None: rotary_seq_len = inference_params.max_sequence_length
rotary_seq_len = transformer.input_tensor.size(0) else:
else: if transformer is not None and transformer.input_tensor is not None:
rotary_seq_len = transformer_input.size(0) rotary_seq_len = transformer.input_tensor.size(0)
else:
if transformer_config.sequence_parallel: rotary_seq_len = transformer_input.size(0)
rotary_seq_len *= transformer_config.tensor_model_parallel_size
if transformer_config.sequence_parallel:
rotary_seq_len *= transformer_config.context_parallel_size rotary_seq_len *= transformer_config.tensor_model_parallel_size
return rotary_seq_len rotary_seq_len *= transformer_config.context_parallel_size
return rotary_seq_len
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings import warnings
from typing import Optional from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import ( from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention, MLASelfAttention,
MLASelfAttentionSubmodules, MLASelfAttentionSubmodules,
) )
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import ( from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules, TransformerBlockSubmodules,
get_num_layers_to_build, get_num_layers_to_build,
) )
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.transformer.transformer_layer import (
from megatron.core.utils import is_te_min_version TransformerLayer,
TransformerLayerSubmodules,
try: get_transformer_layer_offset,
from megatron.core.extensions.transformer_engine import ( )
TEColumnParallelLinear, from megatron.core.utils import is_te_min_version
TEDotProductAttention,
TELayerNormColumnParallelLinear, try:
TENorm, from megatron.core.extensions.transformer_engine import (
TERowParallelLinear, TEColumnParallelLinear,
) TEDotProductAttention,
TELayerNormColumnParallelLinear,
HAVE_TE = True TENorm,
except ImportError: TERowParallelLinear,
HAVE_TE = False )
try: HAVE_TE = True
import apex # pylint: disable=unused-import except ImportError:
HAVE_TE = False
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
try:
HAVE_APEX = True import apex # pylint: disable=unused-import
LNImpl = FusedLayerNorm
except ImportError: from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.torch_norm import WrappedTorchNorm
HAVE_APEX = True
warnings.warn('Apex is not installed. Falling back to Torch Norm') LNImpl = FusedLayerNorm
LNImpl = WrappedTorchNorm except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
def get_gpt_layer_with_transformer_engine_spec( warnings.warn('Apex is not installed. Falling back to Torch Norm')
num_experts: Optional[int] = None, LNImpl = WrappedTorchNorm
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False, def get_gpt_layer_with_transformer_engine_spec(
fp8: Optional[str] = None, # pylint: disable=unused-arguments num_experts: Optional[int] = None,
moe_use_legacy_grouped_gemm: Optional[bool] = False, moe_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec: qk_layernorm: Optional[bool] = False,
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training). multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
Args: ) -> ModuleSpec:
num_experts (int, optional): Number of experts. Defaults to None. """Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility. Args:
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. num_experts (int, optional): Number of experts. Defaults to None.
Defaults to False. moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
Returns: fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
ModuleSpec: Module specification with TE modules moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
""" Defaults to False.
if fp8 is not None:
warnings.warn( Returns:
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated' ModuleSpec: Module specification with TE modules
' and will be removed soon. Please update your code accordingly.' """
) if fp8 is not None:
warnings.warn(
mlp = _get_mlp_module_spec( 'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
use_te=True, ' and will be removed soon. Please update your code accordingly.'
num_experts=num_experts, )
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, mlp = get_mlp_module_spec(
) use_te=True,
num_experts=num_experts,
if multi_latent_attention: moe_grouped_gemm=moe_grouped_gemm,
return ModuleSpec( moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
module=TransformerLayer, )
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm, if multi_latent_attention:
self_attention=ModuleSpec( return ModuleSpec(
module=MLASelfAttention, module=TransformerLayer,
params={"attn_mask_type": AttnMaskType.causal}, submodules=TransformerLayerSubmodules(
submodules=MLASelfAttentionSubmodules( input_layernorm=TENorm,
linear_q_proj=TEColumnParallelLinear, self_attention=ModuleSpec(
linear_q_down_proj=TEColumnParallelLinear, module=MLASelfAttention,
linear_q_up_proj=TEColumnParallelLinear, params={"attn_mask_type": AttnMaskType.causal},
linear_kv_down_proj=TEColumnParallelLinear, submodules=MLASelfAttentionSubmodules(
linear_kv_up_proj=TEColumnParallelLinear, linear_q_proj=TEColumnParallelLinear,
core_attention=TEDotProductAttention, linear_q_down_proj=TEColumnParallelLinear,
linear_proj=TERowParallelLinear, linear_q_up_proj=(
q_layernorm=TENorm if qk_layernorm else IdentityOp, TELayerNormColumnParallelLinear
kv_layernorm=TENorm if qk_layernorm else IdentityOp, if qk_layernorm
), else TEColumnParallelLinear
), ),
self_attn_bda=get_bias_dropout_add, linear_kv_down_proj=TEColumnParallelLinear,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp, linear_kv_up_proj=(
mlp=mlp, TELayerNormColumnParallelLinear
mlp_bda=get_bias_dropout_add, if qk_layernorm
), else TEColumnParallelLinear
) ),
else: core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
# TENorm significantly harms convergence when used q_layernorm=IdentityOp,
# for QKLayerNorm if TE Version < 1.9; kv_layernorm=IdentityOp,
# we instead use the Apex implementation. ),
qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm ),
self_attn_bda=get_bias_dropout_add,
return ModuleSpec( pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
module=TransformerLayer, mlp=mlp,
submodules=TransformerLayerSubmodules( mlp_bda=get_bias_dropout_add,
self_attention=ModuleSpec( ),
module=SelfAttention, )
params={"attn_mask_type": AttnMaskType.causal}, else:
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear, # TENorm significantly harms convergence when used
core_attention=TEDotProductAttention, # for QKLayerNorm if TE Version < 1.9;
linear_proj=TERowParallelLinear, # we instead use the Apex implementation.
q_layernorm=qk_norm if qk_layernorm else IdentityOp, qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
), return ModuleSpec(
), module=TransformerLayer,
self_attn_bda=get_bias_dropout_add, submodules=TransformerLayerSubmodules(
pre_mlp_layernorm=TENorm if num_experts else IdentityOp, self_attention=ModuleSpec(
mlp=mlp, module=SelfAttention,
mlp_bda=get_bias_dropout_add, params={"attn_mask_type": AttnMaskType.causal},
), submodules=SelfAttentionSubmodules(
) linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
def get_gpt_layer_local_spec( q_layernorm=qk_norm if qk_layernorm else IdentityOp,
num_experts: Optional[int] = None, k_layernorm=qk_norm if qk_layernorm else IdentityOp,
moe_grouped_gemm: Optional[bool] = False, ),
qk_layernorm: Optional[bool] = False, ),
multi_latent_attention: Optional[bool] = False, self_attn_bda=get_bias_dropout_add,
fp8: Optional[str] = None, # pylint: disable=unused-arguments pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
moe_use_legacy_grouped_gemm: Optional[bool] = False, mlp=mlp,
) -> ModuleSpec: mlp_bda=get_bias_dropout_add,
"""Use this spec for an implementation using only modules in Megatron-Core. ),
)
Args:
num_experts (int, optional): Number of experts. Defaults to None. def get_gpt_layer_local_spec(
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. num_experts: Optional[int] = None,
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. moe_grouped_gemm: Optional[bool] = False,
fp8 (str, optional): Deprecated. For temporary Nemo compatibility. qk_layernorm: Optional[bool] = False,
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. multi_latent_attention: Optional[bool] = False,
Defaults to False. fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
Returns: ) -> ModuleSpec:
ModuleSpec: Module specification with Megatron-Core modules """Use this spec for an implementation using only modules in Megatron-Core.
"""
if fp8 is not None:
warnings.warn( Args:
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated' num_experts (int, optional): Number of experts. Defaults to None.
' and will be removed soon. Please update your code accordingly.' moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
) qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
mlp = _get_mlp_module_spec( moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
use_te=False, Defaults to False.
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm, Returns:
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, ModuleSpec: Module specification with Megatron-Core modules
) """
if fp8 is not None:
if multi_latent_attention: warnings.warn(
return ModuleSpec( 'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
module=TransformerLayer, ' and will be removed soon. Please update your code accordingly.'
submodules=TransformerLayerSubmodules( )
input_layernorm=LNImpl,
self_attention=ModuleSpec( mlp = get_mlp_module_spec(
module=MLASelfAttention, use_te=False,
params={"attn_mask_type": AttnMaskType.causal}, num_experts=num_experts,
submodules=MLASelfAttentionSubmodules( moe_grouped_gemm=moe_grouped_gemm,
linear_q_proj=ColumnParallelLinear, moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
linear_q_down_proj=ColumnParallelLinear, )
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear, if multi_latent_attention:
linear_kv_up_proj=ColumnParallelLinear, return ModuleSpec(
core_attention=DotProductAttention, module=TransformerLayer,
linear_proj=RowParallelLinear, submodules=TransformerLayerSubmodules(
q_layernorm=LNImpl if qk_layernorm else IdentityOp, input_layernorm=LNImpl,
kv_layernorm=LNImpl if qk_layernorm else IdentityOp, self_attention=ModuleSpec(
), module=MLASelfAttention,
), params={"attn_mask_type": AttnMaskType.causal},
self_attn_bda=get_bias_dropout_add, submodules=MLASelfAttentionSubmodules(
pre_mlp_layernorm=LNImpl, linear_q_proj=ColumnParallelLinear,
mlp=mlp, linear_q_down_proj=ColumnParallelLinear,
mlp_bda=get_bias_dropout_add, linear_q_up_proj=ColumnParallelLinear,
), linear_kv_down_proj=ColumnParallelLinear,
) linear_kv_up_proj=ColumnParallelLinear,
else: core_attention=DotProductAttention,
return ModuleSpec( linear_proj=RowParallelLinear,
module=TransformerLayer, q_layernorm=LNImpl if qk_layernorm else IdentityOp,
submodules=TransformerLayerSubmodules( kv_layernorm=LNImpl if qk_layernorm else IdentityOp,
input_layernorm=LNImpl, ),
self_attention=ModuleSpec( ),
module=SelfAttention, self_attn_bda=get_bias_dropout_add,
params={"attn_mask_type": AttnMaskType.causal}, pre_mlp_layernorm=LNImpl,
submodules=SelfAttentionSubmodules( mlp=mlp,
linear_qkv=ColumnParallelLinear, mlp_bda=get_bias_dropout_add,
core_attention=DotProductAttention, ),
linear_proj=RowParallelLinear, )
q_layernorm=LNImpl if qk_layernorm else IdentityOp, else:
k_layernorm=LNImpl if qk_layernorm else IdentityOp, return ModuleSpec(
), module=TransformerLayer,
), submodules=TransformerLayerSubmodules(
self_attn_bda=get_bias_dropout_add, input_layernorm=LNImpl,
pre_mlp_layernorm=LNImpl, self_attention=ModuleSpec(
mlp=mlp, module=SelfAttention,
mlp_bda=get_bias_dropout_add, params={"attn_mask_type": AttnMaskType.causal},
sharded_state_dict_keys_map={ submodules=SelfAttentionSubmodules(
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', linear_qkv=ColumnParallelLinear,
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', core_attention=DotProductAttention,
}, linear_proj=RowParallelLinear,
), q_layernorm=LNImpl if qk_layernorm else IdentityOp,
) k_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
def _get_mlp_module_spec( self_attn_bda=get_bias_dropout_add,
use_te: Optional[bool] = True, pre_mlp_layernorm=LNImpl,
num_experts: Optional[int] = None, mlp=mlp,
moe_grouped_gemm: Optional[bool] = False, mlp_bda=get_bias_dropout_add,
fp8: Optional[str] = None, # pylint: disable=unused-arguments sharded_state_dict_keys_map={
moe_use_legacy_grouped_gemm: Optional[bool] = False, 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
) -> ModuleSpec: 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
"""Helper function to get module spec for MLP/MoE""" },
if fp8 is not None: ),
warnings.warn( )
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
) def _get_mlp_module_spec(
use_te: Optional[bool] = True,
if num_experts is None: num_experts: Optional[int] = None,
# Dense MLP w/ or w/o TE modules. moe_grouped_gemm: Optional[bool] = False,
return ModuleSpec( fp8: Optional[str] = None, # pylint: disable=unused-arguments
module=MLP, moe_use_legacy_grouped_gemm: Optional[bool] = False,
submodules=MLPSubmodules( ):
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, warnings.warn(
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, """This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
), since it will be removed in a future release."""
) )
else:
# Mixture of experts with modules in megatron core. return get_mlp_module_spec(
return get_moe_module_spec( use_te=use_te,
use_te=use_te, num_experts=num_experts,
num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm,
moe_grouped_gemm=moe_grouped_gemm, fp8=fp8,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
) )
def get_gpt_decoder_block_spec( def get_mlp_module_spec(
config: TransformerConfig, use_transformer_engine: bool use_te: Optional[bool] = True,
) -> TransformerBlockSubmodules: num_experts: Optional[int] = None,
"""GPT block spec.""" moe_grouped_gemm: Optional[bool] = False,
if use_transformer_engine: fp8: Optional[str] = None, # pylint: disable=unused-arguments
layer_norm_impl = TENorm moe_use_legacy_grouped_gemm: Optional[bool] = False,
else: ) -> ModuleSpec:
layer_norm_impl = LNImpl """Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
# Layer specs. warnings.warn(
dense_layer_spec = ( 'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
get_gpt_layer_with_transformer_engine_spec( ' and will be removed soon. Please update your code accordingly.'
num_experts=None, )
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm, if num_experts is None:
multi_latent_attention=config.multi_latent_attention, # Dense MLP w/ or w/o TE modules.
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, return ModuleSpec(
) module=MLP,
if use_transformer_engine submodules=MLPSubmodules(
else get_gpt_layer_local_spec( linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
num_experts=None, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
moe_grouped_gemm=False, ),
qk_layernorm=config.qk_layernorm, )
multi_latent_attention=config.multi_latent_attention, else:
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, # Mixture of experts with modules in megatron core.
) return get_moe_module_spec(
) use_te=use_te,
moe_layer_spec = ( num_experts=num_experts,
get_gpt_layer_with_transformer_engine_spec( moe_grouped_gemm=moe_grouped_gemm,
num_experts=config.num_moe_experts, moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
moe_grouped_gemm=config.moe_grouped_gemm, )
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, def get_gpt_decoder_block_spec(
) config: TransformerConfig, use_transformer_engine: bool
if use_transformer_engine ) -> TransformerBlockSubmodules:
else get_gpt_layer_local_spec( """GPT block spec."""
num_experts=config.num_moe_experts, if use_transformer_engine:
moe_grouped_gemm=config.moe_grouped_gemm, layer_norm_impl = TENorm
qk_layernorm=config.qk_layernorm, else:
multi_latent_attention=config.multi_latent_attention, layer_norm_impl = LNImpl
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
) # Layer specs.
) dense_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers. num_experts=None,
# 0 stands for dense layers, 1 stands for expert layers. moe_grouped_gemm=False,
# For integer N: Creates a pattern with one expert layer every N layers. qk_layernorm=config.qk_layernorm,
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense). multi_latent_attention=config.multi_latent_attention,
if isinstance(config.moe_layer_freq, int): moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
moe_layer_pattern = [ )
1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers) if use_transformer_engine
] else get_gpt_layer_local_spec(
elif isinstance(config.moe_layer_freq, list): num_experts=None,
moe_layer_pattern = config.moe_layer_freq moe_grouped_gemm=False,
assert len(moe_layer_pattern) == config.num_layers, ( qk_layernorm=config.qk_layernorm,
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " multi_latent_attention=config.multi_latent_attention,
f"expected {config.num_layers}, " moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
f"current moe layer pattern: {config.moe_layer_freq}" )
) )
else: moe_layer_spec = (
raise ValueError( get_gpt_layer_with_transformer_engine_spec(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}" num_experts=config.num_moe_experts,
) moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
# Create the layer specs for the model. multi_latent_attention=config.multi_latent_attention,
layer_specs = [] moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
for layer_number in range(config.num_layers): )
if moe_layer_pattern[layer_number] == 1: if use_transformer_engine
layer_specs.append(moe_layer_spec) else get_gpt_layer_local_spec(
elif moe_layer_pattern[layer_number] == 0: num_experts=config.num_moe_experts,
layer_specs.append(dense_layer_spec) moe_grouped_gemm=config.moe_grouped_gemm,
else: qk_layernorm=config.qk_layernorm,
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}") multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
# Slice the layer specs to only include the layers that are built in this pipeline stage. )
# Note: MCore layer_number starts at 1 )
offset = TransformerLayer._get_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config) # Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
layer_specs = layer_specs[offset : offset + num_layers_to_build] # 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# Block spec. # For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl) if isinstance(config.moe_layer_freq, int):
moe_layer_pattern = [
return block_spec 1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers)
]
elif isinstance(config.moe_layer_freq, list):
moe_layer_pattern = config.moe_layer_freq
assert len(moe_layer_pattern) == config.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {config.num_layers}, "
f"current moe layer pattern: {config.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}"
)
# Create the layer specs for the model.
layer_specs = []
for layer_number in range(config.num_layers):
if moe_layer_pattern[layer_number] == 1:
layer_specs.append(moe_layer_spec)
elif moe_layer_pattern[layer_number] == 0:
layer_specs.append(dense_layer_spec)
else:
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}")
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset = get_transformer_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl)
return block_spec
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Literal, Optional from typing import Dict, Literal, Optional
from torch import Tensor import torch
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core import InferenceParams, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.enums import ModelType from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
class GPTModel(LanguageModule):
"""GPT Transformer language model. class GPTModel(LanguageModule):
"""GPT Transformer language model.
Args:
config (TransformerConfig): Args:
Transformer config config (TransformerConfig):
transformer_layer_spec (ModuleSpec): Transformer config
Specifies module to use for transformer layers transformer_layer_spec (ModuleSpec):
vocab_size (int): Specifies module to use for transformer layers
Vocabulary size vocab_size (int):
max_sequence_length (int): Vocabulary size
maximum size of sequence. This is used for positional embedding max_sequence_length (int):
pre_process (bool, optional): maximum size of sequence. This is used for positional embedding
Include embedding layer (used with pipeline parallelism). Defaults to True. pre_process (bool, optional):
post_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True.
Include an output layer (used with pipeline parallelism). Defaults to True. post_process (bool, optional):
fp16_lm_cross_entropy (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True.
Defaults to False. fp16_lm_cross_entropy (bool, optional):
parallel_output (bool, optional): Defaults to False.
Do not gather the outputs, keep them split across tensor parallel_output (bool, optional):
parallel ranks. Defaults to True. Do not gather the outputs, keep them split across tensor
share_embeddings_and_output_weights (bool, optional): parallel ranks. Defaults to True.
When True, input embeddings and output logit weights are shared. Defaults to False. share_embeddings_and_output_weights (bool, optional):
position_embedding_type (Literal[learned_absolute,rope], optional): When True, input embeddings and output logit weights are shared. Defaults to False.
Position embedding type.. Defaults to 'learned_absolute'. position_embedding_type (Literal[learned_absolute,rope], optional):
rotary_percent (float, optional): Position embedding type.. Defaults to 'learned_absolute'.
Percent of rotary dimension to use for rotary position embeddings. rotary_percent (float, optional):
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. Percent of rotary dimension to use for rotary position embeddings.
rotary_base (int, optional): Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
Base period for rotary position embeddings. Ignored unless rotary_base (int, optional):
position_embedding_type is 'rope'. Base period for rotary position embeddings. Ignored unless
Defaults to 10000. position_embedding_type is 'rope'.
scatter_embedding_sequence_parallel (bool, optional): Defaults to 10000.
Whether embeddings should be scattered across sequence parallel rope_scaling (bool, optional): Toggle RoPE scaling.
region or not. Defaults to True. rope_scaling_factor (float): RoPE scaling factor. Default 8.
seq_len_interpolation_factor (Optional[float], optional): scatter_embedding_sequence_parallel (bool, optional):
scale of linearly interpolating RoPE for longer sequences. Whether embeddings should be scattered across sequence parallel
The value must be a float larger than 1.0. Defaults to None. region or not. Defaults to True.
""" seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
def __init__( The value must be a float larger than 1.0. Defaults to None.
self, """
config: TransformerConfig,
transformer_layer_spec: ModuleSpec, def __init__(
vocab_size: int, self,
max_sequence_length: int, config: TransformerConfig,
pre_process: bool = True, transformer_layer_spec: ModuleSpec,
post_process: bool = True, vocab_size: int,
fp16_lm_cross_entropy: bool = False, max_sequence_length: int,
parallel_output: bool = True, pre_process: bool = True,
share_embeddings_and_output_weights: bool = False, post_process: bool = True,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', fp16_lm_cross_entropy: bool = False,
rotary_percent: float = 1.0, parallel_output: bool = True,
rotary_base: int = 10000, share_embeddings_and_output_weights: bool = False,
rope_scaling: bool = False, position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
scatter_embedding_sequence_parallel: bool = True, rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None, rotary_base: int = 10000,
) -> None: rope_scaling: bool = False,
super().__init__(config=config) rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
if has_config_logger_enabled(config): seq_len_interpolation_factor: Optional[float] = None,
log_config_to_disk(config, locals(), prefix=type(self).__name__) ) -> None:
super().__init__(config=config)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size if has_config_logger_enabled(config):
self.max_sequence_length = max_sequence_length log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.pre_process = pre_process
self.post_process = post_process self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy self.vocab_size = vocab_size
self.parallel_output = parallel_output self.max_sequence_length = max_sequence_length
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.pre_process = pre_process
self.position_embedding_type = position_embedding_type self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
# megatron core pipelining currently depends on model type self.parallel_output = parallel_output
# TODO: remove this dependency ? self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.model_type = ModelType.encoder_or_decoder self.position_embedding_type = position_embedding_type
# These 4 attributes are needed for TensorRT-LLM export. # megatron core pipelining currently depends on model type
self.max_position_embeddings = max_sequence_length # TODO: remove this dependency ?
self.rotary_percent = rotary_percent self.model_type = ModelType.encoder_or_decoder
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling # These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
if self.pre_process: self.rotary_percent = rotary_percent
self.embedding = LanguageModelEmbedding( self.rotary_base = rotary_base
config=self.config, self.rotary_scaling = rope_scaling
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length, if self.pre_process:
position_embedding_type=position_embedding_type, self.embedding = LanguageModelEmbedding(
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel, config=self.config,
) vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: position_embedding_type=position_embedding_type,
self.rotary_pos_emb = RotaryEmbedding( scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
kv_channels=self.config.kv_channels, )
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved, if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
seq_len_interpolation_factor=seq_len_interpolation_factor, self.rotary_pos_emb = RotaryEmbedding(
rotary_base=rotary_base, kv_channels=self.config.kv_channels,
rope_scaling=rope_scaling, rotary_percent=rotary_percent,
use_cpu_initialization=self.config.use_cpu_initialization, rotary_interleaved=self.config.rotary_interleaved,
) seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
# Transformer. rope_scaling=rope_scaling,
self.decoder = TransformerBlock( rope_scaling_factor=rope_scaling_factor,
config=self.config, use_cpu_initialization=self.config.use_cpu_initialization,
spec=transformer_layer_spec, )
pre_process=self.pre_process,
post_process=self.post_process, # Cache for RoPE tensors which do not change between iterations.
) self.rotary_pos_emb_cache = {}
# Output # Transformer.
if post_process: self.decoder = TransformerBlock(
if self.config.defer_embedding_wgrad_compute: config=self.config,
# The embedding activation buffer preserves a reference to the input activations spec=transformer_layer_spec,
# of the final embedding projection layer GEMM. It will hold the activations for pre_process=self.pre_process,
# all the micro-batches of a global batch for the last pipeline stage. Once we are post_process=self.post_process,
# done with all the back props for all the microbatches for the last pipeline stage, )
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs # Output
# stored in gradient buffer to calculate the weight gradients for the embedding if post_process:
# final linear layer. if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer = [] # The embedding activation buffer preserves a reference to the input activations
self.grad_output_buffer = [] # of the final embedding projection layer GEMM. It will hold the activations for
else: # all the micro-batches of a global batch for the last pipeline stage. Once we are
self.embedding_activation_buffer = None # done with all the back props for all the microbatches for the last pipeline stage,
self.grad_output_buffer = None # it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
self.output_layer = tensor_parallel.ColumnParallelLinear( # stored in gradient buffer to calculate the weight gradients for the embedding
config.hidden_size, # final linear layer.
self.vocab_size, self.embedding_activation_buffer = []
config=config, self.grad_output_buffer = []
init_method=config.init_method, else:
bias=False, self.embedding_activation_buffer = None
skip_bias_add=False, self.grad_output_buffer = None
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process self.output_layer = tensor_parallel.ColumnParallelLinear(
and self.share_embeddings_and_output_weights, config.hidden_size,
embedding_activation_buffer=self.embedding_activation_buffer, self.vocab_size,
grad_output_buffer=self.grad_output_buffer, config=config,
) init_method=config.init_method,
bias=False,
if self.pre_process or self.post_process: skip_bias_add=False,
self.setup_embeddings_and_output_layer() gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
if has_config_logger_enabled(self.config): and self.share_embeddings_and_output_weights,
log_config_to_disk( embedding_activation_buffer=self.embedding_activation_buffer,
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt' grad_output_buffer=self.grad_output_buffer,
) )
def set_input_tensor(self, input_tensor: Tensor) -> None: if self.pre_process or self.post_process:
"""Sets input tensor to the model. self.setup_embeddings_and_output_layer()
See megatron.model.transformer.set_input_tensor() if has_config_logger_enabled(self.config):
log_config_to_disk(
Args: self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
input_tensor (Tensor): Sets the input tensor for the model. )
"""
# This is usually handled in schedules.py but some inference code still def set_input_tensor(self, input_tensor: Tensor) -> None:
# gives us non-lists or None """Sets input tensor to the model.
if not isinstance(input_tensor, list):
input_tensor = [input_tensor] See megatron.model.transformer.set_input_tensor()
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' Args:
self.decoder.set_input_tensor(input_tensor[0]) input_tensor (Tensor): Sets the input tensor for the model.
"""
def forward( # This is usually handled in schedules.py but some inference code still
self, # gives us non-lists or None
input_ids: Tensor, if not isinstance(input_tensor, list):
position_ids: Tensor, input_tensor = [input_tensor]
attention_mask: Tensor,
decoder_input: Tensor = None, assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
labels: Tensor = None, self.decoder.set_input_tensor(input_tensor[0])
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None, def forward(
extra_block_kwargs: dict = None, self,
runtime_gather_output: Optional[bool] = None, input_ids: Tensor,
) -> Tensor: position_ids: Tensor,
"""Forward function of the GPT Model This function passes the input tensors attention_mask: Tensor,
through the embedding layer, and then the decoeder and finally into the post decoder_input: Tensor = None,
processing layer (optional). labels: Tensor = None,
inference_params: InferenceParams = None,
It either returns the Loss values if labels are given or the final hidden units packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
Args: runtime_gather_output: Optional[bool] = None,
runtime_gather_output (bool): Gather output at runtime. Default None means ) -> Tensor:
`parallel_output` arg in the constructor will be used. """Forward function of the GPT Model This function passes the input tensors
""" through the embedding layer, and then the decoeder and finally into the post
# If decoder_input is provided (not None), then input_ids and position_ids are ignored. processing layer (optional).
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
It either returns the Loss values if labels are given or the final hidden units
# Decoder embedding.
if decoder_input is not None: Args:
pass runtime_gather_output (bool): Gather output at runtime. Default None means
elif self.pre_process: `parallel_output` arg in the constructor will be used.
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) """
else: # If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# intermediate stage of pipeline # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None # Decoder embedding.
if decoder_input is not None:
# Rotary positional embeddings (embedding is None for PP intermediate devices) pass
rotary_pos_emb = None elif self.pre_process:
rotary_pos_cos = None decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
rotary_pos_sin = None else:
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: # intermediate stage of pipeline
if not self.training and self.config.flash_decode: # decoder will get hidden_states from encoder.input_tensor
# Flash decoding uses precomputed cos and sin for RoPE decoder_input = None
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cos_sin(
inference_params.max_sequence_length # Rotary positional embeddings (embedding is None for PP intermediate devices)
) rotary_pos_emb = None
else: rotary_pos_cos = None
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( rotary_pos_sin = None
inference_params, self.decoder, decoder_input, self.config, packed_seq_params if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
) if not self.training and self.config.flash_decode and inference_params:
rotary_pos_emb = self.rotary_pos_emb( # Flash decoding uses precomputed cos and sin for RoPE
rotary_seq_len, rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
packed_seq=packed_seq_params is not None inference_params.max_sequence_length,
and packed_seq_params.qkv_format == 'thd', self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
) )
else:
# Run decoder. rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
hidden_states = self.decoder( inference_params, self.decoder, decoder_input, self.config, packed_seq_params
hidden_states=decoder_input, )
attention_mask=attention_mask, rotary_pos_emb = self.rotary_pos_emb(
inference_params=inference_params, rotary_seq_len,
rotary_pos_emb=rotary_pos_emb, packed_seq=packed_seq_params is not None
rotary_pos_cos=rotary_pos_cos, and packed_seq_params.qkv_format == 'thd',
rotary_pos_sin=rotary_pos_sin, )
packed_seq_params=packed_seq_params, if (
**(extra_block_kwargs or {}), (self.config.enable_cuda_graph or self.config.flash_decode)
) and rotary_pos_cos is not None
and inference_params
if not self.post_process: ):
return hidden_states sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
# logits and loss dtype=torch.int32,
output_weight = None device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
if self.share_embeddings_and_output_weights: )
output_weight = self.shared_embedding_or_output_weight() else:
logits, _ = self.output_layer( sequence_len_offset = None
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
) # Run decoder.
hidden_states = self.decoder(
if has_config_logger_enabled(self.config): hidden_states=decoder_input,
payload = OrderedDict( attention_mask=attention_mask,
{ inference_params=inference_params,
'input_ids': input_ids, rotary_pos_emb=rotary_pos_emb,
'position_ids': position_ids, rotary_pos_cos=rotary_pos_cos,
'attention_mask': attention_mask, rotary_pos_sin=rotary_pos_sin,
'decoder_input': decoder_input, packed_seq_params=packed_seq_params,
'logits': logits, sequence_len_offset=sequence_len_offset,
} **(extra_block_kwargs or {}),
) )
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if not self.post_process:
if labels is None: return hidden_states
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous() # logits and loss
output_weight = None
loss = self.compute_language_model_loss(labels, logits) if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
return loss logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
def sharded_state_dict( )
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict: if has_config_logger_enabled(self.config):
"""Sharded state dict implementation for GPTModel backward-compatibility payload = OrderedDict(
(removing extra state). {
'input_ids': input_ids,
Args: 'position_ids': position_ids,
prefix (str): Module name prefix. 'attention_mask': attention_mask,
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. 'decoder_input': decoder_input,
metadata (Optional[Dict]): metadata controlling sharded state dict creation. 'logits': logits,
}
Returns: )
ShardedStateDict: sharded state dict for the GPTModel log_config_to_disk(self.config, payload, prefix='input_and_logits')
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) if labels is None:
output_layer_extra_state_key = f'{prefix}output_layer._extra_state' # [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway loss = self.compute_language_model_loss(labels, logits)
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not ( return loss
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}' def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
return sharded_state_dict ) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
return sharded_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment