Commit deb8370c authored by hepj's avatar hepj
Browse files

Initial commit

parents
Pipeline #2198 canceled with stages
from typing import Dict, List
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.scheduler import Scheduler
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
)
class MCoreEngine(AbstractEngine):
def __init__(
self,
text_generation_controller: SimpleTextGenerationController,
max_batch_size,
random_seed: int = None,
):
"""The Megatron core backend constructor
This is the backend that does a simple forward pass on the model. Supports any model that is callable (Accepts the inputs and outputs the tensor)
Args:
text_generation_controller (SimpleTextGenerationController): A text generation controller that will be used to define how to preprocess prompts, generate outputs and detokenizer the output tokens.
max_batch_size : The maxinum number of requests to process at once
random_seed (int, optional): Use a random seed if you want deterministic results. Defaults to None.
"""
self.text_generation_controller = text_generation_controller
self.random_seed = random_seed
self.scheduler = Scheduler(max_batch_size=max_batch_size)
def generate(self, prompts: List[str], common_inference_params: CommonInferenceParams) -> dict:
"""The megatron core inference backend generate function
This backend returns the output generations as a dictionary. It returns the prompt tokens along with the generated tokens, the prompt plus the generated string and the output log probabilities if requested
Args:
prompts (List[str]): All the prompts as a list of strings
common_inference_params (CommonInferenceParams): The inference parameters
Returns:
List[InferenceRequest]: The output is list of inference requests containing the generated tokens, texts and log probs if required
"""
# TODO :M core- get rng state tracker
if self.random_seed:
torch.random.manual_seed(self.random_seed)
for prompt in prompts:
prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt)
self.scheduler.add_request(
prompt=prompt,
prompt_tokens=prompt_tokens,
inference_parameters=common_inference_params,
)
self.run_engine()
result: List[InferenceRequest] = self.scheduler.completed_request_pool.values()
return result
def run_engine(self):
"""Main functionality to run inference
Runs the engine until there are no requests in the queue.
Args:
dynamic_generation (bool, optional): Set this to True, if you want to enable dynamic batching. Mainly used with an inference server. Defaults to False.
"""
while self.scheduler.have_requests_pending():
active_requests: Dict[int, InferenceRequest] = self.scheduler.active_request_pool.copy()
result_dict: Dict[int, InferenceRequest] = (
self.text_generation_controller.generate_all_output_tokens_static_batch(
active_requests
)
)
self.scheduler.update_requests_pools(result_dict=result_dict)
# TODO: Later for dynamic batching we will do something like this
"""
if dynamic_batching:
result_dict: Dict[
int, InferenceRequest
] = self.text_generation_controller.generate_output_tokens_one_step_dynamic_batch(
active_requests
)
self.scheduler.update_requests_pools(result_dict=result_dict)
"""
from dataclasses import dataclass
from enum import Enum
from typing import List
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
# class syntax
class Status(Enum):
WAITING_IN_QUEUE = 1
ACTIVE_AND_GENERATING_TOKENS = 2
ACTIVE_BUT_NOT_GENERATING_TOKENS = 3
COMPLETED = 4
@dataclass
class InferenceRequest:
request_id: str
prompt: str
inference_parameters: CommonInferenceParams
prompt_tokens: List[int]
arrival_time: float
status: Status
generated_text: str = None
generated_tokens: torch.Tensor = None
generated_log_probs: torch.Tensor = None
generated_length: int = 0
import abc
import math
from argparse import Namespace
from typing import Iterable, List, Union
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.communication_utils import (
recv_from_prev_pipeline_rank_,
send_to_next_pipeline_rank,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference_params import InferenceParams
from megatron.core.models.gpt.gpt_model import GPTModel
class AbstractModelInferenceWrapper(abc.ABC):
def __init__(
self,
model: Union['LegacyGPTModel', GPTModel],
inference_wrapper_config: InferenceWrapperConfig,
):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input data and runs the forward pass.
Args:
model (Union[GPTModel, LegacyGPTModel]): The actual GPT model (MCore or MLM)
args (Namespace): The commadline arguments that were passed
"""
assert not isinstance(
model, Iterable
), 'interleaving schedule is not supported for inference'
self.model = model
self.inference_wrapper_config = inference_wrapper_config
self.pipeline_communication_dtype = (
torch.float
if self.inference_wrapper_config.fp32_residual_connection
else self.inference_wrapper_config.params_dtype
)
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop. It puts the model in eval mode , and gets some model and inference data parameters. Extend this to build position ids ,attention mask etc, so that required slices can be extracted during the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
self.model.eval()
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
self.prompts_tokens = prompts_tokens
batch_size, max_sequence_length = self.prompts_tokens.shape
self.inference_params = InferenceParams(batch_size, max_sequence_length)
@abc.abstractmethod
def get_batch_for_context_window(self) -> List:
"""Returns the input data for inference
This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
"""
pass
def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without any parallelism or only tensor parallelism.
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
tokens, position_ids, attention_mask = inference_input
logits = self.model(
tokens, position_ids, attention_mask, inference_params=self.inference_params
)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
self.inference_params.sequence_len_offset += tokens.size(1)
return logits
def _allocate_recv_buffer(self, batch_size, seq_len):
"""Receive happens between the layers with size [seq_len, batch_size, hidden_size]."""
recv_size = (seq_len, batch_size, self.inference_wrapper_config.hidden_size)
return torch.empty(
recv_size, dtype=self.pipeline_communication_dtype, device=torch.cuda.current_device()
)
def forward_pass_with_pipeline_parallel_small_input_batch(
self, inference_input: List
) -> torch.Tensor:
"""Utility to carry out forward pass for PP models with very small inputs
If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
tokens, position_ids, attention_mask = inference_input
batch_size, seq_len = tokens.shape
recv_buffer = None
if not parallel_state.is_pipeline_first_stage():
recv_buffer = self._allocate_recv_buffer(batch_size, seq_len)
recv_from_prev_pipeline_rank_(recv_buffer)
self.model.set_input_tensor(recv_buffer)
output_tensor = self.model(
tokens, position_ids, attention_mask, inference_params=self.inference_params
)
if not parallel_state.is_pipeline_last_stage():
send_to_next_pipeline_rank(output_tensor.type(dtype=self.pipeline_communication_dtype))
self.inference_params.sequence_len_offset += seq_len
logits = None
if parallel_state.is_pipeline_last_stage():
logits = output_tensor
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
return logits
def forward_pass_with_pipeline_parallel_large_input_batch(
self, inference_input: List
) -> torch.Tensor:
"""Utility to carry out forward pass PP models.
Runs the forward pass for models which are pipeline parallel. This is more complex than forward_pass_with_pipeline_parallel_small_input_batch coz this splits the global batch into small micro batches and runs them through the model.
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
tokens, position_ids, attention_mask = inference_input
micro_batch_size = max(
1,
self.inference_wrapper_config.inference_batch_times_seqlen_threshold // tokens.size(1),
)
batch_size, seq_len = tokens.shape
# Round up to account for the last partial micro batch if present
num_micro_batches = math.ceil(batch_size / micro_batch_size)
logits = None
# Preallocate memory for output logits.
if parallel_state.is_pipeline_last_stage():
logits = torch.empty(
(batch_size, seq_len, self.inference_wrapper_config.padded_vocab_size),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
recv_buffer = None
if not parallel_state.is_pipeline_first_stage():
recv_buffer = self._allocate_recv_buffer(micro_batch_size, seq_len)
for micro_batch_index in range(num_micro_batches):
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
current_micro_batch_size = end - start
# Need to change recv buffer shape for the last partial microbatch (if exists)
if current_micro_batch_size != micro_batch_size:
recv_buffer = self._allocate_recv_buffer(current_micro_batch_size, seq_len)
if not parallel_state.is_pipeline_first_stage():
recv_from_prev_pipeline_rank_(recv_buffer)
self.model.set_input_tensor(recv_buffer)
output_tensor = self.model(
tokens2use, position_ids2use, attention_mask, inference_params=self.inference_params
)
if not parallel_state.is_pipeline_last_stage():
send_to_next_pipeline_rank(output_tensor)
self.inference_params.batch_size_offset += current_micro_batch_size
if parallel_state.is_pipeline_last_stage():
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
output_tensor
)
logits[start:end, ...] = output_tensor
# Once done with all micro batches, we reset batch size offset and seq len offset
self.inference_params.sequence_len_offset += seq_len
self.inference_params.batch_size_offset = 0
# NOTE: Only returns the logits on the last pipeline stage
return logits
def run_one_forward_step(self, inference_input: List) -> torch.Tensor:
"""The forward pass of the model for inference
Appropriate utility is called for the forward pass depending on the type of model parallelism used
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.
"""
if self.model_is_pipeline_parallel:
tokens = inference_input[0]
current_batch_size, seq_len = tokens.shape
# If input batch is large, we need to split into micro batches and run the forward pass
if (
current_batch_size * seq_len
> self.inference_wrapper_config.inference_batch_times_seqlen_threshold
):
return self.forward_pass_with_pipeline_parallel_large_input_batch(inference_input)
else:
# If input batch is very small we can do a simple forward pass on the entire global batch
return self.forward_pass_with_pipeline_parallel_small_input_batch(inference_input)
else:
return self.forward_pass_without_pipeline_parallel(inference_input)
from argparse import Namespace
from typing import List, Tuple
import torch
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.models.gpt import GPTModel
class GPTInferenceWrapper(AbstractModelInferenceWrapper):
def __init__(self, model: GPTModel, args: Namespace):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input data, and runs the forward pass
Args:
model (GPTModel): The GPT model (MCore or legacy)
args (Namespace): The command line arguments that were passed
"""
super().__init__(model, args)
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
"""A utility function for preparing model for inference
This function is called before the forward pass. It puts the model in eval mode, builds position ids, and creates attention masks so that required slices can be extracted during the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
super().prep_model_for_inference(prompts_tokens=prompts_tokens)
self.attention_mask, self.position_ids = self._build_attention_mask_and_position_ids(
prompts_tokens
)
def _build_attention_mask_and_position_ids(
self, prompts_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Builds the full attention mask and position ids for the input tokens
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
Returns:
Tuple[torch.Tensor, torch.Tensor]: The attention mask of shape [1, 1, max_seq_len, max_seq_len] and position ids of shape [batch_size, max_seq_len]
"""
seq_length = prompts_tokens.size(1)
attention_mask = torch.tril(
torch.ones((1, seq_length, seq_length), device=prompts_tokens.device)
).view(1, 1, seq_length, seq_length)
# Convert to boolean
attention_mask = attention_mask < 0.5
position_ids = (
torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device)
.unsqueeze(0)
.expand_as(prompts_tokens)
)
return attention_mask, position_ids
def get_batch_for_context_window(
self, context_start_position: int, context_end_position: int
) -> List:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
Args:
context_start_position (int): Start of the context window. During the first inference step it is mostly 0
context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length.
Returns:
List: A list of inputs that will be used by your model in the forward step
"""
tokens2use = self.prompts_tokens[:, context_start_position:context_end_position]
positions2use = self.position_ids[:, context_start_position:context_end_position]
attention_mask2use = self.attention_mask[
..., context_start_position:context_end_position, :context_end_position
]
data_at_step_idx = [tokens2use, positions2use, attention_mask2use]
return data_at_step_idx
from dataclasses import dataclass
import torch
@dataclass
class InferenceWrapperConfig:
"""Config for the model inference wrapper
NOTE : All the arguments here are obtained from arguments.py file
"""
hidden_size: int
"""Receive happens between the layers during PP with size [seq_len, batch_size, hidden_size]"""
params_dtype: torch.dtype
"""Can be torch.float or torch.half if --fp16 is used, or torch.bfloat16 if --bf16 is used"""
inference_batch_times_seqlen_threshold: int
"""if batch-size times sequence-length is smaller than this threshold then we will not use pipelining, otherwise we will."""
padded_vocab_size: int
"""The final padded vocab size (Padded to make it divisible by --make-vocab-size-divisible-by value)"""
fp32_residual_connection: bool = False
"""Move residual connections to fp32. Obtained from arguments.py"""
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to inference params
Use this method to pass in a custom dictonary to add more config to the instance you created. Use as follows
c = InferenceWrapperConfig
c.add_attributes({'precision':'fp32'})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and their values as the values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
import time
import typing
from collections import OrderedDict
from typing import Dict, List
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.utils import Counter
class Scheduler:
def __init__(self, max_batch_size: int):
"""Scheduler for handling requests to inference engine
This class is responsible for handing of all the incomign requests
Args:
max_batch_size (int): The max batch size that we can pass to the inference engine at a time.
"""
self.max_batch_size = max_batch_size
self.active_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.waiting_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.completed_request_pool: Dict[int, InferenceRequest] = OrderedDict()
self.request_counter = Counter()
def add_request(
self,
prompt: str,
prompt_tokens: torch.Tensor,
inference_parameters: CommonInferenceParams,
arrival_time: float = None,
):
"""Add an incoming request
This method will add the request to either the active pool or the waiting pool depending on the batch size.
Args:
prompt (str): Input prompt string
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
inference_parameters (CommonInferenceParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
"""
request_id = str(next(self.request_counter))
if arrival_time is None:
arrival_time = time.time()
status = (
Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
if len(self.active_request_pool) < self.max_batch_size
else Status.WAITING_IN_QUEUE
)
inference_request = InferenceRequest(
request_id=request_id,
prompt=prompt,
inference_parameters=inference_parameters,
arrival_time=arrival_time,
prompt_tokens=prompt_tokens,
status=status,
)
if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS:
self.active_request_pool[request_id] = inference_request
else:
self.waiting_request_pool[request_id] = inference_request
def have_requests_pending(self) -> bool:
"""Method to check if there are requests pending
This method returns False only when there are no active requests or waiting requests.
"""
num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool)
return num_requests_pending > 0
def add_earliest_waiting_request_to_active_pool(self):
"""Utility to add the waiting request to active pool
This method will add the earliest request (FIFO) that is in the waiting request pool to the active request pool.
"""
assert (
len(self.active_request_pool) < self.max_batch_size
), "Active request pool is already full. Cant add any more requests"
if len(self.waiting_request_pool) > 0:
(
earliest_waiting_request_request_id,
earliest_waiting_request,
) = self.waiting_request_pool.popitem(last=False)
earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request
def update_requests_pools(self, result_dict: typing.OrderedDict[int, InferenceRequest] = None):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[int, InferenceRequest], optional): The result returned by the engine. A dictionary with keys as the request ids, and values as the requests. Defaults to None
"""
for result_request_id in list(result_dict.keys()):
active_request = self.active_request_pool[result_request_id]
# If a request has completed put it into the completed request pool.
if active_request.status == Status.COMPLETED:
completed_request = self.active_request_pool.pop(result_request_id)
self.completed_request_pool[result_request_id] = completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while (
len(self.active_request_pool) < self.max_batch_size
and len(self.waiting_request_pool) > 0
):
self.add_earliest_waiting_request_to_active_pool()
from typing import List, OrderedDict, Tuple
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
class SimpleTextGenerationController:
def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
"""The basic text generation controller
This class is responsible for tokenizing the input , running the inference, sampling and also detokenizing the output
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
self.inference_wrapped_model = inference_wrapped_model
self.tokenizer = tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self.model_is_pipeline_parallel = not (
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
)
def tokenize_prompt(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
return self.tokenizer.tokenize(prompt)
def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
"""Detokenize the output generations
Args:
prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt tokens plus the generated tokens
Returns:
str: The detokenized output
"""
tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist()
return self.tokenizer.detokenize(tokens)
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
common_inference_params: CommonInferenceParams,
vocab_size: int = None,
) -> torch.Tensor:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it according to the parameters defined in common_inference_params and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of size [batch_size, vocab_size]
common_inference_params (CommonInferenceParams): The paramters to use for inference
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
top_p = common_inference_params.top_p
top_k = common_inference_params.top_k
temperature = common_inference_params.temperature
assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
assert top_p <= 1.0, 'top-p should be in (0,1]'
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
# Greedy sampling
if top_k == 1:
sampled_logits = torch.argmax(last_token_logits, dim=-1)
else:
last_token_logits = last_token_logits.clone()
if temperature != 1.0:
last_token_logits.div_(temperature)
if top_k > 1:
assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(last_token_logits, top_k)
elif top_p > 0.0:
modify_logits_for_top_p_filtering(last_token_logits, top_p)
# After filtering, we need to recalculate the distribution.
probabilities = last_token_logits.softmax(dim=-1)
sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if vocab_size:
sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
return sampled_logits
def update_generation_status(
self,
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding flags of the is_generation_done_tensor to True. The generated sequence lengths increase as we keep generating, until that prompts hits an end condition. The generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest generated tokens. A tensor of shape [batch_size, max_seq_len] (i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size]. True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size]. Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples = updated_prompts_tokens[:, current_context_end_position]
# Make sure we are checking eod criterion only for prompts that have started generating (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
is_generation_done_tensor = is_generation_done_tensor | reached_eod
# We increment generated sequence lengths when that prompt has not hit the EOD and generation has started
generated_sequence_lengths += ~is_generation_done_tensor & generation_started
return is_generation_done_tensor, generated_sequence_lengths
def pad_input_prompt_tokens(
self,
batch_prompt_tokens_list: List[List[int]],
max_prompt_length_in_batch: int,
num_tokens_to_generate: int,
) -> torch.Tensor:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e) max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate, with extra indices for each tensor padded with mask id.
"""
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
for prompt_tokens in batch_prompt_tokens_list:
padding_size = max_seq_len - len(prompt_tokens)
prompt_tokens.extend([self.tokenizer.eod] * padding_size)
return torch.tensor(batch_prompt_tokens_list).cuda()
def generate_output_tokens_dynamic_batch(
self,
active_requests: OrderedDict[int, InferenceRequest],
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step at a time, and pass control back to the engine, which will update the request pool and call this method again.
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests after running one forward step.
"""
raise Exception("Not implemented yet")
def generate_all_output_tokens_static_batch(
self,
active_requests: OrderedDict[int, InferenceRequest],
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till all prompts complete generation, updates the status of these requests to completed, adds the generated result and returns these requests
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
"""
batch_prompt_tokens_list = list(
map(lambda request: request.prompt_tokens, active_requests.values())
)
prompt_lengths_in_batch = torch.tensor(
[len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list]
).cuda()
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch)
# For batch inference the inference params are the same for all request
common_inference_params: CommonInferenceParams = list(active_requests.values())[
0
].inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens = self.pad_input_prompt_tokens(
batch_prompt_tokens_list,
max_prompt_length_in_batch=max_prompt_length_in_batch,
num_tokens_to_generate=common_inference_params.num_tokens_to_generate,
)
batch_size, max_sequence_length = batch_prompt_tokens.shape
# Pre allocate log probs tensor
output_log_probs = None
if common_inference_params.return_log_probs:
output_log_probs = torch.empty(
(batch_size, max_sequence_length - 1), dtype=torch.float32
).cuda()
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda()
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths = torch.zeros(batch_size).cuda()
with torch.no_grad():
self.inference_wrapped_model.prep_model_for_inference(
prompts_tokens=batch_prompt_tokens
)
context_start_position = 0
# Pick the context window that we need to pass through the network.
for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
inference_input = self.inference_wrapped_model.get_batch_for_context_window(
context_start_position, context_end_position
)
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
[batch_size, context_length, self.tokenizer.vocab_size],
dtype=torch.float32,
tensor=logits,
)
# Indicates which of the input prompts have started generating tokens. A 1D boolean tensor with [batch_size] elements (i.e) The shortest prompts will start generating first and so on
generation_started = prompt_lengths_in_batch <= context_end_position
last_token_logits = logits[:, -1, :]
sampled_logits = self.sample_from_logits(
last_token_logits, common_inference_params, self.tokenizer.vocab_size
)
# Substitute the sampled logits only for only the prompts that have started generating tokens
batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
generation_started
]
if common_inference_params.return_log_probs:
log_probs = F.log_softmax(logits, dim=2)
indices = torch.unsqueeze(
batch_prompt_tokens[
:, (context_start_position + 1) : (context_end_position + 1)
],
2,
)
# Get the log probabilities for only the prompt tokens
output_log_probs[:, context_start_position:context_end_position] = torch.gather(
log_probs, 2, indices
).squeeze(2)
context_start_position = context_end_position
# Check end of generation status for each tensor and update generated sequence lengths
(
is_generation_done_tensor,
generated_sequence_lengths,
) = self.update_generation_status(
updated_prompts_tokens=batch_prompt_tokens,
generation_started=generation_started,
current_context_end_position=context_end_position,
is_generation_done_tensor=is_generation_done_tensor,
generated_sequence_lengths=generated_sequence_lengths,
)
# Boolean flag indicating if all prompts are finished
all_prompts_done = torch.all(is_generation_done_tensor)
if all_prompts_done:
break
# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if common_inference_params.return_log_probs:
output_log_probs = output_log_probs[:, :context_end_position]
generated_sequence_lengths[
generated_sequence_lengths > common_inference_params.num_tokens_to_generate
] = common_inference_params.num_tokens_to_generate
for idx, request in enumerate(active_requests.values()):
input_prompt_length = int(prompt_lengths_in_batch[idx])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length = int(
min(generated_sequence_lengths[idx], common_inference_params.num_tokens_to_generate)
)
# Extract only the generated tokens
required_result_tokens = batch_prompt_tokens_with_generations[
idx, input_prompt_length : (input_prompt_length + required_sequence_length)
]
request.generated_length = required_sequence_length
request.generated_tokens = required_result_tokens
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, input_prompt_length:required_sequence_length]
)
request.status = Status.COMPLETED
request.generated_text = self.detokenize_generations(required_result_tokens)
return active_requests
class Counter:
"""A simple counter class
This class is responsible for assigning request ids to incoming requests
"""
def __init__(self, start: int = 0) -> None:
self.counter = start
def __next__(self) -> int:
i = self.counter
self.counter += 1
return i
def reset(self) -> None:
self.counter = 0
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert (
len(batch_idx) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
def __str__(self):
return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})"
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if (TORCH_MAJOR > 2) or (TORCH_MAJOR == 2 and TORCH_MINOR >= 2):
jit_fuser = torch.compile
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, ContextManager, Optional
import torch
@dataclass
class ModelParallelConfig:
"""Base configuration for Megatron Core
The initialization function has an argument for each parameter.
"""
###################
# Model parallelism
###################
tensor_model_parallel_size: int = 1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_size: int = 1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
virtual_pipeline_model_parallel_size: Optional[int] = None
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
"""
sequence_parallel: bool = False
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
"""
context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""
expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""
moe_extended_tp: bool = False
"""Alternative parallelization strategy for expert parallelism. Instead of distributing experts
across expert_model_parallel_size, each expert is sharded along extendended tensor parallel
domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing
problem with MOE training.
"""
###################
# Initialization
###################
perform_initialization: bool = True
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
"""
use_cpu_initialization: bool = False
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
weights from CPU to GPU can take a significant amount of time for large models.
"""
###################
# Training
###################
fp16: bool = False
"""If true, train with fp16 mixed precision training."""
bf16: bool = False
"""If true, train with bf16 mixed precision training."""
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights."""
timers: Callable = None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
finalize_model_grads_func: Callable = None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
dimensions.
"""
grad_scale_func: Callable = None
"""If using loss scaling, this function should take the loss and return the scaled loss. If
None, no function is called on the loss.
"""
no_sync_func: Callable = None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
"""
grad_sync_func: Callable = None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""
param_sync_func: Callable = None
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument: an iterable of parameters to
be synchronized.
"""
deterministic_mode: bool = False
"""If true, code that has deterministic execution will be chosen. This usually
means slower execution, but is good for debugging and testing. Defaults to False."""
enable_autocast: bool = False
"""If true runs the forward step function inside torch.autocast context."""
autocast_dtype: torch.dtype = None
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
"""
###################
# Optimizations
###################
gradient_accumulation_fusion: bool = False
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\"
--global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
"""
async_tensor_model_parallel_allreduce: bool = False
"""NOTE: Deprecated. This flag is ignored."""
use_te_rng_tracker: bool = False
"""If true, uses RNG state tracker in TransformerEngine if exists.
"""
tp_comm_overlap: bool = False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass.
"""
tp_comm_bulk_wgrad: bool = True
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_bulk_dgrad: bool = True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_overlap_ag: bool = True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs: bool = True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs_dgrad: bool = False
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_ag: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_ag: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather both
done atomically. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_rs: bool = True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_rs: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
cross_entropy_loss_fusion: bool = False
"""If this is enabled, the fused cross entropy implementation would be used.
Defaults to False.
"""
###################
# Pipeline Parallel
###################
pipeline_dtype: torch.dtype = None
"""dtype used in p2p communication, usually params_dtype"""
variable_seq_lengths: bool = False
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
"""
overlap_p2p_comm: bool = False
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
computation. Must be False if batch_p2p_comm is true.
"""
batch_p2p_comm: bool = True
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
overlap_p2p_comm is True.
"""
batch_p2p_sync: bool = True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
older version of PyTorch.
"""
use_ring_exchange_p2p: bool = False
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
custom built torch with torch.distributed.ring_exchange.
"""
deallocate_pipeline_outputs: bool = False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""
defer_embedding_wgrad_compute: bool = False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
taking place enabling us to hide pipeline flush latency. Defaults to False.
"""
wgrad_deferral_limit: int = 0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
needs to be deferred to pipeline flush, this argument is invalid if `defer_embedding_wgrad_compute` is False.
Defaults to 0, which means all micro-batches are deferred.
"""
pipeline_model_parallel_split_rank: Optional[int] = None
"""If int, rank where encoder and decoder should be split in cases where the model has both an
encoder and decoder (e.g., T5). Ignored if None.
"""
###################
# CPU Offloading
###################
cpu_offloading: bool = False
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
cpu_offloading_num_layers: int = 0
"""Tells the number of transformer layers for which activations has to be offloaded."""
_cpu_offloading_context: ContextManager = (
None # Used for internal use only, not to be set by the user. TODO: Need to move to the 'right' place when possible.
)
"""For internal use only, do not set."""
cpu_offloading_activations: bool = True
"""If True, offloads the activations to CPU."""
cpu_offloading_weights: bool = True
"""If True, offloads the weights to CPU."""
###################
# Timing
###################
barrier_with_L1_time: bool = True
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
raise ValueError(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if self.autocast_dtype is None:
self.autocast_dtype = self.params_dtype
if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1:
raise ValueError(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
)
if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion:
raise ValueError(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
)
if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0:
raise ValueError(
"Wgrad deferral limit should be greater than or equal to 0 when this optimization is enabled!"
)
if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
raise ValueError(
"When using expert parallelism and tensor parallelism, sequence parallelism must be used"
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import logging
from typing import List, Literal, Optional, Tuple
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
class T5LMHead(MegatronModule):
"""Masked LM head for T5
Args:
config (TransformerConfig): transformer config
parallel_output (bool): wether output logits being distributed or not.
vocab_size (int): vocabulary size
pre_process (bool): Include embedding layer
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
shared.
"""
def __init__(
self,
config: TransformerConfig,
parallel_output: bool,
vocab_size: int,
pre_process: bool = True,
share_embeddings_and_output_weights: bool = False,
):
super(T5LMHead, self).__init__(config=config)
self.parallel_output = parallel_output
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
vocab_size,
config=config,
init_method=config.init_method,
bias=share_embeddings_and_output_weights,
skip_bias_add=not share_embeddings_and_output_weights,
gather_output=not self.parallel_output,
skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
)
def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor:
"""Forward pass.
Args:
hidden_states (Tensor): output hidden states from decoder
word_embeddings_weight (Tensor): word embedding weight
Returns:
Tensor: logits tensor
"""
logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight)
return logits
class T5Model(LanguageModule):
"""T5 Language model.
Args:
config (TransformerConfig): transformer config
encoder_config (TransformerConfig): encoder transformer config
transformer_encoder_layer_spec (ModuleSpec): transformer layer customization specs for encoder
transformer_decoder_layer_spec (ModuleSpec): transformer layer customization specs for decoder
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This is used for positional embedding
pre_process (bool): Include embedding layer (used with pipeline parallelism)
post_process (bool): Include an output layer (used with pipeline parallelism)
fp16_lm_cross_entropy (bool, optional): Defaults to False
parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
shared. Defaults to False.
position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
Defaults is 'learned_absolute'.
rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
seq_len_interpolation_factor (float): scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
add_encoder (bool): Create the encoder (used with pipeline parallelism). When using pipelining,
the encoder will only be created on a subset of the pipeline ranks.
add_decoder (bool): Include an output layer (used with pipeline parallelism). As with `add_encoder`, when
using this model and pipelining, the decoder will only be created on a subset of the pipeline ranks.
"""
def __init__(
self,
config: TransformerConfig,
encoder_config: TransformerConfig,
transformer_encoder_layer_spec: ModuleSpec,
transformer_decoder_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
add_encoder: bool = True,
add_decoder: bool = True,
):
super(T5Model, self).__init__(config=config)
self.config: TransformerConfig = config
self.encoder_config: TransformerConfig = encoder_config
self.transformer_encoder_layer_spec: ModuleSpec = transformer_encoder_layer_spec
self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
self.encoder_hidden_state = None
# Tells schedules.py that this model has a skip connection between the encoder's output and the decoder
# (and hence both the encoder and decoder's tensors are required for correct backprop).
self.xattn_needed = True
# specify the position embeddings as a member variable in the T5 class
# so that they are easy to find for `finalize_model_grads._allreduce_position_embedding_grads`
self.position_embeddings = None
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
)
self.position_embeddings = self.embedding.position_embeddings
# Rotary Position Embeddings
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
)
# Transformer encoder
encoder_spec, decoder_spec = (
self.transformer_encoder_layer_spec,
self.transformer_decoder_layer_spec,
)
if self.add_encoder:
self.encoder = TransformerBlock(
config=self.encoder_config,
spec=encoder_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
else:
self.encoder = None
if self.add_decoder:
# Transformer decoder
self.decoder = TransformerBlock(
config=self.config,
spec=decoder_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
else:
self.decoder = None
# Output
if post_process:
self.lm_head = T5LMHead(
config,
parallel_output,
self.vocab_size,
self.pre_process,
self.share_embeddings_and_output_weights,
)
self.output_layer = self.lm_head.output_layer
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
def forward(
self,
encoder_input_ids: Tensor,
decoder_input_ids: Tensor,
encoder_attn_mask: Tensor,
decoder_attn_mask: Tensor,
encoder_decoder_attn_mask: Tensor,
lm_labels: Tensor = None,
encoder_hidden_states: Tensor = None,
output_encoder_hidden_only: bool = False,
inference_params: InferenceParams = None,
) -> Tensor:
"""Forward pass.
Args:
encoder_input_ids (Tensor): input ids for encoder
decoder_input_ids (Tensor): input ids for decoder
encoder_attn_mask (Tensor): self-attention mask for encoder
decoder_attn_mask (Tensor): self-attention mask for decoder
encoder_decoder_attn_mask (Tensor): cross-attention mask between encoder and decoder
lm_labels (Tensor): labels for decoder output
inference_params (InferenceParams): relevant arguments for inferencing
Returns:
Tensor: loss tensor
"""
(
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
) = t5_extended_attention_mask(
[encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]
)
## Encoder forward
if encoder_hidden_states is None:
# Encoder position ids
encoder_position_ids = t5_position_ids(encoder_input_ids)
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(
input_ids=encoder_input_ids, position_ids=encoder_position_ids
)
else:
# intermediate stage of pipeline
encoder_input = None
# Rotary positional embeddings
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run encoder.
if self.add_encoder:
encoder_hidden_states = self.encoder(
hidden_states=encoder_input,
attention_mask=encoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
else:
encoder_hidden_states = self.encoder_hidden_state
if not self.add_decoder or output_encoder_hidden_only:
return encoder_hidden_states
## Decoder forward
# Decoder position ids
decoder_position_ids = t5_position_ids(decoder_input_ids)
# Decoder embedding.
if self.pre_process:
decoder_input = self.embedding(
input_ids=decoder_input_ids, position_ids=decoder_position_ids
)
else:
# intermediate stage of pipeline
decoder_input = None ### should it take encoder_hidden_states
# Rotary positional embeddings
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
# Run decoder.
decoder_hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=decoder_attn_mask,
context=encoder_hidden_states,
context_mask=encoder_decoder_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
if self.post_process:
lm_logits = self.lm_head(
decoder_hidden_states, self.shared_embedding_or_output_weight()
)
if lm_labels is None:
# [s b h] => [b s h]
return lm_logits.transpose(0, 1).contiguous()
else:
# [b s] => [s b]
lm_loss = self.compute_language_model_loss(lm_labels, lm_logits)
return lm_loss
else:
return decoder_hidden_states
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert (
len(input_tensor) == 1
), 'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert (
len(input_tensor) == 1
), 'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def shared_embedding_or_output_weight(self) -> Tensor:
"""Function to share the input embeddings and output logit weights."""
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.lm_head.output_layer.weight
return None
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
assert not sharded_offsets, "Unexpected sharded offsets"
sharded_state_dict = {}
if self.pre_process:
embedding_prefix = f'{prefix}embedding.'
embedding_sharded_state_dict = self.embedding.sharded_state_dict(
prefix=embedding_prefix, metadata=metadata
)
sharded_state_dict.update(embedding_sharded_state_dict)
encoder_prefix = f'{prefix}encoder.'
encoder_sharded_state_dict = self.encoder.sharded_state_dict(
prefix=encoder_prefix, metadata=metadata
)
sharded_state_dict.update(encoder_sharded_state_dict)
decoder_prefix = f'{prefix}decoder.'
decoder_sharded_state_dict = self.decoder.sharded_state_dict(
prefix=decoder_prefix, metadata=metadata
)
sharded_state_dict.update(decoder_sharded_state_dict)
if self.post_process:
output_layer_prefix = f'{prefix}output_layer.'
output_layer_weight_key = f'{output_layer_prefix}weight'
output_layer_bias_key = f'{output_layer_prefix}bias'
if self.share_embeddings_and_output_weights:
if not self.pre_process:
# when sharing embeddings with last stage, we need to use the weights from the first stage
# on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight
tensor = self.shared_embedding_or_output_weight()
first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
dp_rank = parallel_state.get_data_parallel_rank()
dp_size = parallel_state.get_data_parallel_world_size()
last_stage_word_emb_replica_id = (
dp_rank + dp_size
) # copy of first stage embedding
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor
# output_layer.weight is shared, but we still need to process output_layer.bias
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=self.lm_head.output_layer.bias,
key=output_layer_bias_key,
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_bias_key] = sharded_output_layer_tensor
else:
output_layer_state_dict = self.output_layer.state_dict(
prefix=output_layer_prefix, keep_vars=True
)
output_layer_tensor = output_layer_state_dict[output_layer_weight_key]
# independent output layer
sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
tensor=output_layer_tensor,
key=output_layer_weight_key,
replica_id=parallel_state.get_data_parallel_rank(),
allow_shape_mismatch=True,
)
sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor
return sharded_state_dict
def t5_extended_attention_mask(attention_mask_list: List[Tensor]) -> List[Tensor]:
def attn_mask_postprocess(attn_mask):
# [b, 1, s, s]
extended_attention_mask = attn_mask.unsqueeze(1)
return extended_attention_mask
return [
(attn_mask_postprocess(attn_mask) if attn_mask is not None else None)
for attn_mask in attention_mask_list
]
def t5_position_ids(token_ids: Tensor) -> Tensor:
"""Calculate position ids from token ids
Args:
token_ids (Tensor): input tokens
Returns:
Tensor: position ids
"""
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import (
CrossAttention,
CrossAttentionSubmodules,
SelfAttention,
SelfAttentionSubmodules,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
"""T5 encoder TE spec (uses Transformer Engine components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear,
linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)
def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
"""T5 decoder TE spec (uses Transformer Engine components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_cross_attn_layernorm=TENorm,
cross_attention=ModuleSpec(
module=CrossAttention,
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
cross_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear,
linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)
def encoder_model_with_local_spec() -> ModuleSpec:
"""T5 encoder local spec (uses Megatron-Core components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def decoder_model_with_local_spec() -> ModuleSpec:
"""T5 decoder local spec (uses Megatron-Core components)."""
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_cross_attn_layernorm=LNImpl,
cross_attention=ModuleSpec(
module=CrossAttention,
submodules=CrossAttentionSubmodules(
linear_q=ColumnParallelLinear,
linear_kv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
cross_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def get_t5_encoder_with_transformer_engine_block_spec(
num_layers: int,
) -> TransformerBlockSubmodules:
"""T5 encoder block spec for Transformer Engine
Args:
config (TransformerConfig): config, containing number of layers for encoder
"""
layer_spec = encoder_model_with_transformer_engine_default_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
def get_t5_decoder_with_transformer_engine_block_spec(
num_layers: int,
) -> TransformerBlockSubmodules:
"""T5 decoder block spec for Transformer Engine
Args:
config (TransformerConfig): config, containing number of layers for decoder
"""
layer_spec = decoder_model_with_transformer_engine_default_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
def get_t5_encoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
"""T5 encoder block spec for local (uses Megatron-Core components)
Args:
num_layers (int): number of encoder layers
"""
layer_spec = encoder_model_with_local_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
def get_t5_decoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
"""T5 decoder block spec for local (uses Megatron-Core components)
Args:
num_layers (int): number of decoder layers
"""
layer_spec = decoder_model_with_local_spec()
block_spec = TransformerBlockSubmodules([layer_spec] * num_layers, layer_norm=TENorm)
return block_spec
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
bert_layer_with_transformer_engine_spec = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear,
linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)
# Use this spec for an implementation using only modules in megatron core
bert_layer_local_spec = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment