Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import OrderedDict
from typing import Any, Dict, OrderedDict
import torch
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class EncoderDecoderTextGenerationController(SimpleTextGenerationController):
class EncoderDecoderTextGenerationController(TextGenerationController):
"""The text generation controller for encoder-decoder architecture
This class ingherits from SimpleTextGenerationController, adding features
This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
"""
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
) -> Dict[str, Any]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
encoder_prompts = list(
map(lambda request: request.encoder_prompt, active_requests.values())
)
self.inference_wrapped_model.prep_model_for_inference(
prompts_tokens=prompts_tokens, encoder_prompts=encoder_prompts, tokenizer=self.tokenizer
return self.inference_wrapped_model.prep_inference_input(
prompts_tokens, encoder_prompts, tokenizer=self.tokenizer
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, OrderedDict, Tuple
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import
TextGenerationController as SimpleTextGenerationController,
)
class SimpleTextGenerationController:
"""The basic text generation controller
This class is responsible for tokenizing the input , running the inference, sampling
and also detokenizing the output
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
self.inference_wrapped_model = inference_wrapped_model
self.tokenizer = tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def tokenize_prompt(
self, prompt: str, add_BOS: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
prompt_tokens = self.tokenizer.tokenize(prompt)
if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
return prompt_tokens
def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
"""Detokenize the output generations
Args:
prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt
tokens plus the generated tokens
Returns:
str: The detokenized output
"""
tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist()
return self.tokenizer.detokenize(tokens)
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
common_inference_params: CommonInferenceParams,
vocab_size: int = None,
) -> torch.Tensor:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it
according to the parameters defined in common_inference_params
and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
common_inference_params (CommonInferenceParams): The paramters to use
for inference
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
top_p = common_inference_params.top_p
top_k = common_inference_params.top_k
temperature = common_inference_params.temperature
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
assert top_p <= 1.0, 'top-p should be in (0,1]'
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
# Greedy sampling
if top_k == 1:
sampled_logits = torch.argmax(last_token_logits, dim=-1)
else:
last_token_logits = last_token_logits.clone()
if temperature != 1.0:
last_token_logits.div_(temperature)
if top_k > 1:
assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(last_token_logits, top_k)
elif top_p > 0.0:
modify_logits_for_top_p_filtering(last_token_logits, top_p)
# After filtering, we need to recalculate the distribution.
probabilities = last_token_logits.softmax(dim=-1)
sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if vocab_size:
sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
return sampled_logits
def update_generation_status(
self,
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
increase as we keep generating, until that prompts hits an end condition. The
generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
generated tokens. A tensor of shape [batch_size, max_seq_len]
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean
is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples = updated_prompts_tokens[:, current_context_end_position]
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
is_generation_done_tensor = is_generation_done_tensor | reached_eod
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
generated_sequence_lengths += ~is_generation_done_tensor & generation_started
return is_generation_done_tensor, generated_sequence_lengths
def pad_input_prompt_tokens(
self,
batch_prompt_tokens_list: List[List[int]],
max_prompt_length_in_batch: int,
num_tokens_to_generate: int,
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
with extra indices for each tensor padded with mask id.
"""
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
for prompt_tokens in batch_prompt_tokens_list:
padding_size = max_seq_len - len(prompt_tokens)
prompt_tokens.extend([self.tokenizer.eod] * padding_size)
return torch.tensor(batch_prompt_tokens_list).cuda()
def generate_output_tokens_dynamic_batch(
self, active_requests: OrderedDict[int, InferenceRequest]
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step
at a time, and pass control back to the engine, which will update the request pool and call
this method again.
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
after running one forward step.
"""
raise Exception("Not implemented yet")
def generate_all_output_tokens_static_batch(
self, active_requests: OrderedDict[int, InferenceRequest]
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
"""
batch_prompt_tokens_list = list(
map(lambda request: request.prompt_tokens, active_requests.values())
)
prompt_lengths_in_batch = torch.tensor(
[len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list]
).cuda()
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch)
# For batch inference the inference params are the same for all request
common_inference_params: CommonInferenceParams = list(active_requests.values())[
0
].inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens = self.pad_input_prompt_tokens(
batch_prompt_tokens_list,
max_prompt_length_in_batch=max_prompt_length_in_batch,
num_tokens_to_generate=common_inference_params.num_tokens_to_generate,
)
batch_size, max_sequence_length = batch_prompt_tokens.shape
# Pre allocate log probs tensor
output_log_probs = None
if common_inference_params.return_log_probs:
output_log_probs = torch.empty(
(batch_size, max_sequence_length - 1), dtype=torch.float32
).cuda()
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda()
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths = torch.zeros(batch_size).cuda()
with torch.no_grad():
self.prep_model_for_inference(
prompts_tokens=batch_prompt_tokens, active_requests=active_requests
)
context_start_position = 0
# Pick the context window that we need to pass through the network.
for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
inference_input = self.inference_wrapped_model.get_batch_for_context_window(
context_start_position, context_end_position
)
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, self.tokenizer.vocab_size],
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
tensor=logits,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, common_inference_params, self.tokenizer.vocab_size
)
# Substitute the sampled logits only for only the prompts that
# have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
generation_started
]
if common_inference_params.return_log_probs:
log_probs = F.log_softmax(logits, dim=2)
indices = torch.unsqueeze(
batch_prompt_tokens[
:, (context_start_position + 1) : (context_end_position + 1)
],
2,
)
# Get the log probabilities for only the prompt tokens
output_log_probs[:, context_start_position:context_end_position] = torch.gather(
log_probs, 2, indices
).squeeze(2)
context_start_position = context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(is_generation_done_tensor, generated_sequence_lengths) = (
self.update_generation_status(
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started,
current_context_end_position=context_end_position,
is_generation_done_tensor=is_generation_done_tensor,
generated_sequence_lengths=generated_sequence_lengths,
)
)
# Boolean flag indicating if all prompts are finished
all_prompts_done = torch.all(is_generation_done_tensor)
if all_prompts_done:
break
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if common_inference_params.return_log_probs:
output_log_probs = output_log_probs[:, :context_end_position]
generated_sequence_lengths[
generated_sequence_lengths > common_inference_params.num_tokens_to_generate
] = common_inference_params.num_tokens_to_generate
for idx, request in enumerate(active_requests.values()):
input_prompt_length = int(prompt_lengths_in_batch[idx])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length = int(
min(generated_sequence_lengths[idx], common_inference_params.num_tokens_to_generate)
)
# Extract only the generated tokens
required_result_tokens = batch_prompt_tokens_with_generations[
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
]
request.generated_length = required_sequence_length
request.generated_tokens = required_result_tokens
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, input_prompt_length:required_sequence_length]
)
request.status = Status.COMPLETED
request.generated_text = self.detokenize_generations(required_result_tokens)
return active_requests
def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import concurrent
import copy
import functools
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.inference.async_stream import AsyncStream
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import get_model_config
class TextGenerationController:
"""The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
self.inference_wrapped_model = inference_wrapped_model
self.tokenizer = tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def tokenize_prompt(
self, prompt: str, add_BOS: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
prompt_tokens = self.tokenizer.tokenize(prompt)
if add_BOS:
prompt_tokens = [self.tokenizer.bos] + prompt_tokens
return prompt_tokens
def detokenize_generations(
self,
tokens_gpu_tensor: torch.Tensor,
lengths_gpu_tensor: torch.Tensor,
detokenize_segments: bool,
) -> tuple[str, Optional[List[List[str]]]]:
"""Detokenize the generated tokens.
Args:
tokens_gpu_tensor (torch.Tensor): Tensor containing the tokens
lengths_gpu_tensor (torch.Tensor): Tensor containing the lengths of each sequence
detokenize_segments (bool): If True, returns individually detokenized tokens. If False,
returns None as second element. Helpful for understanding per-token boundaries in
generated text.
Returns:
tuple[str, List[str] | None]: A tuple containing:
- str: The complete detokenized text
- List[str] | None: List of segmented tokens if detokenize_segments is True, else None
"""
# TODO(helenn): Unify with `detokenize_generations` from legacy textgen path
if not detokenize_segments:
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
return self.tokenizer.detokenize(tokens), None
prompts_plus_generations: List[str] = []
prompts_plus_generations_segments: List[List[str]] = []
tokens_gpu_tensor = torch.unsqueeze(tokens_gpu_tensor, 0)
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
detok_str = self.tokenizer.detokenize(sequence_tokens)
prompts_plus_generations.append(detok_str)
offsets = self.tokenizer.offsets(sequence_tokens, detok_str)
words = [
detok_str[start:end] for start, end in zip(offsets, offsets[1:] + [len(detok_str)])
]
prompts_plus_generations_segments.append(words)
text = self.tokenizer.detokenize(tokens[0])
return text, prompts_plus_generations_segments
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
sampling_params: Optional[SamplingParams] = None,
vocab_size: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it
according to the parameters defined in sampling_params
and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
if kwargs.get('common_inference_params'):
sampling_params = kwargs['common_inference_params']
top_p = sampling_params.top_p
top_k = sampling_params.top_k
temperature = sampling_params.temperature
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
assert top_p <= 1.0, 'top-p should be in (0,1]'
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
# Greedy sampling
if top_k == 1:
sampled_logits = torch.argmax(last_token_logits, dim=-1)
else:
last_token_logits = last_token_logits.clone()
if temperature != 1.0:
last_token_logits.div_(temperature)
if top_k > 1:
assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(last_token_logits, top_k)
elif top_p > 0.0:
modify_logits_for_top_p_filtering(last_token_logits, top_p)
# After filtering, we need to recalculate the distribution.
probabilities = last_token_logits.softmax(dim=-1)
sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if vocab_size:
sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
return sampled_logits
def update_generation_status(
self,
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
increase as we keep generating, until that prompts hits an end condition. The
generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
generated tokens. A tensor of shape [batch_size, max_seq_len]
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Returns the boolean
is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples = updated_prompts_tokens[:, current_context_end_position]
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
is_generation_done_tensor = is_generation_done_tensor | reached_eod
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
generated_sequence_lengths += ~is_generation_done_tensor & generation_started
return is_generation_done_tensor, generated_sequence_lengths.int()
def pad_input_prompt_tokens(
self,
batch_prompt_tokens_list: List[List[int]],
max_prompt_length_in_batch: int,
num_tokens_to_generate: int,
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
"""
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
for prompt_tokens in batch_prompt_tokens_list:
padding_size = max_seq_len - len(prompt_tokens)
prompt_tokens.extend([self.tokenizer.eod] * padding_size)
return torch.tensor(batch_prompt_tokens_list, device=torch.cuda.current_device())
def generate_output_tokens_dynamic_batch(
self, active_requests: OrderedDict[str, InferenceRequest]
) -> OrderedDict[str, InferenceRequest]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step
at a time, and pass control back to the engine, which will update the request pool and call
this method again.
Args:
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
Returns:
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
after running one forward step.
"""
raise Exception("Not implemented yet")
def generate_all_output_tokens_static_batch(
self,
active_requests: OrderedDict[str, InferenceRequest],
active_streams: Optional[OrderedDict[str, AsyncStream]] = None,
) -> OrderedDict[str, InferenceRequest]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
Args:
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
Returns:
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
"""
assert all(request.prompt_tokens is not None for request in active_requests.values())
# Perform a deep copy so that the request prompt tokens do not get modified.
batch_prompt_tokens_list: List[List[int]] = list(
map(
lambda request: copy.deepcopy(request.prompt_tokens), # type: ignore[arg-type]
active_requests.values(),
)
)
prompt_lengths_in_batch = torch.tensor(
[len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list],
device=torch.cuda.current_device(),
)
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch)
# For batch inference the inference params are the same for all request
sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens = self.pad_input_prompt_tokens(
batch_prompt_tokens_list,
max_prompt_length_in_batch=max_prompt_length_in_batch,
num_tokens_to_generate=sampling_params.num_tokens_to_generate,
)
batch_size, max_sequence_length = batch_prompt_tokens.shape
# Verify that output sequence length is within configured limit
# TODO(ksanthanam): Raise TokenOverflowError once !2518 is merged
inference_max_sequence_length = (
self.inference_wrapped_model.inference_wrapper_config.inference_max_seq_length
)
assert max_sequence_length <= inference_max_sequence_length, (
f"Maximum allowed sequence length was set to {inference_max_sequence_length} tokens "
f"but requested generation of {max_sequence_length} tokens"
)
# Pre allocate log probs tensor
output_log_probs = None
if sampling_params.return_log_probs:
output_log_probs = torch.empty(
(batch_size, max_sequence_length - 1),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor = torch.zeros(
batch_size, dtype=torch.bool, device=torch.cuda.current_device()
)
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths = torch.zeros(
batch_size, device=torch.cuda.current_device()
).cuda()
# Use padded vocab size because tokenizer vocab size might not include padding
# to nearest power of 2
vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size
# Check whether CUDA graphs are enabled
enable_cuda_graph = get_model_config(self.inference_wrapped_model.model).enable_cuda_graph
streaming_enabled = active_streams is not None and len(active_streams) > 0
if streaming_enabled:
# Start a separate thread for streaming tokens to avoid blocking the
# main computation
streaming_idx: List[int] = [
i
for (i, request_id) in enumerate(active_requests.keys())
if request_id in active_streams
]
streaming_request_ids: List[str] = list(active_streams.keys())
streams: List[AsyncStream] = list(active_streams.values())
streaming_requests: List[InferenceRequest] = [
active_requests[request_id] for request_id in streaming_request_ids
]
streaming_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
stream_tokens = functools.partial(self.stream_tokens, sampling_params)
with torch.no_grad():
self.inference_wrapped_model.prep_model_for_inference(
prompts_tokens=batch_prompt_tokens
)
inference_input: Dict[str, Any] = self.prep_inference_input(
prompts_tokens=batch_prompt_tokens, active_requests=active_requests
)
assert (
not self.inference_wrapped_model.inference_params.decode_mode
), f"Generation must start in prefill mode"
context_start_position = 0
# Pick the context window that we need to pass through the network.
for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
inference_input_for_context_window: Dict[str, Any] = (
self.inference_wrapped_model.get_batch_for_context_window(
inference_input, context_start_position, context_end_position
)
)
# Disable attention mask when using CUDA graphs for decode
if (
enable_cuda_graph
and self.inference_wrapped_model.inference_params.decode_mode
and "attention_mask" in inference_input_for_context_window
):
inference_input_for_context_window["attention_mask"] = None
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(
inference_input_for_context_window
)
if enable_cuda_graph:
create_cudagraphs()
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, vocab_size],
dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
tensor=logits,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, sampling_params, vocab_size
)
# Substitute the sampled logits only for the prompts that
# have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
generation_started
]
if sampling_params.return_log_probs:
log_probs = F.log_softmax(logits, dim=2)
indices = torch.unsqueeze(
batch_prompt_tokens[
:, (context_start_position + 1) : (context_end_position + 1)
],
2,
)
# Get the log probabilities for only the prompt tokens
assert output_log_probs is not None
output_log_probs[:, context_start_position:context_end_position] = torch.gather(
log_probs, 2, indices
).squeeze(2)
context_start_position = context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(is_generation_done_tensor, generated_sequence_lengths) = (
self.update_generation_status(
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started,
current_context_end_position=context_end_position,
is_generation_done_tensor=is_generation_done_tensor,
generated_sequence_lengths=generated_sequence_lengths,
)
)
# Stream intermediate outputs
if streaming_enabled:
streaming_executor.submit(
stream_tokens,
streaming_request_ids,
streaming_requests,
streams,
generation_started[streaming_idx].cpu(),
is_generation_done_tensor[streaming_idx].cpu(),
batch_prompt_tokens[streaming_idx].cpu(),
prompt_lengths_in_batch[streaming_idx].cpu(),
generated_sequence_lengths[streaming_idx].cpu(),
(
output_log_probs[streaming_idx].cpu()
if output_log_probs is not None
else [None] * len(streaming_idx)
),
)
# Boolean flag indicating if all prompts are finished
all_prompts_done = torch.all(is_generation_done_tensor)
if all_prompts_done:
break
# Change to decode mode if all prefill is complete
if torch.all(generation_started):
self.inference_wrapped_model.inference_params.enable_decode_mode()
# Close all streams
if streaming_enabled:
streaming_executor.shutdown()
for stream in streams:
stream.finish()
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if sampling_params.return_log_probs:
assert output_log_probs is not None
output_log_probs = output_log_probs[:, :context_end_position]
generated_sequence_lengths[
generated_sequence_lengths > sampling_params.num_tokens_to_generate
] = sampling_params.num_tokens_to_generate
for idx, request in enumerate(active_requests.values()):
input_prompt_length = int(prompt_lengths_in_batch[idx])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length = int(
min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate)
)
# Extract only the generated tokens
required_result_tokens = batch_prompt_tokens_with_generations[
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
]
generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32)
request.generated_sequence_lengths = generated_sequence_lengths.to(dtype=torch.int32)
request.generated_length = required_sequence_length
request.generated_tokens = required_result_tokens
request.prompt_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, :input_prompt_length].cpu().numpy().tolist()
)
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[
idx,
input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1),
]
.cpu()
.numpy()
.tolist()
)
request.status = Status.COMPLETED
text, segments = self.detokenize_generations(
batch_prompt_tokens_with_generations[idx],
input_prompt_length + generated_sequence_lengths,
sampling_params.return_segments,
)
request.text = text # Inference server returns prompts & generations together
if sampling_params.return_segments:
request.segments = segments[0]
request.generated_text = text[len(request.prompt) :]
return active_requests
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
) -> Dict[str, Any]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
return self.inference_wrapped_model.prep_inference_input(prompts_tokens)
def stream_tokens(
self,
sampling_params: SamplingParams,
request_ids: List[str],
requests: List[InferenceRequest],
streams: List[AsyncStream],
generation_started: List[bool],
is_generation_done: List[bool],
tokens: torch.Tensor,
prompt_lengths: List[int],
generated_lengths: List[int],
output_log_probs: Union[torch.Tensor, None],
):
"""Asynchronously streams tokens for the given requests.
Args:
sampling_params (SamplingParams): The sampling parameters.
request_ids (List[str]): The request IDs.
request (List[InferenceRequest]): The requests.
stream (List[AsyncStream]): The streams over which to send tokens.
generation_started (List[bool]): Whether the decode step has started.
is_generation_done (List[bool]): Whether generation has completed.
tokens (torch.Tensor): The tokens for this request.
prompt_lengths (List[int]): The number of prompt tokens for each request.
generated_lengths (List[int]): The number of output tokens for each request.
output_log_probs (torch.Tensor, optional): The log probs for each request.
"""
def stream_token(
request_id: str,
request: InferenceRequest,
stream: AsyncStream,
generation_started: bool,
is_generation_done: bool,
tokens: torch.Tensor,
prompt_length: int,
generated_length: int,
output_log_probs: Union[torch.Tensor, None],
):
"""Asynchronously streams a token for the given request."""
if not generation_started or stream.finished:
return
num_tokens_to_generate = sampling_params.num_tokens_to_generate
return_segments = sampling_params.return_segments
detokenize_streaming_text = not getattr(
sampling_params, "no_detokenize_streaming_text", False
)
generated_tokens = tokens[prompt_length : prompt_length + generated_length]
if detokenize_streaming_text:
generated_text, generated_segments = self.detokenize_generations(
generated_tokens, prompt_length + generated_length, return_segments
)
else:
generated_text = ""
generated_segments = []
if output_log_probs is not None:
generated_log_probs = (
output_log_probs[prompt_length - 1 : prompt_length + generated_length - 1]
.cpu()
.numpy()
.tolist()
)
else:
generated_log_probs = None
stream.put(
InferenceRequest(
request_id=request_id,
prompt=request.prompt,
inference_parameters=request.inference_parameters,
prompt_tokens=request.prompt_tokens,
arrival_time=request.arrival_time,
status=request.status,
encoder_prompt=request.encoder_prompt,
generated_text=generated_text,
generated_segments=generated_segments,
generated_tokens=generated_tokens,
generated_log_probs=generated_log_probs,
generated_length=generated_length,
)
)
if is_generation_done or generated_length == num_tokens_to_generate:
stream.finish()
ret = map(
stream_token,
request_ids,
requests,
streams,
generation_started,
is_generation_done,
tokens,
prompt_lengths,
generated_lengths,
output_log_probs,
)
list(ret)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import OrderedDict
import torch
from megatron.core.inference.inference_request import InferenceRequest, VLMInferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
class VLMTextGenerationController(TextGenerationController):
"""The text generation controller for VLMs"""
def prep_inference_input(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[str, InferenceRequest]
):
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Currently only supports batch size 1 inference.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
"""
assert len(active_requests) == 1, f"VLM inference currently only supports batch size 1"
request = list(active_requests.values())[0]
assert isinstance(
request, VLMInferenceRequest
), f"Found inference request of type {type(request)}, expected VLMInferenceRequest"
return self.inference_wrapped_model.prep_inference_input(
prompts_tokens,
request.num_img_embeddings_per_tile,
request.imgs,
request.num_tiles,
request.decoder_seq_length,
)
File mode changed from 100755 to 100644
......@@ -6,9 +6,12 @@ class InferenceParams:
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.current_batch_size = max_batch_size # Required for bookkeeping variable-sized batches
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.decode_mode = False
self.key_value_memory_dict = {}
self.decode_mode = False
def swap_key_value_dict(self, batch_idx):
"swap between batches"
......@@ -27,5 +30,71 @@ class InferenceParams:
new_inference_value_memory,
)
def enable_prefill_mode(self):
"""
Indicates the generation loop is in the prefill phase (still processing
input prompt tokens). This should be enabled if the generation loop is
encoding prompt tokens for *any* request in a batch.
"""
self.decode_mode = False
def enable_decode_mode(self):
"""
Indicates the generation loop is in the decode phase (generating new output
tokens). This should only be enabled if the generation loop has fully encoded
the prompts for *all* requests in a batch.
"""
self.decode_mode = True
def reset(self):
"""Resets the inference state for a new batch."""
self.current_batch_size = self.max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.enable_prefill_mode()
def __str__(self):
return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})"
return (
f"InferenceParams(max_seq_len = {self.max_sequence_length}, "
f"max_batch_size = {self.max_batch_size}, "
f"current_batch_size = {self.current_batch_size}, "
f"sequence_len_offset = {self.sequence_len_offset}, "
f"batch_size_offset = {self.batch_size_offset}, "
f"key_value_memory_dict = {self.key_value_memory_dict.keys()})"
f"decode_mode = {self.decode_mode}"
)
def __eq__(self, other):
if not isinstance(other, InferenceParams):
return False
# Check all attributes match
basic_attrs = [
'max_sequence_length',
'max_batch_size',
'current_batch_size',
'sequence_len_offset',
'batch_size_offset',
]
if not all(hasattr(other, attr) for attr in basic_attrs):
return False
# Check dictionary keys match; i.e. the same number of layers are cached
if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys():
return False
# Check each tensor tuple in the dictionary
for key in self.key_value_memory_dict:
self_tensors = self.key_value_memory_dict[key]
other_tensors = other.key_value_memory_dict[key]
# Compare each key, value tensor in the tuple
for self_tensor, other_tensor in zip(self_tensors, other_tensors):
if (
self_tensor.data_ptr() != other_tensor.data_ptr()
or self_tensor.shape != other_tensor.shape
):
return False
return True
File mode changed from 100755 to 100644
......@@ -19,6 +19,11 @@ class ModelParallelConfig:
tensor_model_parallel_size: int = 1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_comm_backend: Optional[str] = None
"""Configuring backend option of pipeline parallel communication (e.g., nccl, ucc)
If None, the default backend will be used.
"""
pipeline_model_parallel_size: int = 1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
......
File mode changed from 100755 to 100644
......@@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.enums import ModelType
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.relative_pos_embedding import RelativePositionEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
......@@ -135,9 +137,13 @@ class T5Model(LanguageModule):
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
position_embedding_type: Literal[
'learned_absolute', 'rope', 'relative'
] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
add_encoder: bool = True,
add_decoder: bool = True,
):
......@@ -193,6 +199,23 @@ class T5Model(LanguageModule):
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Relative Position Embeddings
if self.position_embedding_type == 'relative':
self.encoder_relative_pos_emb = RelativePositionEmbedding(
bidirectional=True,
init_method=self.config.init_method,
num_attention_heads=self.config.num_attention_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
)
self.decoder_relative_pos_emb = RelativePositionEmbedding(
bidirectional=False,
init_method=self.config.init_method,
num_attention_heads=self.config.num_attention_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
relative_attention_max_distance=relative_attention_max_distance,
)
# Transformer encoder
encoder_spec, decoder_spec = (
self.transformer_encoder_layer_spec,
......@@ -284,6 +307,27 @@ class T5Model(LanguageModule):
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Relative positional embeddings
encoder_attention_bias_parallel = None
if self.position_embedding_type == 'relative':
query_seq_length = RelativePositionEmbedding.get_relative_seq_len(
inference_params, self.encoder, encoder_input, self.config
)
key_seq_length = query_seq_length
attention_bias = self.encoder_relative_pos_emb(query_seq_length, key_seq_length)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias = torch.permute(attention_bias, (0, 2, 3, 1))
# Then, scatter to TP region
attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
encoder_attention_bias_parallel = torch.permute(
attention_bias_parallel, (0, 3, 1, 2)
)
# Run encoder.
if self.add_encoder:
encoder_hidden_states = self.encoder(
......@@ -291,6 +335,7 @@ class T5Model(LanguageModule):
attention_mask=encoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
attention_bias=encoder_attention_bias_parallel,
)
else:
encoder_hidden_states = self.encoder_hidden_state
......@@ -315,10 +360,29 @@ class T5Model(LanguageModule):
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config, packed_seq_params
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Relative positional embeddings
decoder_attention_bias_parallel = None
if self.position_embedding_type == 'relative':
query_seq_length = RelativePositionEmbedding.get_relative_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
key_seq_length = query_seq_length
attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias = torch.permute(attention_bias, (0, 2, 3, 1))
# Then, scatter to TP region
attention_bias_parallel = scatter_to_tensor_model_parallel_region(attention_bias)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
decoder_attention_bias_parallel = torch.permute(attention_bias_parallel, (0, 3, 1, 2))
# Run decoder.
decoder_hidden_states = self.decoder(
hidden_states=decoder_input,
......@@ -327,12 +391,15 @@ class T5Model(LanguageModule):
context_mask=encoder_decoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
attention_bias=decoder_attention_bias_parallel,
)
if self.post_process:
lm_logits = self.lm_head(
decoder_hidden_states, self.shared_embedding_or_output_weight()
)
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
lm_logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight)
if lm_labels is None:
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous()
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
......@@ -28,38 +30,60 @@ try:
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
bert_layer_with_transformer_engine_spec = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
def get_bert_layer_with_transformer_engine_spec():
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Returns:
ModuleSpec: Module specification with TE modules
"""
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. Please use local Bert layer spec instead."
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
),
mlp_bda=get_bias_dropout_add,
),
)
)
def __getattr__(name):
if name == 'bert_layer_with_transformer_engine_spec':
warnings.warn(
"""Attribute bert_layer_specs.bert_layer_with_transformer_engine_spec is on a
deprecation track and will be removed in future releases. Please migrate to
bert_layer_specs.get_bert_layer_with_transformer_engine_spec()."""
)
return get_bert_layer_with_transformer_engine_spec()
# Use this spec for an implementation using only modules in megatron core
bert_layer_local_spec = ModuleSpec(
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment