Commit d444a97a authored by yangzhong's avatar yangzhong
Browse files

首次上传

parents
Pipeline #3020 canceled with stages
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Dict, List
import torch
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.scheduler import Scheduler
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class MCoreEngine(AbstractEngine):
"""The Megatron core backend constructor
This is the backend that does a simple forward pass on the model.
Supports any model that is callable (Accepts the inputs and outputs the tensor)
Args:
text_generation_controller (TextGenerationController): A text generation
controller that will be used to define how to preprocess prompts, generate
outputs and detokenizer the output tokens.
max_batch_size : The maxinum number of requests to process at once
random_seed (int, optional): Use a random seed if you want deterministic
results. Defaults to None.
"""
def __init__(
self,
text_generation_controller: TextGenerationController,
max_batch_size,
random_seed: int = None,
):
self.text_generation_controller = text_generation_controller
self.random_seed = random_seed
self.scheduler = Scheduler(max_batch_size=max_batch_size)
def generate(
self,
prompts: List[str],
add_BOS: bool = False,
encoder_prompts: List[str] = None,
common_inference_params: SamplingParams = None,
sampling_params: SamplingParams = None,
) -> dict:
"""The megatron core inference backend generate function
This backend returns the output generations as a dictionary.
It returns the prompt tokens along with the generated tokens, the prompt
plus the generated string and the output log probabilities if requested
Args:
prompts (List[str]): All the prompts as a list of strings
add_BOS (bool): Whether to add BOS token to beginning of prompts
encoder_prompts (List[dict]): All the encoder prompts as a list of strings
common_inference_params: Deprecated. Only used for backward compatibility with
MCore <= 0.9.0. Use `sampling_params` going forward.
sampling_params (SamplingParams): The request-level sampling parameters
Returns:
List[InferenceRequest]: The output is list of inference requests containing the
generated tokens, texts and log probs if required
"""
# TODO :M core- get rng state tracker
if common_inference_params:
sampling_params = common_inference_params
if self.random_seed:
torch.random.manual_seed(self.random_seed)
for i in range(len(prompts)):
prompt = prompts[i]
encoder_prompt = encoder_prompts[i] if encoder_prompts is not None else None
prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt, add_BOS)
self.scheduler.add_request(
prompt=prompt,
prompt_tokens=prompt_tokens,
encoder_prompt=encoder_prompt,
inference_parameters=sampling_params,
)
self.run_engine()
result: List[InferenceRequest] = self.scheduler.completed_request_pool.values()
return result
def run_engine(self):
"""Main functionality to run inference
Runs the engine until there are no requests in the queue.
Args:
dynamic_generation (bool, optional): Set this to True, if you want
to enable dynamic batching. Mainly used with an inference server.
Defaults to False.
"""
while self.scheduler.have_requests_pending():
active_requests: Dict[int, InferenceRequest] = self.scheduler.active_request_pool.copy()
result_dict: Dict[int, InferenceRequest] = (
self.text_generation_controller.generate_all_output_tokens_static_batch(
active_requests
)
)
self.scheduler.update_requests_pools(result_dict=result_dict)
# TODO: Later for dynamic batching we will do something like this
"""
if dynamic_batching:
result_dict: Dict[
int, InferenceRequest
] = self.text_generation_controller.generate_output_tokens_one_step_dynamic_batch(
active_requests
)
self.scheduler.update_requests_pools(result_dict=result_dict)
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from enum import Enum
from typing import List
import torch
from megatron.core.inference.sampling_params import SamplingParams
# class syntax
class Status(Enum):
"""Enum for status"""
WAITING_IN_QUEUE = 1
ACTIVE_AND_GENERATING_TOKENS = 2
ACTIVE_BUT_NOT_GENERATING_TOKENS = 3
COMPLETED = 4
@dataclass
class InferenceRequest:
"""Class for one inference request
Containing relevant data for an inference request
"""
request_id: str
prompt: str
inference_parameters: SamplingParams
prompt_tokens: List[int]
arrival_time: float
status: Status
encoder_prompt: str = None
generated_text: str = None
generated_tokens: torch.Tensor = None
generated_log_probs: torch.Tensor = None
generated_length: int = 0
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import abc
import math
from typing import Iterable, List, Union
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.inference.communication_utils import (
recv_from_prev_pipeline_rank_,
send_to_next_pipeline_rank,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference_params import InferenceParams
from megatron.core.models.gpt.gpt_model import GPTModel
# pylint: disable=line-too-long
class AbstractModelInferenceWrapper(abc.ABC):
"""Abstract inference wrapper
Extend this to create a version for your model.
"""
def __init__(
self,
model: Union['LegacyGPTModel', GPTModel],
inference_wrapper_config: InferenceWrapperConfig,
):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input data and runs the forward pass.
Args:
model (Union[GPTModel, LegacyGPTModel]): The actual GPT model (MCore or MLM)
inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc.
"""
assert not isinstance(
model, Iterable
), 'interleaving schedule is not supported for inference'
self.model = model
self.inference_wrapper_config = inference_wrapper_config
self.pipeline_communication_dtype = (
torch.float
if self.inference_wrapper_config.fp32_residual_connection
else self.inference_wrapper_config.params_dtype
)
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop. It puts the model in eval mode , and gets some model and inference data parameters. Extend this to build position ids ,attention mask etc, so that required slices can be extracted during the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
self.model.eval()
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
self.prompts_tokens = prompts_tokens
batch_size, max_sequence_length = self.prompts_tokens.shape
self.inference_params = InferenceParams(batch_size, max_sequence_length)
@abc.abstractmethod
def get_batch_for_context_window(self) -> List:
"""Returns the input data for inference
This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
"""
pass
def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
tokens, position_ids, attention_mask = inference_input
logits = self.model(
tokens, position_ids, attention_mask, inference_params=self.inference_params
)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
self.inference_params.sequence_len_offset += tokens.size(1)
return logits
def _allocate_recv_buffer(self, batch_size, seq_len):
"""Receive happens between the layers with size [seq_len, batch_size, hidden_size]."""
recv_size = (seq_len, batch_size, self.inference_wrapper_config.hidden_size)
return torch.empty(
recv_size, dtype=self.pipeline_communication_dtype, device=torch.cuda.current_device()
)
def forward_pass_with_pipeline_parallel_small_input_batch(
self, inference_input: List
) -> torch.Tensor:
"""Utility to carry out forward pass for PP models with very small inputs
If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
tokens, position_ids, attention_mask = inference_input
batch_size, seq_len = tokens.shape
recv_buffer = None
if not parallel_state.is_pipeline_first_stage():
recv_buffer = self._allocate_recv_buffer(batch_size, seq_len)
recv_from_prev_pipeline_rank_(recv_buffer)
self.model.set_input_tensor(recv_buffer)
output_tensor = self.model(
tokens, position_ids, attention_mask, inference_params=self.inference_params
)
if not parallel_state.is_pipeline_last_stage():
send_to_next_pipeline_rank(output_tensor.type(dtype=self.pipeline_communication_dtype))
self.inference_params.sequence_len_offset += seq_len
logits = None
if parallel_state.is_pipeline_last_stage():
logits = output_tensor
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
return logits
def forward_pass_with_pipeline_parallel_large_input_batch(
self, inference_input: List
) -> torch.Tensor:
"""Utility to carry out forward pass PP models.
Runs the forward pass for models which are pipeline parallel. This is more complex than forward_pass_with_pipeline_parallel_small_input_batch coz this splits the global batch into small micro batches and runs them through the model.
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
tokens, position_ids, attention_mask = inference_input
micro_batch_size = max(
1,
self.inference_wrapper_config.inference_batch_times_seqlen_threshold // tokens.size(1),
)
batch_size, seq_len = tokens.shape
# Round up to account for the last partial micro batch if present
num_micro_batches = math.ceil(batch_size / micro_batch_size)
logits = None
# Preallocate memory for output logits.
if parallel_state.is_pipeline_last_stage():
logits = torch.empty(
(batch_size, seq_len, self.inference_wrapper_config.padded_vocab_size),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
recv_buffer = None
if not parallel_state.is_pipeline_first_stage():
recv_buffer = self._allocate_recv_buffer(micro_batch_size, seq_len)
for micro_batch_index in range(num_micro_batches):
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
current_micro_batch_size = end - start
# Need to change recv buffer shape for the last partial microbatch (if exists)
if current_micro_batch_size != micro_batch_size:
recv_buffer = self._allocate_recv_buffer(current_micro_batch_size, seq_len)
if not parallel_state.is_pipeline_first_stage():
recv_from_prev_pipeline_rank_(recv_buffer)
self.model.set_input_tensor(recv_buffer)
output_tensor = self.model(
tokens2use, position_ids2use, attention_mask, inference_params=self.inference_params
)
if not parallel_state.is_pipeline_last_stage():
send_to_next_pipeline_rank(output_tensor)
self.inference_params.batch_size_offset += current_micro_batch_size
if parallel_state.is_pipeline_last_stage():
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
output_tensor
)
logits[start:end, ...] = output_tensor
# Once done with all micro batches, we reset batch size offset and seq len offset
self.inference_params.sequence_len_offset += seq_len
self.inference_params.batch_size_offset = 0
# NOTE: Only returns the logits on the last pipeline stage
return logits
def run_one_forward_step(self, inference_input: List) -> torch.Tensor:
"""The forward pass of the model for inference
Appropriate utility is called for the forward pass depending on the type of model parallelism used
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.
"""
if self.model_is_pipeline_parallel:
tokens = inference_input[0]
current_batch_size, seq_len = tokens.shape
# If input batch is large, we need to split into micro batches and run the forward pass
if (
current_batch_size * seq_len
> self.inference_wrapper_config.inference_batch_times_seqlen_threshold
):
return self.forward_pass_with_pipeline_parallel_large_input_batch(inference_input)
else:
# If input batch is very small we can do a simple forward pass on the entire global batch
return self.forward_pass_with_pipeline_parallel_small_input_batch(inference_input)
else:
return self.forward_pass_without_pipeline_parallel(inference_input)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Tuple
import torch
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.models.gpt import GPTModel
# pylint: disable=line-too-long
class GPTInferenceWrapper(AbstractModelInferenceWrapper):
"""Inference wrapper for GPT model"""
def __init__(self, model: GPTModel, inference_wrapper_config: InferenceWrapperConfig):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input data, and runs the forward pass
Args:
model (GPTModel): The GPT model (MCore or legacy)
inference_wrapper_config (InferenceWrapperConfig): Has info like hidden size, vocab size etc
"""
super().__init__(model, inference_wrapper_config)
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
"""A utility function for preparing model for inference
This function is called before the forward pass. It puts the model in eval mode, builds position ids, and creates attention masks so that required slices can be extracted during the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
super().prep_model_for_inference(prompts_tokens=prompts_tokens)
self.attention_mask, self.position_ids = self._build_attention_mask_and_position_ids(
prompts_tokens
)
def _build_attention_mask_and_position_ids(
self, prompts_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Builds the full attention mask and position ids for the input tokens
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
Returns:
Tuple[torch.Tensor, torch.Tensor]: The attention mask of shape [1, 1, max_seq_len, max_seq_len] and position ids of shape [batch_size, max_seq_len]
"""
seq_length = prompts_tokens.size(1)
attention_mask = torch.tril(
torch.ones((1, seq_length, seq_length), device=prompts_tokens.device)
).view(1, 1, seq_length, seq_length)
# Convert to boolean
attention_mask = attention_mask < 0.5
position_ids = (
torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device)
.unsqueeze(0)
.expand_as(prompts_tokens)
)
return attention_mask, position_ids
def get_batch_for_context_window(
self, context_start_position: int, context_end_position: int
) -> List:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
Args:
context_start_position (int): Start of the context window. During the first inference step it is mostly 0
context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length.
Returns:
List: A list of inputs that will be used by your model in the forward step
"""
tokens2use = self.prompts_tokens[:, context_start_position:context_end_position]
positions2use = self.position_ids[:, context_start_position:context_end_position]
attention_mask2use = self.attention_mask[
..., context_start_position:context_end_position, :context_end_position
]
data_at_step_idx = [tokens2use, positions2use, attention_mask2use]
return data_at_step_idx
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
import torch
@dataclass
class InferenceWrapperConfig:
"""Config for the model inference wrapper
NOTE : All the arguments here are obtained from arguments.py file
"""
hidden_size: int
"""Receive happens between the layers during PP with size [seq_len, batch_size, hidden_size]"""
params_dtype: torch.dtype
"""Can be torch.float or torch.half if --fp16 is used, or torch.bfloat16 if --bf16 is used"""
inference_batch_times_seqlen_threshold: int
"""if (batch-size * sequence-length) is smaller than this threshold then we will not pipeline
the batch."""
padded_vocab_size: int
"""The final padded vocab size (Padded to make it divisible by
--make-vocab-size-divisible-by value)"""
fp32_residual_connection: bool = False
"""Move residual connections to fp32. Obtained from arguments.py"""
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to inference params
Use this method to pass in a custom dictionary to add more configs to the instance created.
Use as follows:
c = InferenceWrapperConfig
c.add_attributes({'precision':'fp32'})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and
corresponding values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from collections import deque
from typing import Any, List, Tuple
import numpy
import torch
from megatron.core import tensor_parallel
from megatron.core.datasets.t5_dataset import T5MaskedWordPieceDataset
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.models.T5 import T5Model
# pylint: disable=line-too-long
class T5InferenceWrapper(AbstractModelInferenceWrapper):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy)
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
use_local (bool): Whether the T5 model's transformer impl
is local (vs transformer_engine)
"""
def __init__(
self,
model: T5Model,
inference_wrapper_config: InferenceWrapperConfig,
use_local: bool = False,
):
super().__init__(model, inference_wrapper_config)
self.use_local = use_local
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, encoder_prompts: List[str] = None, tokenizer: Any = None
):
"""A utility function for preparing model for inference
This function is called before the forward pass. It puts the model in eval mode, builds
position ids, and creates attention masks so that required slices can be extracted during
the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
encoder_prompts (dict): List of string of encoder input prompts
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
"""
super().prep_model_for_inference(prompts_tokens=prompts_tokens)
# get max_sequence_length
if hasattr(self.model, "module"): # if self.model is Float16Module
max_sequence_length = self.model.module.max_sequence_length
else:
max_sequence_length = self.model.max_sequence_length
encoder_prompts_tokens_list = [
self.tokenize_encoder_prompt(encoder_prompt, tokenizer)
for encoder_prompt in encoder_prompts
]
self.batch_encoder_prompts_tokens = self.pad_encoder_prompts_tokens(
encoder_prompts_tokens_list, max_sequence_length, tokenizer
)
# create batch mask for encoder_prompt (self.batch_input_tokens) and
# decoder_input (self.prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens = self.prompts_tokens.cpu().numpy()
encoder_prompts_tokens = self.batch_encoder_prompts_tokens.cpu().numpy()
self.batch_mask_encoder = []
self.batch_mask_decoder = []
for i in range(len(self.prompts_tokens)):
mask_encoder = encoder_prompts_tokens[i] == tokenizer.pad
mask_decoder = decoder_prompts_tokens[i] == tokenizer.pad
self.batch_mask_encoder.append(mask_encoder)
self.batch_mask_decoder.append(mask_decoder)
self.batch_mask_encoder = torch.tensor(numpy.array(self.batch_mask_encoder)).cuda()
self.batch_mask_decoder = torch.tensor(numpy.array(self.batch_mask_decoder)).cuda()
def tokenize_encoder_prompt(
self, encoder_prompt: str, tokenizer
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the encoder_prompt
Args:
encoder_prompt (str): The encoder_prompt
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
Returns:
torch.Tensor: Returns the tokenized prompt
"""
# if there is the word "<mask>" in prompt, replacing it with special_additional_token,
# similar to processing step in megatron/core/datasets/t5_dataset.py
divided_encoder_prompt_list = encoder_prompt.split("<mask>")
masks_count = len(divided_encoder_prompt_list) - 1
sentinels = deque(tokenizer.additional_special_tokens_ids)
encoder_prompt_tokens = []
for divided_encoder_prompt in divided_encoder_prompt_list:
divided_encoder_prompt_tokens = tokenizer.tokenize(divided_encoder_prompt)
encoder_prompt_tokens.extend(divided_encoder_prompt_tokens)
if masks_count > 0:
sentinel = sentinels.popleft()
encoder_prompt_tokens.extend([sentinel])
masks_count -= 1
return encoder_prompt_tokens
def pad_encoder_prompts_tokens(
self, encoder_prompts_tokens_list: List[List[int]], max_sequence_length: int, tokenizer
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
encoder_prompts_tokens_list (List[List[int]]): A list containing the
encoder_input_tokens
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Returns:
torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
"""
for encoder_prompt_tokens in encoder_prompts_tokens_list:
padding_size = max_sequence_length - len(encoder_prompt_tokens)
encoder_prompt_tokens.extend([tokenizer.pad] * padding_size)
return torch.tensor(encoder_prompts_tokens_list).cuda()
def get_batch_for_context_window(
self, context_start_position: int, context_end_position: int
) -> List:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context
positions , it extracts the appropriate data.
Args:
context_start_position (int): Start of the context window. During
the first inference step it is mostly 0
context_end_position (int): End of the context window. During the
last inference step it will mostly be the max generated sequence length.
Returns:
List: A list of inputs that will be used by your model in the forward step
"""
# T5 inference not yet support kv_cache
encoder_tokens2use = self.batch_encoder_prompts_tokens
decoder_tokens2use = self.prompts_tokens[:, :context_end_position]
encoder_mask2use = self.batch_mask_encoder
decoder_mask2use = self.batch_mask_decoder[:, :context_end_position]
# Configure attention mask based on different conditions
# (e.g., transformer-impl, TE versions, TE backends)
[encoder_mask2use, decoder_mask2use, encoder_decoder_mask2use] = (
T5MaskedWordPieceDataset.config_attention_mask(
encoder_tokens2use,
decoder_tokens2use,
encoder_mask2use,
decoder_mask2use,
self.use_local,
)
)
data_at_step_idx = [
encoder_tokens2use,
decoder_tokens2use,
encoder_mask2use,
decoder_mask2use,
encoder_decoder_mask2use,
]
return data_at_step_idx
def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without
any parallelism or only tensor parallelism.
Args:
inference_input (List): A list containg the inputs for the gpt
model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
[encoder_tokens, decoder_tokens, encoder_mask, decoder_mask, encoder_decoder_mask] = (
inference_input
)
tokens = decoder_tokens
# T5 inference not yet support kv_cache
logits = self.model(
encoder_tokens,
decoder_tokens,
encoder_mask,
decoder_mask,
encoder_decoder_mask,
inference_params=None,
)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
return logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to
compress model for efficient inference on NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference. More details on ModelOpt including
installation and usage can be found at https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_gpt_layer_modelopt_spec(
num_experts: int = None,
moe_grouped_gemm: bool = False,
remap_te_layernorm: bool = False,
qk_layernorm: bool = False,
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False
)
sharded_state_dict_keys_map = {}
if remap_te_layernorm:
if num_experts:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
}
else:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from logging import getLogger
import torch
logger = getLogger(__name__)
def mcore_gpt_load_legacy_state_dict_pre_hook(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""Register a pre-hook to fix the state_dict key difference.
This prehook is used when trying to load the legacy Megatron-LM GPTModel into its
megatron/core variant that uses native ParallelLinear and Transformer-Engine Norm.
Only this particular spec supports post-training quantization and TensorRT-LLM
config export through `nvidia-modelopt` package.
Args:
state_dict: state dictionary
prefix: module name prefix
local_metadata: local metatdata
strict: whether is in strict mode
missing_keys: missing state dict keys
unexpected_keys: unexpected state dict keys
error_msgs: error messages
"""
if "modelopt_state" in state_dict:
state_dict.pop("modelopt_state")
if "language_model" in state_dict:
language_model_state_dict = state_dict.pop("language_model")
if "embedding" in language_model_state_dict:
if "word_embeddings" in language_model_state_dict["embedding"]:
for key, param in language_model_state_dict["embedding"]["word_embeddings"].items():
state_dict.update({"embedding.word_embeddings." + key: param})
if "position_embeddings" in language_model_state_dict["embedding"]:
for key, param in language_model_state_dict["embedding"][
"position_embeddings"
].items():
state_dict.update({"embedding.position_embeddings." + key: param})
if "transformer" in language_model_state_dict:
for key, param in language_model_state_dict["transformer"].items():
state_dict.update({"decoder." + key: param})
else:
for key, param in language_model_state_dict["encoder"].items():
state_dict.update({"decoder." + key: param})
if "output_layer" in language_model_state_dict:
for key, param in language_model_state_dict["output_layer"].items():
state_dict.update({"output_layer." + key: param})
if torch.distributed.get_rank() == 0:
logger.info("ModelOptGPTModel {}".format(state_dict.keys()))
module_name_rewrite_list = [
("input_norm", "input_layernorm"),
(".attention.query_key_value", ".self_attention.linear_qkv"),
(".attention.dense", ".self_attention.linear_proj"),
("self_attention.query_key_value", "self_attention.linear_qkv"),
("self_attention.dense", "self_attention.linear_proj"),
("post_attention_layernorm", "pre_mlp_layernorm"),
("post_attention_norm", "pre_mlp_layernorm"),
("dense_h_to_4h", "linear_fc1"),
("dense_4h_to_h", "linear_fc2"),
("final_norm", "final_layernorm"),
]
key_rewrite_list = []
for key, _ in state_dict.items():
for old_name, new_name in module_name_rewrite_list:
if old_name in key:
key_rewrite_list += [(key, key.replace(old_name, new_name))]
for old_key, new_key in key_rewrite_list:
if torch.distributed.get_rank() == 0:
logger.info("replace {} with {}".format(old_key, new_key))
state_dict[new_key] = state_dict[old_key]
state_dict.pop(old_key)
def mcore_gpt_load_te_state_dict_pre_hook(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""Register a pre-hook to fix the state_dict key difference of.
This prehook is used when trying to load the megatron/core GPTModel that uses a
fused Transformer-Engine ParallelLinear into the variant that uses native ParallelLinear
and Transformer-Engine Norm (effectively to restore the fusion).
Only this particular spec supports post-training quantization and TensorRT-LLM
config export through `nvidia-modelopt` package.
Args:
state_dict: state dictionary
prefix: module name prefix
local_metadata: local metatdata
strict: whether is in strict mode
missing_keys: missing state dict keys
unexpected_keys: unexpected state dict keys
error_msgs: error messages
"""
if "modelopt_state" in state_dict:
state_dict.pop("modelopt_state")
key_with_te_extra_state_to_pop = []
for key, _ in state_dict.items():
if "_extra_state" in key:
key_with_te_extra_state_to_pop += [key]
for key in key_with_te_extra_state_to_pop:
state_dict.pop(key)
module_name_rewrite_list = [
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("mlp.linear_fc1.layer_norm_weight", "pre_mlp_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "pre_mlp_layernorm.bias"),
]
key_rewrite_list = []
for key, _ in state_dict.items():
for old_name, new_name in module_name_rewrite_list:
if old_name in key:
key_rewrite_list += [(key, key.replace(old_name, new_name))]
for old_key, new_key in key_rewrite_list:
if torch.distributed.get_rank() == 0:
logger.info("replace {} with {}".format(old_key, new_key))
state_dict[new_key] = state_dict[old_key]
state_dict.pop(old_key)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
@dataclass
class SamplingParams:
"""Inference parameters sent along with the prompts.
This class contains request-level attributes that control the sampling techniques used when
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
temperature-parameters-ed6a31313910
"""
temperature: float = 1.0
top_k: int = 0
top_p: float = 0.0
return_log_probs: bool = False
num_tokens_to_generate: int = 30
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c = SamplingParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and
their values as the values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import time
import typing
from collections import OrderedDict
from typing import Dict
import torch
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import Counter
class Scheduler:
"""Scheduler for handling requests to inference engine
This class is responsible for handing of all the incomign requests
Args:
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
"""
def __init__(self, max_batch_size: int):
self.max_batch_size = max_batch_size
self.active_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.waiting_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.completed_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.request_counter = Counter()
def add_request(
self,
prompt: str,
prompt_tokens: torch.Tensor,
encoder_prompt: str = None,
inference_parameters: SamplingParams = None,
arrival_time: float = None,
):
"""Add an incoming request
This method will add the request to either the active pool or the waiting pool
depending on the batch size.
Args:
prompt (str): Input prompt string
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
encoder_prompt (str): Encoder input string
inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
"""
request_id = str(next(self.request_counter))
if arrival_time is None:
arrival_time = time.time()
status = (
Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
if len(self.active_request_pool) < self.max_batch_size
else Status.WAITING_IN_QUEUE
)
inference_request = InferenceRequest(
request_id=request_id,
prompt=prompt,
inference_parameters=inference_parameters,
arrival_time=arrival_time,
prompt_tokens=prompt_tokens,
status=status,
encoder_prompt=encoder_prompt,
)
if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS:
self.active_request_pool[request_id] = inference_request
else:
self.waiting_request_pool[request_id] = inference_request
def have_requests_pending(self) -> bool:
"""Method to check if there are requests pending
This method returns False only when there are no active requests or waiting requests.
"""
num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool)
return num_requests_pending > 0
def add_earliest_waiting_request_to_active_pool(self):
"""Utility to add the waiting request to active pool
This method will add the earliest request (FIFO) that is in the waiting request
pool to the active request pool.
"""
assert (
len(self.active_request_pool) < self.max_batch_size
), "Active request pool is already full. Cant add any more requests"
if len(self.waiting_request_pool) > 0:
(earliest_waiting_request_request_id, earliest_waiting_request) = (
self.waiting_request_pool.popitem(last=False)
)
earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request
def update_requests_pools(self, result_dict: typing.OrderedDict[int, InferenceRequest] = None):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed
request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[int, InferenceRequest], optional): The result returned
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
"""
for result_request_id in list(result_dict.keys()):
active_request = self.active_request_pool[result_request_id]
# If a request has completed put it into the completed request pool.
if active_request.status == Status.COMPLETED:
completed_request = self.active_request_pool.pop(result_request_id)
self.completed_request_pool[result_request_id] = completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while (
len(self.active_request_pool) < self.max_batch_size
and len(self.waiting_request_pool) > 0
):
self.add_earliest_waiting_request_to_active_pool()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import OrderedDict
import torch
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class EncoderDecoderTextGenerationController(TextGenerationController):
"""The text generation controller for encoder-decoder architecture
This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
"""
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
encoder_prompts = list(
map(lambda request: request.encoder_prompt, active_requests.values())
)
self.inference_wrapped_model.prep_model_for_inference(
prompts_tokens=prompts_tokens, encoder_prompts=encoder_prompts, tokenizer=self.tokenizer
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import
TextGenerationController as SimpleTextGenerationController,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, OrderedDict, Tuple
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams
class TextGenerationController:
"""The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
self.inference_wrapped_model = inference_wrapped_model
self.tokenizer = tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def tokenize_prompt(
self, prompt: str, add_BOS: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
prompt_tokens = self.tokenizer.tokenize(prompt)
if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
return prompt_tokens
def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
"""Detokenize the output generations
Args:
prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt
tokens plus the generated tokens
Returns:
str: The detokenized output
"""
tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist()
return self.tokenizer.detokenize(tokens)
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
sampling_params: SamplingParams = None,
vocab_size: int = None,
**kwargs
) -> torch.Tensor:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it
according to the parameters defined in sampling_params
and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
if kwargs.get('common_inference_params'):
sampling_params = kwargs['common_inference_params']
top_p = sampling_params.top_p
top_k = sampling_params.top_k
temperature = sampling_params.temperature
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
assert top_p <= 1.0, 'top-p should be in (0,1]'
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
# Greedy sampling
if top_k == 1:
sampled_logits = torch.argmax(last_token_logits, dim=-1)
else:
last_token_logits = last_token_logits.clone()
if temperature != 1.0:
last_token_logits.div_(temperature)
if top_k > 1:
assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(last_token_logits, top_k)
elif top_p > 0.0:
modify_logits_for_top_p_filtering(last_token_logits, top_p)
# After filtering, we need to recalculate the distribution.
probabilities = last_token_logits.softmax(dim=-1)
sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if vocab_size:
sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
return sampled_logits
def update_generation_status(
self,
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
increase as we keep generating, until that prompts hits an end condition. The
generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
generated tokens. A tensor of shape [batch_size, max_seq_len]
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean
is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples = updated_prompts_tokens[:, current_context_end_position]
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
is_generation_done_tensor = is_generation_done_tensor | reached_eod
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
generated_sequence_lengths += ~is_generation_done_tensor & generation_started
return is_generation_done_tensor, generated_sequence_lengths
def pad_input_prompt_tokens(
self,
batch_prompt_tokens_list: List[List[int]],
max_prompt_length_in_batch: int,
num_tokens_to_generate: int,
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
with extra indices for each tensor padded with mask id.
"""
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
for prompt_tokens in batch_prompt_tokens_list:
padding_size = max_seq_len - len(prompt_tokens)
prompt_tokens.extend([self.tokenizer.eod] * padding_size)
return torch.tensor(batch_prompt_tokens_list).cuda()
def generate_output_tokens_dynamic_batch(
self, active_requests: OrderedDict[int, InferenceRequest]
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step
at a time, and pass control back to the engine, which will update the request pool and call
this method again.
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
after running one forward step.
"""
raise Exception("Not implemented yet")
def generate_all_output_tokens_static_batch(
self, active_requests: OrderedDict[int, InferenceRequest]
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
"""
batch_prompt_tokens_list = list(
map(lambda request: request.prompt_tokens, active_requests.values())
)
prompt_lengths_in_batch = torch.tensor(
[len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list]
).cuda()
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch)
# For batch inference the inference params are the same for all request
sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens = self.pad_input_prompt_tokens(
batch_prompt_tokens_list,
max_prompt_length_in_batch=max_prompt_length_in_batch,
num_tokens_to_generate=sampling_params.num_tokens_to_generate,
)
batch_size, max_sequence_length = batch_prompt_tokens.shape
# Pre allocate log probs tensor
output_log_probs = None
if sampling_params.return_log_probs:
output_log_probs = torch.empty(
(batch_size, max_sequence_length - 1), dtype=torch.float32
).cuda()
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda()
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths = torch.zeros(batch_size).cuda()
with torch.no_grad():
self.prep_model_for_inference(
prompts_tokens=batch_prompt_tokens, active_requests=active_requests
)
context_start_position = 0
# Pick the context window that we need to pass through the network.
for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
inference_input = self.inference_wrapped_model.get_batch_for_context_window(
context_start_position, context_end_position
)
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, self.tokenizer.vocab_size],
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
tensor=logits,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, sampling_params, self.tokenizer.vocab_size
)
# Substitute the sampled logits only for only the prompts that
# have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
generation_started
]
if sampling_params.return_log_probs:
log_probs = F.log_softmax(logits, dim=2)
indices = torch.unsqueeze(
batch_prompt_tokens[
:, (context_start_position + 1) : (context_end_position + 1)
],
2,
)
# Get the log probabilities for only the prompt tokens
output_log_probs[:, context_start_position:context_end_position] = torch.gather(
log_probs, 2, indices
).squeeze(2)
context_start_position = context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(is_generation_done_tensor, generated_sequence_lengths) = (
self.update_generation_status(
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started,
current_context_end_position=context_end_position,
is_generation_done_tensor=is_generation_done_tensor,
generated_sequence_lengths=generated_sequence_lengths,
)
)
# Boolean flag indicating if all prompts are finished
all_prompts_done = torch.all(is_generation_done_tensor)
if all_prompts_done:
break
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if sampling_params.return_log_probs:
output_log_probs = output_log_probs[:, :context_end_position]
generated_sequence_lengths[
generated_sequence_lengths > sampling_params.num_tokens_to_generate
] = sampling_params.num_tokens_to_generate
for idx, request in enumerate(active_requests.values()):
input_prompt_length = int(prompt_lengths_in_batch[idx])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length = int(
min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate)
)
# Extract only the generated tokens
required_result_tokens = batch_prompt_tokens_with_generations[
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
]
request.generated_length = required_sequence_length
request.generated_tokens = required_result_tokens
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, input_prompt_length:required_sequence_length]
)
request.status = Status.COMPLETED
request.generated_text = self.detokenize_generations(required_result_tokens)
return active_requests
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class Counter:
"""A simple counter class
This class is responsible for assigning request ids to incoming requests
"""
def __init__(self, start: int = 0) -> None:
self.counter = start
def __next__(self) -> int:
i = self.counter
self.counter += 1
return i
def reset(self) -> None:
self.counter = 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment