Commit 523ec9cc authored by wangsen's avatar wangsen
Browse files

all

parents
Pipeline #1668 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
import torch
from megatron.core import parallel_state
def switch_load_balancing_loss_func(
probs: torch.Tensor, tokens_per_expert: torch.Tensor, topk: int, moe_aux_loss_coeff: float
):
"""Calculate the auxiliary loss for better load balacing.
Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
Args:
probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts]
tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts]
Returns:
torch.Tensor: The auxiliary loss for load balancing.
"""
num_tokens = probs.shape[0] * topk
num_experts = probs.shape[1]
probs_mean_per_expert = probs.mean(dim=0)
aux_loss = torch.sum(probs_mean_per_expert * tokens_per_expert) * (
num_experts / num_tokens * moe_aux_loss_coeff
)
return aux_loss
def z_loss_func(logits, z_loss_coeff):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss
def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
"""Sinkhorn based MoE routing function"""
cost = torch.exp(cost)
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
eps = 0.00000001
error = 1e9
d1_old = d1
while error > tol:
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
error = torch.mean(torch.abs(d1_old - d1))
d1_old = d1
return d1 * cost * d0.unsqueeze(1)
def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None):
"""
Calculate the capacity of each expert.
Args:
num_tokens (int): num of the input tokens.
num_experts (int): num of the experts.
capacity_factor (float): Capacity factor.
min_capacity (int, optional): Minimum capacity. Defaults to None.
Returns:
Tensor: Capacity of each expert.
"""
capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
if min_capacity is not None and capacity < min_capacity:
capacity = min_capacity
return capacity
class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that compute and scales the grad for auxiliary loss.
"""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
@staticmethod
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
"""Preserve the aux_loss by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
aux_loss (torch.Tensor): The auxiliary loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(aux_loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for auxiliary loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
"""
(aux_loss,) = ctx.saved_tensors
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
return grad_output, scaled_aux_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the aux loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
"""
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False):
"""Permute the tokens based on the indices. Token with the same index will be grouped together.
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
num_out_tokens (int, optional): The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
Returns:
torch.Tensor: The permuted tensor.
torch.Tensor: The sorted_indices corresponding permuted tensor.
"""
if padded_mode:
return permute_with_padded_tokens(tokens, indices)
if indices.dim() == 1:
topk = 1
else:
topk = indices.size(1)
flatten_indices = indices.view(-1)
sorted_indices = torch.argsort(flatten_indices, stable=True)
if num_out_tokens is not None:
sorted_indices = sorted_indices[:num_out_tokens]
permuted_tokens = tokens.index_select(0, sorted_indices // topk)
return permuted_tokens, sorted_indices
def unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor = None,
padded_mode: bool = False,
restore_shape: torch.Size = None,
):
"""Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
Args:
permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None.
Returns:
torch.Tensor: The unpermuted tokens, optionally merged with probabilities.
"""
if padded_mode:
return unpermute_with_padded_tokens(
permuted_tokens, sorted_indices, probs, restore_shape=restore_shape
)
assert sorted_indices.numel() == permuted_tokens.size(0)
if probs is not None:
# Unpermute and merge the tokens with their probabilities
num_unpermuted_tokens = probs.numel()
topk = probs.size(1)
else:
# Unpermute the tokens without merge
num_unpermuted_tokens = permuted_tokens.size(0)
topk = 1
unpermuted_tokens = torch.zeros(
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
if probs is not None:
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
return unpermuted_tokens
def permute_with_padded_tokens(tokens, indices):
"""Permute the tokens based on the indices, only used in padding mode.
The input indices shape is [num_expert, capacity], it indicates which tokens were selected by each expert separately.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
Returns:
torch.Tensor: The permuted tensor.
torch.Tensor: The sorted_indices corresponding permuted tensor.
"""
permuted_tokens = tokens.index_select(dim=0, index=indices.view(-1))
return permuted_tokens, indices
def unpermute_with_padded_tokens(
permuted_tokens: torch.Tensor,
indices: torch.Tensor,
probs: torch.Tensor,
restore_shape: torch.Size,
) -> torch.Tensor:
"""
Unpermutes a padded permuted tokens based on sorted indices and merges the tokens with their corresponding probabilities.
This function takes a tensor of permuted tokens and reorders them according to the provided indices. It also combines the tokens with their associated probabilities.
Parameters:
permuted_tokens (torch.Tensor): A 2D tensor containing permuted tokens.
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities corresponding to each token.
restore_shape (torch.Size): The target shape for the unpermuted tokens tensor.
Returns:
torch.Tensor: A tensor of unpermuted tokens, merged with their probabilities.
"""
# Ensure permuted_tokens is 2D
assert permuted_tokens.dim() == 2, f"Got {permuted_tokens.dim()}D."
# Reshape and expand probabilities and indices to match permuted_tokens
probs = probs.view(-1).unsqueeze(-1)
indices = indices.view(-1, 1).expand(-1, permuted_tokens.shape[1])
assert (
permuted_tokens.shape == indices.shape
), "Shape mismatch between permuted_tokens and indices."
# Combine tokens with their probabilities
combined_output = probs * permuted_tokens
# Prepare a tensor of zeros with the desired output shape
empty_tokens = torch.zeros(
restore_shape,
dtype=combined_output.dtype,
device=combined_output.device,
requires_grad=True,
)
# Scatter the combined tokens back to their original positions
unpermuted_tokens = torch.scatter_add(empty_tokens, 0, indices, combined_output)
return unpermuted_tokens
def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
capacity_factor: float = None,
pad_to_capacity: bool = False,
drop_policy: str = "probs",
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor.
(1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token.
(2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert.
"""
# TODO: Add Pre softmax.
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens = logits.shape[0]
num_experts = logits.shape[1]
scores, top_indices = torch.topk(logits, k=topk, dim=1)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
if capacity_factor is None:
# TopK without capacity
tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts)
return probs, top_indices, tokens_per_expert
else:
# TopK with capacity
expert_capacity = get_capacity(
num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor,
)
# TopK selection, Maskout unused experts
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_mask = torch.zeros_like(logits).scatter(1, top_indices, 1)
# Maskout exceeded tokens
if drop_policy == "probs":
capacity_probs, capacity_indices = torch.topk(
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
elif drop_policy == "position":
_, capacity_indices = torch.topk(topk_mask, k=expert_capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices)
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")
if pad_to_capacity:
final_probs, final_indices = (
capacity_probs.T.contiguous(),
capacity_indices.T.contiguous(),
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
else:
# Get exceed mask and maskout exceeded probs and indices
final_mask = torch.logical_and(topk_mask, capacity_mask)
drop_mask = torch.logical_not(final_mask)
exceed_mask = torch.gather(drop_mask, 1, top_indices)
final_probs = probs * torch.logical_not(exceed_mask)
final_indices = top_indices.clone().masked_fill_(
exceed_mask, torch.iinfo(torch.long).max
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
return final_probs, final_indices, tokens_per_expert_before_capacity
def save_to_aux_losses_tracker(name: str, loss: torch.Tensor, layer_number: int, num_layers: int):
"""Save the auxiliary loss for logging.
Args:
name (str): The name of the loss.
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
"""
# Skip aux loss logging if layer_number is None.
if layer_number is None:
return
if name not in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name] = torch.zeros(
num_layers, device=loss.device
)
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name][layer_number - 1] += loss.detach()
def clear_aux_losses_tracker():
"""Clear the auxiliary losses."""
for name in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name].zero_()
def get_aux_losses_tracker():
"""Return the auxiliary losses."""
return parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER
def aggregate_aux_losses_tracker_across_pipeline_parallel():
"""Sum aux losses across PP."""
for name in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
loss = parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name]
torch.distributed.all_reduce(loss, group=parallel_state.get_pipeline_model_parallel_group())
def track_moe_metrics(
loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False
):
# Aux loss logging
aggregate_aux_losses_tracker_across_pipeline_parallel()
if writer is not None:
aux_losses = {k: v.float() * loss_scale for k, v in get_aux_losses_tracker().items()}
for name, loss_list in aux_losses.items():
if total_loss_dict is not None:
if name not in total_loss_dict:
total_loss_dict[name] = loss_list.mean()
else:
total_loss_dict[name] += loss_list.mean()
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
writer.add_scalar(name, loss_list.mean(), iteration)
if per_layer_logging:
for i, loss in enumerate(loss_list.tolist()):
writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)
# W&B logging lacks support for logging multiple scalars simultaneously.
# As a workaround, we log each scalar individually first, then we can create
# a custom panel to manually group them to a single plot.
if wandb_writer:
wandb_writer.log({f"{name}": loss_list.mean()}, iteration)
if per_layer_logging:
wandb_writer.log(
{
f"moe/{name}_layer_{i}": loss
for i, loss in enumerate(loss_list.tolist())
},
iteration,
)
clear_aux_losses_tracker()
class moe_gather(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, map_):
ctx.input_size = input_.size()
ctx.map = map_
return torch.gather(input_, 0, map_)
@staticmethod
def backward(ctx, grad_output):
input_size = ctx.input_size
map_ = ctx.map
output = torch.zeros(
input_size, dtype=grad_output.dtype, device=torch.cuda.current_device()
)
output.scatter_add_(0, map_, grad_output)
return output, None, None
class moe_scatter(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, map_, output_size=None):
ctx.map = map_
if output_size is not None:
output = torch.zeros(
output_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
else:
output = torch.zeros_like(input_)
output.scatter_add_(0, map_, input_)
return output
@staticmethod
def backward(ctx, grad_output):
map_ = ctx.map
grad_input = torch.gather(grad_output, 0, map_)
return grad_input, None, None, None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel import (
gather_from_sequence_parallel_region,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
save_to_aux_losses_tracker,
sinkhorn,
switch_load_balancing_loss_func,
topk_softmax_with_capacity,
z_loss_func,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class Router(ABC, MegatronModule):
"""Base Router class"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the Router module.
Args:
config (TransformerConfig): Configuration object for the Transformer model.
"""
super().__init__(config)
self.config = config
self.num_experts = self.config.num_moe_experts
self.moe_aux_loss_func = None
self.layer_number = None
# Initialize the gate weights.
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size))
)
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
config.init_method(self.weight)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Logits tensor.
"""
logits = torch.nn.functional.linear(input, self.weight)
return logits
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices.
"""
raise NotImplementedError("Routing function not implemented.")
@abstractmethod
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
raise NotImplementedError("Forward function not implemented.")
def set_layer_number(self, layer_number: int):
"""Set the layer number for the router."""
self.layer_number = layer_number
class TopKRouter(Router):
"""Route each token to the top-k experts."""
def __init__(self, config: TransformerConfig,) -> None:
"""Initialize the zero token dropping router.
Args:
config (TransformerConfig): The configuration for the transformer model.
"""
super().__init__(config=config)
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.input_jitter = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
torch.Tensor: The logits tensor after applying sinkhorn routing.
"""
def _sinkhorn_activation(logits):
if self.topk == 1:
logits = torch.sigmoid(logits)
else: # k > 1
logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
return logits
assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
if self.training:
with torch.no_grad():
norm_logits = sinkhorn(
logits.to(dtype=torch.float32)
) # explicit fp32 conversion for stability
_, indices = torch.topk(norm_logits, k=self.topk, dim=1)
logits = _sinkhorn_activation(logits)
scores = torch.gather(logits, 1, indices)
else:
logits = _sinkhorn_activation(logits)
scores, indices = torch.topk(logits, k=self.topk, dim=1)
return scores, indices
def aux_loss_load_balancing(self, logits: torch.Tensor):
"""Apply loss-based load balancing to the logits tensor.
Args:
logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): the probabilities tensor after load balancing.
indices (torch.Tensor): the indices tensor after top-k selection.
"""
probs, indices, tokens_per_expert = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
)
# Apply load balancing loss
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs)
return probs, indices
def apply_load_balancing_loss(
self,
probs: torch.Tensor,
num_local_tokens_per_expert: torch.Tensor,
activation: torch.Tensor,
):
"""Applies auxiliary loss to the MoE layer.
Args:
probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts]
activation (torch.Tensor): The activation tensor to attach the gradient function to.
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
moe_aux_loss_coeff = (
self.config.moe_aux_loss_coeff / parallel_state.get_tensor_model_parallel_world_size()
)
aux_loss = switch_load_balancing_loss_func(
probs, num_local_tokens_per_expert, self.topk, moe_aux_loss_coeff
)
save_to_aux_losses_tracker(
"load_balancing_loss",
aux_loss / moe_aux_loss_coeff,
self.layer_number,
self.config.num_layers,
)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def apply_z_loss(self, logits):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None:
moe_z_loss_coeff = (
self.config.moe_z_loss_coeff / parallel_state.get_tensor_model_parallel_world_size()
)
z_loss = z_loss_func(logits, moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
save_to_aux_losses_tracker(
"z_loss",
z_loss / self.config.moe_z_loss_coeff,
self.layer_number,
self.config.num_layers,
)
return logits
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, device=input.device),
torch.tensor(1.0 + eps, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
Returns:
probs (torch.Tensor): the probabilities tensor after load balancing.
indices (torch.Tensor): the indices tensor after top-k selection.
"""
logits = logits.view(-1, self.config.num_moe_experts)
# Apply Z-Loss
logits = self.apply_z_loss(logits)
if (
parallel_state.get_tensor_model_parallel_world_size() > 1
and self.config.moe_token_dispatcher_type == "alltoall"
):
# Gather the logits from the TP region
logits = gather_from_sequence_parallel_region(logits)
if self.routing_type == "sinkhorn":
scores, indices = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, indices = self.aux_loss_load_balancing(logits)
elif self.routing_type == "none":
# A naive top-k routing without load balancing
scores, indices, _ = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
return scores, indices
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
self.hidden = input.shape[-1]
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
logits = logits.view(-1, self.config.num_moe_experts)
scores, indices = self.routing(logits)
return scores, indices
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import abstractmethod
from typing import List, Optional, Tuple
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel
from megatron.core.transformer.moe.moe_utils import moe_gather, moe_scatter, permute, unpermute
from megatron.core.transformer.transformer_config import TransformerConfig
class MoETokenDispatcher:
"""
MoE Token Dispatcher
"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.config = config
@abstractmethod
def token_permutation(
self, tokens: torch.Tensor, indices: torch.Tensor,
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
indices (torch.Tensor): indices tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(
self, expert_output: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
probs (torch.Tensor): Each token's score with each expert.
indices (torch.Tensor): The indices used to reorder the expert output.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise NotImplementedError("Restore function not implemented.")
class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
"""
AllGather Based Token dispatcher.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the zero token dropping router.
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert len(self.local_expert_indices) > 0, "Expected at least one local expert index"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
# self.local_probs: probs of global token assignment to local experts.
self.local_probs = None
# self.indices: The indices of `local_indices` (which holds the un-sorted expert indices of tokens that local expert can process) that give its sorted order along dim 0.
self.indices = None
# self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where each element is True if it's between the local_expert_indices. Only useful when cross device token permutation is enabled and **AllGahter** is performed.
self.global_local_map = None
def token_permutation(
self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind: torch.Tensor
):
"""Dispatch tokens to local experts. It's composed of two stages:
(1) Permute the tokens across the expert parallel devices. After this stage,
each device receives all of the tokens assigned to its local set of experts
in its local HBM.
(2) Permute the tokens locally so that they are grouped by their expert
assignment. After the stage (1), the tokens are grouped by which device
they came from. We re-order them locally for subsequent efficient computation.
Args:
hidden_states: input tokens of shape [SeqLen/TP, MBS, HiddenSize]
max_prob: probs of local token assignment to global experts.
max_ind: token assignment to local experts.
Returns:
permuted_local_hidden_states: Permutation of tokens to local experts group.
tokens_per_expert: the number of tokens each local expert to process.
"""
self.hidden_shape = hidden_states.shape
# [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Permute the tokens across the expert parallel devices.
if (self.config.tensor_model_parallel_size > 1) or (
self.config.expert_model_parallel_size > 1
):
with torch.no_grad():
global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
max_ind
)
# Create a mask of mapping between global and local tokens where each
# element is True if it's between the local_expert_indices
global_local_mask = (global_indices >= self.local_expert_indices[0]) & (
global_indices <= self.local_expert_indices[-1]
)
local_indices = global_indices.masked_select(global_local_mask)
if self.router_topk > 1: # k > 1
global_probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(max_prob)
self.local_probs = global_probs.masked_select(global_local_mask)
else:
self.local_probs = max_prob
# [S*B/TP, H] -> [S*B, H]
global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
hidden_states, use_global_buffer=True
)
# Reshape global_local_mask to be compatible with Tensor.gather
global_local_map = global_local_mask.nonzero()[:, 0]
self.global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
local_hidden_states = moe_gather.apply(global_hidden_states, self.global_local_map)
else:
if self.router_topk > 1:
global_local_mask = torch.ones_like(max_ind).bool()
local_indices = max_ind.masked_select(global_local_mask)
self.local_probs = max_prob.masked_select(global_local_mask)
global_local_map = global_local_mask.nonzero()[:, 0]
self.global_local_map = global_local_map.view(-1, 1).expand(
-1, hidden_states.shape[-1]
)
local_hidden_states = torch.gather(hidden_states, 0, self.global_local_map)
else:
local_indices = max_ind
self.local_probs = max_prob
local_hidden_states = hidden_states
self.global_local_map = None
with torch.no_grad():
# The indices of local_indices that give its sorted order along dim 0.
self.indices = torch.argsort(local_indices, dim=0)
tokens_per_expert = torch.histc(
local_indices,
bins=self.num_local_experts,
min=self.local_expert_indices[0],
max=self.local_expert_indices[-1],
)
tokens_per_expert = tokens_per_expert.cpu().to(torch.long)
# Stage2: permute the tokens locally so that they are grouped by their expert assignment
# Reshape indices to be compatible with Tensor.gather
self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
if self.num_local_experts > 1:
permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices)
else:
permuted_local_hidden_states = local_hidden_states
return (
permuted_local_hidden_states,
tokens_per_expert,
)
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
):
"""
Reverse process of `dispatch()` which permutes the ouput of local
experts locallay and across expert parallel rank into the original order to
produce the final output.
Args:
hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize],
ouput of local experts.
bias (optional): The bias tensor.
Returns:
output_total: un-permuted updated hidden states output from all local experts
with shape of [SeqLen/TP, MBS, HiddenSize]
"""
# Stage1: unpermute the tokens and bias locally respectively.
scores = self.local_probs.to(dtype=hidden_states.dtype)
if self.num_local_experts > 1:
assert self.indices.shape == hidden_states.shape
unpermuted_local_hidden = moe_scatter.apply(hidden_states, self.indices)
else:
unpermuted_local_hidden = hidden_states
# Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
if self.router_topk > 1:
unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
unpermuted_local_bias = None
if self.add_bias:
assert bias is not None
unpermuted_local_bias = torch.zeros_like(hidden_states)
assert self.indices.shape == bias.shape
unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias)
if self.router_topk > 1:
unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
output_total = unpermuted_local_hidden
output_bias_total = unpermuted_local_bias
# Unpermute the tokens across expert parallel devices.
if (self.config.tensor_model_parallel_size > 1) or (
self.config.expert_model_parallel_size > 1
):
assert (
self.global_local_map is not None
), "global_local_map is necessary for `AllGather`."
ep_group_size = parallel_state.get_tensor_and_expert_parallel_world_size()
# hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP)
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
assert self.global_local_map.shape == unpermuted_local_hidden.shape
unpermuted_global_hidden = moe_scatter.apply(
unpermuted_local_hidden, self.global_local_map, global_hidden_shape
)
output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_hidden
)
if self.add_bias:
# Unpermute the bias across expert parallel devices.
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
unpermuted_global_bias = unpermuted_global_bias.scatter_add(
0, self.global_local_map, unpermuted_local_bias
)
output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_bias
)
# bias is duplicated across tensor parallelism ranks;
# reduce scatter reduces bias across tensor parallel_ranks
output_bias_total = (
output_bias_total / parallel_state.get_tensor_model_parallel_world_size()
)
else:
if self.router_topk > 1:
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
unpermuted_global_hidden = torch.zeros(
global_hidden_shape,
dtype=hidden_states.dtype,
device=torch.cuda.current_device(),
)
output_total = unpermuted_global_hidden.scatter_add(
0, self.global_local_map, unpermuted_local_hidden
)
if self.add_bias:
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
output_bias_total = unpermuted_global_bias.scatter_add(
0, self.global_local_map, unpermuted_local_bias
)
if self.router_topk == 1:
output_total = output_total * scores
output_total = output_total.view(self.hidden_shape)
if self.add_bias:
assert output_bias_total is not None
if self.router_topk == 1:
output_bias_total = output_bias_total * scores
output_bias_total = output_bias_total.view(self.hidden_shape)
else:
output_bias_total = None
return output_total, output_bias_total
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
AlltoAll Based Token dispatcher.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the AlltoAll token dispatcher.
Args:
num_local_experts (int): Number of local experts on the current device.
local_expert_indices (List[int]): Indices of local experts on the current device.
config (TransformerConfig): Configuration for the transformer model.
"""
super().__init__(config=config)
self.hidden_shape = None
self.num_input_tokens = None
self.num_local_experts = num_local_experts
self.num_experts = config.num_moe_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert (
len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
self.ep_size = config.expert_model_parallel_size
self.probs = None
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
# Token drop and padding.
# We need to keep track of the token num if we drop tokens without padding them.
self.num_out_tokens = None
# Drop and pad the input to capacity.
self.drop_and_pad = self.config.moe_pad_expert_input_to_capacity
if self.drop_and_pad:
assert self.config.moe_expert_capacity_factor is not None
self.capacity = None
def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
"""
Preprocess token indices for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the input indices.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts.
Args:
indices (torch.Tensor): Tensor of indices mapping tokens to experts.
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
num_local_tokens_per_expert = torch.histc(
indices, bins=self.num_experts, min=0, max=self.num_experts
)
# num_local_tokens_per_expert: [num_experts]
ep_size = self.config.expert_model_parallel_size
if self.drop_and_pad:
# probs: [num_experts, capacity]
self.capacity = self.probs.size(1)
num_tokens_per_local_expert = torch.full(
(self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long
)
return num_tokens_per_local_expert
elif self.config.moe_expert_capacity_factor is not None:
self.num_out_tokens = num_local_tokens_per_expert.sum().cpu()
if ep_size > 1:
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
.sum(axis=1)
.to(torch.device("cpu"))
.numpy()
)
num_global_tokens_per_expert = _gather_along_first_dim_expert_parallel(
num_local_tokens_per_expert
).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, self.local_expert_indices
]
self.output_splits = (
self.num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu")).numpy()
)
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0).to(
torch.device("cpu"), non_blocking=True
)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
else:
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
-1, self.num_experts
)
num_tokens_per_local_expert = num_local_tokens_per_expert.to(
torch.device("cpu"), non_blocking=True
)
if self.num_local_experts > 1:
expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.config.num_moe_experts)],
dtype=torch.int32,
device=torch.cuda.current_device(),
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
)
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): Probs of tokens assigned to experts.
indices (torch.Tensor): Indices of tokens assigned to experts.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
self.probs = probs
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert indices.dim() == 2, "Expected 2D tensor for indices"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(indices)
# Perform tensor parallel AlltoAll communication
# hidden_states: [S*B/TP, H] -> [S*B, H/TP]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states)
# Permutation 1: input to AlltoAll input
self.hiddden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
indices,
num_out_tokens=self.num_out_tokens,
padded_mode=self.drop_and_pad,
)
# Perform expert parallel AlltoAll communication
global_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
)
# Permutation 2: Sort alltoall output by local experts when num_local_experts > 1.
if self.num_local_experts > 1:
if not self.drop_and_pad:
global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
global_input_tokens, self.global_input_tokens_local_experts_indices
)
else:
global_input_tokens = global_input_tokens.reshape(
self.ep_size, self.num_local_experts, self.capacity, -1
)
global_input_tokens = (
global_input_tokens.transpose(0, 1)
.reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
.contiguous()
)
# Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(
global_input_tokens
)
return global_input_tokens, tokens_per_expert
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverse the token permutation to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
# Perform tensor parallel Reduce-Scatter
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.reduce_scatter_last_dim_to_tensor_parallel_region(
hidden_states
)
# Unpermutation 2: expert output to AlltoAll input
if self.num_local_experts > 1:
if not self.drop_and_pad:
hidden_states = unpermute(
hidden_states, self.reversed_global_input_permutation_mapping,
)
else:
hidden_states = hidden_states.reshape(
self.num_local_experts, self.ep_size, self.capacity, -1
)
hidden_states = (
hidden_states.transpose(0, 1)
.reshape(self.ep_size * self.num_local_experts * self.capacity, -1)
.contiguous()
)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
hidden_states,
self.input_splits,
self.output_splits,
)
# Unpermutation 1: AlltoAll output to output
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
probs=self.probs,
padded_mode=self.drop_and_pad,
restore_shape=self.hiddden_shape_before_permute,
)
# Perform tensor parallel AlltoAll communication
# output: [S*B, H/TP] -> [S*B/TP, H]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
output = tensor_parallel.all_to_all_hp2sp(output)
# Reshape the output tensor
output = output.view(self.hidden_shape)
return output, None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import types
from dataclasses import dataclass, field
from typing import Tuple, Union
import torch
@dataclass
class ModuleSpec:
"""This is a Module Specification dataclass.
Specification defines the location of the module (to import dynamically)
or the imported module itself. It also defines the params that need to be
passed to initialize the module.
Args:
module (Union[Tuple, type]): A tuple describing the location of the
module class e.g. `(module.location, ModuleClass)` or the imported
module class itself e.g. `ModuleClass` (which is already imported
using `from module.location import ModuleClass`).
params (dict): A dictionary of params that need to be passed while init.
"""
module: Union[Tuple, type]
params: dict = field(default_factory=lambda: {})
submodules: type = None
def import_module(module_path: Tuple[str]):
"""Import a named object from a module in the context of this function.
TODO: make this importer module more robust, at least make sure there
are no side effects of using this as is
"""
base_path, name = module_path
try:
module = __import__(base_path, globals(), locals(), [name])
except ImportError as e:
print(f"couldn't import module due to {e}")
return None
return vars(module)[name]
def get_module(spec_or_module: Union[ModuleSpec, type], **additional_kwargs):
# If a module clas is already provided return it as is
if isinstance(spec_or_module, (type, types.FunctionType)):
return spec_or_module
# If the module is provided instead of module path, then return it as is
if isinstance(spec_or_module.module, (type, types.FunctionType)):
return spec_or_module.module
# Otherwise, return the dynamically imported module from the module path
return import_module(spec_or_module.module)
def build_module(spec_or_module: Union[ModuleSpec, type], *args, **kwargs):
# If the passed `spec_or_module` is
# a `Function`, then return it as it is
# NOTE: to support an already initialized module add the following condition
# `or isinstance(spec_or_module, torch.nn.Module)` to the following if check
if isinstance(spec_or_module, types.FunctionType):
return spec_or_module
# If the passed `spec_or_module` is actually a spec (instance of
# `ModuleSpec`) and it specifies a `Function` using its `module`
# field, return the `Function` as it is
if isinstance(spec_or_module, ModuleSpec) and isinstance(
spec_or_module.module, types.FunctionType
):
return spec_or_module.module
# Check if a module class is provided as a spec or if the module path
# itself is a class
if isinstance(spec_or_module, type):
module = spec_or_module
elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
module = spec_or_module.module
else:
# Otherwise, dynamically import the module from the module path
module = import_module(spec_or_module.module)
# If the imported module is actually a `Function` return it as it is
if isinstance(module, types.FunctionType):
return module
# Finally return the initialized module with params from the spec as well
# as those passed as **kwargs from the code
# Add the `submodules` argument to the module init call if it exists in the
# spec.
if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
kwargs["submodules"] = spec_or_module.submodules
try:
return module(
*args, **spec_or_module.params if hasattr(spec_or_module, "params") else {}, **kwargs
)
except Exception as e:
# improve the error message since we hide the module name in the line above
import sys
tb = sys.exc_info()[2]
raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
sys.exc_info()[2]
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Tuple, Union
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.packed_seq_params import PackedSeqParams
try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDelayedScaling,
TENorm,
get_cpu_offload_context,
te_checkpoint,
)
except ImportError:
TEDelayedScaling = None
TENorm = None
get_cpu_offload_context = None
te_checkpoint = None
#print("Do not support transformer_engine")
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import BaseTransformerLayer, TransformerLayer
from megatron.core.transformer.utils import sharded_state_dict_default
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
def get_num_layers_to_build(config: TransformerConfig) -> int:
num_layers_per_pipeline_rank = (
config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# Interleaved pipeline parallelism:
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
num_layers_to_build = num_layers_per_virtual_rank
else:
# Non-interleaved pipeline parallelism:
# Each stage gets a contiguous set of layers.
num_layers_to_build = num_layers_per_pipeline_rank
return num_layers_to_build
@dataclass
class TransformerBlockSubmodules:
layer_specs: List[ModuleSpec] = None
def _get_block_submodules(
config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec],
) -> TransformerBlockSubmodules:
# Transformer block submodules.
if isinstance(spec, TransformerBlockSubmodules):
return spec
# ModuleSpec here is generally assumed to be for a transformer layer that
# is implemented in `transformer_layer.py` or if it subclasses
# `BaseTransformerLayer` from the `transformer_layer.py` file.
elif isinstance(spec, ModuleSpec):
if issubclass(spec.module, TransformerBlock):
return spec.submodules
elif issubclass(spec.module, BaseTransformerLayer):
num_layers = get_num_layers_to_build(config)
return TransformerBlockSubmodules(layer_specs=[spec] * num_layers)
else:
raise Exception(f"specialize for {spec.module.__name__}.")
else:
raise Exception(f"specialize for {type(spec).__name__}.")
class TransformerBlock(MegatronModule):
"""Transformer class."""
def __init__(
self,
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
):
super().__init__(config=config)
self.submodules = _get_block_submodules(config, spec)
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
# Dictionary to store CUDA graphs. Number of items in the dictionary = len(self.layers).
# Item `i` in the dictionary is a list of `N` CUDA graphs for layer 'i' where N is the
# number of microbatches. Multiple CUDA graphs per layer is required to support
# pipelining which requires running FWD graph of multiple microbatches before BWD graph.
self.cuda_graphs = {}
self.current_microbatch = -1
# required for pipeline parallel schedules
self.input_tensor = None
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
if get_cpu_offload_context is not None:
(
self.offload_context,
self.group_prefetch_offload_commit_async,
) = get_cpu_offload_context(
self.config.cpu_offloading,
self.config.cpu_offloading_num_layers,
self.config.cpu_offloading_activations,
self.config.cpu_offloading_weights,
)
self.config._cpu_offloading_context = (
self.offload_context if self.config.cpu_offloading else None
)
else:
assert (
self.config.cpu_offloading == False
), "CPU Offloading is enabled when TE is not present"
self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None
self.config._cpu_offloading_context = None
self._build_layers()
self.num_layers_per_pipeline_rank = len(self.layers)
def _build_layers(self):
# Transformer layers.
# @jcasper can we improve how we deal with layer_number?
# currently it's only used in CoreAttention?
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
def build_layer(layer_spec, layer_number):
return build_module(layer_spec, config=self.config, layer_number=layer_number,)
# offset is implicit in TransformerLayer
self.layers = torch.nn.ModuleList(
[
build_layer(layer_spec, i + 1)
for i, layer_spec in enumerate(self.submodules.layer_specs)
]
)
# # TODO: add back standalone_embedding_stage
# if self.num_layers == 0:
# # When a standalone embedding stage is used (e.g.,
# # args.standalone_embedding_stage == True), virtual pipeline ranks
# # on pipeline rank 0 will have zero transformer layers assigned to
# # them. This results in the model's input and output tensors to be
# # the same, which will cause failure for certain output tensor
# # optimizations (e.g., pipeline output deallocation). To remedy
# # this, we assign a 'no-op' layer on these ranks, which will
# # disconnect the input tensor from the output tensor.
# self.num_layers = 1
# self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
# else:
# self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = TENorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
def _get_layer(self, layer_number: int):
return self.layers[layer_number]
def _checkpointed_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor,
context_mask: Tensor,
rotary_pos_emb: Tensor,
packed_seq_params: PackedSeqParams,
):
"""Forward method with activation checkpointing."""
def custom(start: int, end: int):
def custom_forward(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
):
for index in range(start, end):
layer = self._get_layer(index)
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=None,
packed_seq_params=packed_seq_params,
)
return hidden_states, context
return custom_forward
def checkpoint_handler(forward_func):
if self.config.fp8:
return te_checkpoint(
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
return tensor_parallel.checkpoint(
forward_func,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers_per_pipeline_rank:
hidden_states, context = checkpoint_handler(
custom(l, l + self.config.recompute_num_layers)
)
l += self.config.recompute_num_layers
elif self.config.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
recompute_skip_num_layers = 0
for l in range(self.num_layers_per_pipeline_rank):
# Skip recomputation when input grad computation is not needed.
# Need to have at least one input tensor with gradient computation
# for re-enterant autograd engine.
if self.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if (
l >= recompute_skip_num_layers
and l < self.config.recompute_num_layers + recompute_skip_num_layers
):
hidden_states, context = checkpoint_handler(custom(l, l + 1))
else:
hidden_states, context = custom(l, l + 1)(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def set_input_tensor(self, input_tensor: Tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
):
# hidden_states (float): [s, b, h]
# attention_mask (bool): [1, 1, s, s]
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True,
)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context and fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
else:
for l_no, layer in enumerate(self.layers):
with self.offload_context:
if (len(self.cuda_graphs) == 0) or (not self.training):
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
)
# CUDA graph doesn't output context and is expected to be None
assert (
(context is None)
or (not self.config.enable_cuda_graph)
or (not self.training)
)
else:
# CUDA graph replay for layer `l_no` and microbatch `self.current_microbatch`
# CUDA graph requires positional arguments with the exception of is_first_microbatch.
# Also CUDA graph accepts only Tensor inputs and outputs. Hence, the arg list and
# returned list is limited to `hidden_states`.
assert (len(self.cuda_graphs) > l_no) and (
self.current_microbatch < len(self.cuda_graphs[l_no])
)
hidden_states = self.cuda_graphs[l_no][self.current_microbatch](
hidden_states, is_first_microbatch=(self.current_microbatch == 0),
)
if (
torch.is_grad_enabled()
and self.config.cpu_offloading
and self.group_prefetch_offload_commit_async is not None
):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None
) -> ShardedStateDict:
assert not sharded_offsets, "Unexpected sharded offsets"
non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False
)
sharded_state_dict = {}
layer_prefix = f'{prefix}layers.'
num_layers = self.config.num_layers
for layer in self.layers:
offset = layer._get_layer_offset()
global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock
if non_homogeneous_layers:
sharded_prefix = f'{layer_prefix}{global_layer_offset}.'
sharded_pp_offset = []
else:
sharded_prefix = layer_prefix
sharded_pp_offset = [
(0, global_layer_offset, num_layers)
] # PP sharding offset for ShardedTensors
layer_sharded_state_dict = layer.sharded_state_dict(
state_dict_prefix, sharded_pp_offset, metadata
)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)
# Add modules other than self.layers
for name, module in self.named_children():
if not module is self.layers:
sharded_state_dict.update(
sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata
)
)
return sharded_state_dict
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import types
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
import torch.nn.functional as F
from ..model_parallel_config import ModelParallelConfig
from ..utils import init_method_normal, scaled_init_method_normal
@dataclass
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
The initialization function has an argument for each parameter, including those in ModelParallelConfig.
"""
####################
# model architecture
####################
num_layers: int = 0
"""Number of transformer layers in a transformer block."""
hidden_size: int = 0
"""Transformer hidden size."""
num_attention_heads: int = 0
"""Number of transformer attention heads."""
num_query_groups: int = None
"""Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size: int = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided."""
kv_channels: int = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""
hidden_dropout: float = 0.1
"""Dropout probability for transformer hidden state."""
attention_dropout: float = 0.1
"""Post attention dropout probability."""
fp32_residual_connection: bool = False
"""If true, move residual connections to fp32."""
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
"""If True, uses the original BERT residule connection ordering."""
layernorm_epsilon: float = 1e-5
"""Epsilon value for any LayerNorm operations."""
layernorm_zero_centered_gamma: bool = False
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""
add_bias_linear: bool = True
"""Include a bias term in all linear layers (QKV projections, after core attention, and two in
MLP layer)."""
add_qkv_bias: bool = False
"""Add a bias term only for QKV projections."""
gated_linear_unit: bool = False
"""Use a gated linear unit for the first linear layer in the MLP."""
activation_func: Callable = F.gelu
"""Activation function to use for the non-linearity in the MLP."""
activation_func_fp8_input_store: bool = False
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts: int = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
rotary_interleaved: bool = False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""
window_size: Optional[Tuple[int, int]] = None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
#normalization: bool = "LayerNorm"
normalization: bool = "RMSNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm: bool = False
"""Whether to apply LayerNorm to the query and key embeddings."""
test_mode: bool = False
"""Whether to run real-time tests."""
calculate_per_token_loss: bool = False
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
global batch, versus the default behavior of assuming all tokens are non-padded."""
####################
# initialization
####################
init_method: Callable = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method: Callable = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
init_method_std: float = 0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""
####################
# mixed-precision
####################
apply_query_key_layer_scaling: bool = False
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16."""
attention_softmax_in_fp32: bool = True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""
####################
# fusion
####################
bias_activation_fusion: bool = False
"""If True, fuses bias addition and the activation function when possible."""
masked_softmax_fusion: bool = False
"""If True, uses softmax fusion."""
persist_layer_norm: bool = False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""
memory_efficient_layer_norm: bool = False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""
bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""
apply_rope_fusion: bool = False
"""If True, use fused RoPE kernel."""
####################
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: str = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method: str = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers: int = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations: bool = None
"""If True, distribute recomputed activations across the model parallel group."""
####################
# fp8 related
####################
fp8: str = None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
fp8_margin: int = 0
"""Margin for the scaling factor computation."""
fp8_interval: int = 1
"""Controls how often the scaling factor is recomputed."""
fp8_amax_history_len: int = 1
"""The length of the amax history window used for scaling factor computation."""
fp8_amax_compute_algo: str = "most_recent"
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""
fp8_wgrad: bool = True
"""When set to False, override FP8 config options and do the wgrad computation in higher precision."""
fp8_dot_product_attention: bool = False
"""When set to True, use the FP8 implementation of Dot Product Attention."""
fp8_multi_head_attention: bool = False
"""When set to True, use the FP8 implementation of Multi Head Attention."""
####################
# MoE related
####################
moe_router_load_balancing_type: str = "aux_loss"
"""Determines the load balancing strategy for the router. "aux_loss" corresponds to the load
balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing."""
moe_router_topk: int = 2
"""Number of experts to route to for each token."""
moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps: float = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping: bool = False # TODO: Support token dropping.
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
moe_token_dispatcher_type: str = "allgather"
"""The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather' and 'alltoall'."""
moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor: float = None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None."""
moe_pad_expert_input_to_capacity: bool = False
"""moe_pad_expert_input_to_capacity (bool): If True, pads the input for each expert to match the expert capacity length, effective only after the moe_expert_capacity_factor is set. The default setting is False."""
moe_token_drop_policy: str = 'probs'
"""The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
"""
moe_layer_recompute: bool = False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
####################
# miscellaneous
####################
clone_scatter_output_in_embedding: bool = True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""
disable_parameter_transpose_cache: bool = False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""
enable_cuda_graph: bool = False
"""When set to true, TransformerLayer blocks are wrapped with CUDA graph."""
def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
super().__post_init__()
if self.fp16 and self.bf16:
raise ValueError(
f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.'
)
if self.num_attention_heads % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
if self.kv_channels is None:
self.kv_channels = self.hidden_size // self.num_attention_heads
if self.num_query_groups is None:
self.num_query_groups = self.num_attention_heads
if self.num_query_groups % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_query_groups ({self.num_query_groups}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.expert_model_parallel_size > 1 and self.num_moe_experts is None:
raise ValueError(f'num_moe_experts must be non None to use expert-parallel.')
if self.num_moe_experts is not None and self.num_moe_experts <= 0:
raise ValueError(f'num_moe_experts must be non-negative.')
if self.moe_expert_capacity_factor is not None:
if self.moe_token_dispatcher_type != "alltoall":
raise ValueError(
f'moe_expert_capacity_factor only works with alltoall token dispatcher'
)
if self.moe_expert_capacity_factor < 0:
self.moe_expert_capacity_factor = None
if self.moe_router_load_balancing_type not in ["aux_loss", "none"]:
raise ValueError(
f'moe_expert_capacity_factor only works with aux_loss or none load balancing'
)
if self.moe_pad_expert_input_to_capacity:
if self.moe_expert_capacity_factor is None:
raise ValueError(
f'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity'
)
if self.cpu_offloading and (
self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers
):
raise ValueError(
f'CPU offloading can be done only for layers less than {self.num_layers}'
)
if self.cpu_offloading and self.pipeline_model_parallel_size > 1:
raise ValueError(
f'Currently there is no support for Pipeline parallelism with CPU offloading'
)
if self.cpu_offloading and self.recompute_granularity is not None:
raise ValueError(
f'CPU offloading does not work when activation recomputation is enabled'
)
if self.recompute_granularity is not None:
if not self.recompute_granularity in ['full', 'selective']:
raise ValueError(
f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".'
)
if self.recompute_method is not None:
if not self.recompute_method in ['block', 'uniform']:
raise ValueError(
f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
)
elif self.recompute_granularity != 'selective':
raise ValueError(
f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"'
)
if self.recompute_granularity != 'selective' and self.recompute_num_layers is None:
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between '
f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}'
)
elif (
self.recompute_granularity == 'selective' and self.recompute_num_layers is not None
):
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.'
)
if self.distribute_saved_activations and self.sequence_parallel:
raise ValueError(
f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}'
)
if self.virtual_pipeline_model_parallel_size is not None:
if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
raise ValueError(
f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}'
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.bias_activation_fusion:
if self.activation_func not in [F.gelu, F.silu]:
raise ValueError(
"When bias_activation_fusion is True, activation function should be either gelu or swiglu"
)
if (
self.activation_func == F.gelu
and not self.gated_linear_unit
and not self.add_bias_linear
):
raise ValueError(
"When bias_activation_fusion is True, gated_linear_unit is False, "
"and activation function is gelu, add_bias_linear must also be True."
)
if self.activation_func_fp8_input_store:
if self.activation_func != F.silu or not self.gated_linear_unit:
raise ValueError("Storing activation input in FP8 is supported only for SwiGLU.")
if self.apply_rope_fusion and self.rotary_interleaved:
raise ValueError(f'rotary_interleaved does not work with apply_rope_fusion.')
if self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std, self.num_layers
)
if self.moe_extended_tp:
if self.moe_token_dispatcher_type != 'allgather':
raise ValueError(
"Moe extended TP parallelism only applies to allgather based token dispatcher."
)
extended_tp_size = self.tensor_model_parallel_size * self.expert_model_parallel_size
if self.ffn_hidden_size % extended_tp_size != 0:
raise ValueError(
f'ffn_hidden_size: {self.ffn_hidden_size} must be divisible by extended_tp_size {extended_tp_size}'
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
@dataclass
class TransformerLayerSubmodules:
input_layernorm: Union[ModuleSpec, type] = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class BaseTransformerLayer(ABC):
""" A common parent class for `TransformerLayer` like implementations.
A dummy class that is subclassed by similar `TransformerLayer`s e.g. the
`TransformerLayer` in this file and possibly other `TransformerLayer`
implementations that aim to use `TransformerBlock` as the base module.
The main purpose is to check if any layer (or module) provided in the spec
is a subclass of this class to allow fanning-out of that spec for all the
layers in the `TransformerBlock`. See `_get_block_submodules` method
implementation in `transformer_block.py` file for more details.
"""
def __init__(self):
pass
class TransformerLayer(MegatronModule, BaseTransformerLayer):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: float = None,
):
super().__init__(config=config)
self.submodules_config = submodules
self.layer_number = layer_number + self._get_layer_offset()
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
## [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = build_module(
submodules.input_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention, config=self.config, layer_number=layer_number,
)
## [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
## [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention, config=self.config, layer_number=layer_number,
)
## [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config,)
## [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = build_module(
submodules.pre_mlp_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 8: MLP block]
# TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1,
# where MLP and MoE layer both appear alternately?
self.mlp = build_module(submodules.mlp, config=self.config)
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)
## [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
def _get_layer_offset(self):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
total_num_layers = self.config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
return offset
def forward(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
):
# hidden_states: [s, b, h]
# Residual connection.
residual = hidden_states
# Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)
# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_params=inference_params,
)
if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
# MLP.
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output, context
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for transformer layers."""
from functools import lru_cache
from operator import itemgetter
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedStateDict, StateDict
from megatron.core.jit import jit_fuser
from megatron.core.utils import (
make_sharded_tensor_for_checkpoint,
make_tp_sharded_tensor_for_checkpoint,
)
def get_linear_layer(rows, columns, init_method, perform_initialization=True):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
if perform_initialization: # Take from modelparallel config
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@lru_cache(maxsize=32)
def get_default_causal_mask(sq: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input."""
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
@jit_fuser
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@jit_fuser
def erf_gelu(x):
return (
x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
)
def make_sharded_tensors_for_checkpoint(
state_dict: StateDict,
prefix: str,
tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
extra_state_suffix: str = '_extra_state',
):
"""Wraps tensors from transformer layers with ShardedTensor or ShardedObject.
For a given `state_dict`, wraps:
- all _extra_states with ShardedObject
- all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor
- other values with DP sharded ShardedTensor
Args:
state_dict (StateDict): state_dict to convert
prefix (str): prefix appended to keys in final state dict
tensor_parallel_layers_axis_map (Dict[str, int], optional): dict mapping layer
names to the axis for TP sharding
sharded_offsets (Iterable[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related), passed along to ShardedTensor
extra_state_suffix (str, default = '_extra_state'): layers with this
suffix will be wrapped with ShardedObject instead of ShardedTensor.
"""
if tensor_parallel_layers_axis_map is None:
tensor_parallel_layers_axis_map = {}
sharded_state_dict = {}
for layer_name in state_dict.keys():
tensor = state_dict[layer_name]
layer_key = f'{prefix}{layer_name}'
if layer_name.endswith(extra_state_suffix):
sharded_state_dict[layer_key] = make_sharded_object_for_checkpoint(
tensor, layer_key, sharded_offsets
)
elif layer_name in tensor_parallel_layers_axis_map:
tp_axis = tensor_parallel_layers_axis_map[layer_name]
sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint(
tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets,
)
else:
sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint(
tensor, layer_key, prepend_offsets=sharded_offsets,
)
return sharded_state_dict
def make_sharded_object_for_checkpoint(
obj: Any,
key: str,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
replica_id: Union[None, int, Tuple[int, ...]] = None,
**kwargs,
):
""" Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
Args:
obj (object): any object to be sharded
key (str): unique identifier of the object
sharded_offsets (Iterable[Tuple[int, int, int]]): offsets normally
prepended to ShardedTensors, will be used as global offsets for
ShardedObject
replica_id (Union[None, int, Tuple[int, ...]]): replica id
"""
if replica_id is None:
replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
return ShardedObject(key, obj, *_get_extra_state_offsets(sharded_offsets), replica_id, **kwargs)
def _get_extra_state_offsets(
sharded_offsets: Iterable[Tuple[int, int, int]]
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
""" Turns ShardedTensor offsets into offsets suitable for ShardedObject. """
if sharded_offsets:
sharded_offsets = sorted(sharded_offsets, key=itemgetter(0)) # sort by axis
axis, extra_state_offset, extra_state_shape = zip(*sharded_offsets)
assert list(axis) == list(
range(len(axis))
), f'Expected contiguous axis for offsets: {sharded_offsets}'
else:
extra_state_shape = (1,)
extra_state_offset = (0,)
return extra_state_shape, extra_state_offset
def sharded_state_dict_default(
module: torch.nn.Module,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Provides implementation for sharded_state_dict method for non-MegatronModules.
Tries to call `module.sharded_state_dict` when possible,
otherwise uses regular state dict and assumes tensors are replicated across TP and DP.
`keep_vars=True` is passed to module.state_dict so that optimizer states
can be sharded later on.
Args:
module (torch.nn.Module): module which sharded state dict we want to obtain
prefix (str): prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional): metadata passed to module sharded_state_dict method
Returns:
dict: dictionary of state dict keys mapped to ShardedTensors
"""
if hasattr(module, 'sharded_state_dict'):
module_sharded_sd = module.sharded_state_dict(
prefix=prefix, sharded_offsets=sharded_offsets, metadata=metadata
)
else:
module_sd = module.state_dict(prefix='', keep_vars=True)
module_sharded_sd = make_sharded_tensors_for_checkpoint(
module_sd, prefix, {}, sharded_offsets,
)
return module_sharded_sd
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions used throughout Megatron core"""
import array
import hashlib
import logging
import math
import operator
import queue
import socket
import sys
import threading
import time
import traceback
from dataclasses import dataclass
from datetime import datetime
from functools import reduce
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor
logger = logging.getLogger(__name__)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False):
"""Get an attribute from a wrapped model.
If return_model_obj is true, return the object that has the 'attr' attribute;
otherwise, return the attribute directly."""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
if allow_none:
def condition(model, attr):
return not hasattr(model, attr)
else:
def condition(model, attr):
return getattr(model, attr, None) is None
while condition(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
if return_model_obj:
return model
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
def get_model_config(model):
return get_attr_wrapped_model(model, 'config', allow_none=False)
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if (
self.buffer.get((name, dtype), None) is None
or self.buffer[(name, dtype)].numel() < required_len
):
self.buffer[(name, dtype)] = torch.empty(
required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg=None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[assert_viewless_tensor(t) for t in tensor]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(
tensor,
extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
% ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape),
)
tensor.data = new_data_tensor
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any):
"""If torch distributed is initialized, log only on rank
Args:
logger (logging.Logger): The logger to write the logs
args (Tuple[Any]): All logging.Logger.log positional arguments
rank (int, optional): The rank to write on. Defaults to 0.
kwargs (Dict[str, Any]): All logging.Logger.log keyword arguments
"""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == rank:
logger.log(*args, **kwargs)
else:
logger.log(*args, **kwargs)
def log_on_each_pipeline_stage(logger: logging.Logger, *args: Any, **kwargs: Any):
"""Log on first rank in each pipeline stage
Args:
logger (logging.Logger): The logger to write the logs
args (Tuple[Any]): All logging.Logger.log positional arguments
kwargs (Dict[str, Any]): All logging.Logger.log keyword arguments
"""
assert torch.distributed.is_initialized()
if (
parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
and parallel_state.get_tensor_model_parallel_rank() == 0
):
logger.log(*args, **kwargs)
def check_param_hashes_across_dp_replicas(model: List[torch.nn.Module]) -> bool:
"""Computes hashes of all parameters in model, all-gathers hashes across DP replicas,
and then checks for equality between the locally-computed hashes and the hashes
from DP replica 0.
NOTE: This function computes SHA-1 hashes on the CPU and thus needs to move all param
tensors from GPU to CPU first; as a result, this function is not intended to be called
very frequently in the main training loop.
Args:
model (List[torch.nn.Module]): List of model chunks whose parameter hashes need to
be checked.
Returns:
True if all param hashes match with corresponding hash on DP replica 0, False
otherwise.
"""
# Compute per-parameter hashes on this rank.
params = []
local_param_hashes = []
for model_chunk_id, model_chunk in enumerate(model):
for (param_name, param) in model_chunk.named_parameters():
param_hash = torch.frombuffer(
array.array(
'B', hashlib.sha1(param.data.to("cpu").float().numpy(force=True)).digest()
),
dtype=torch.uint8,
)
params.append((model_chunk_id, param_name, param))
local_param_hashes.append(param_hash)
local_param_hashes = torch.stack(local_param_hashes)
# Collect per-parameter hashes across all ranks in DP group.
all_param_hashes = [
torch.zeros_like(local_param_hashes)
for _ in range(parallel_state.get_data_parallel_world_size())
]
torch.distributed.all_gather(
all_param_hashes, local_param_hashes, group=parallel_state.get_data_parallel_group_gloo()
)
# Make sure local per-parameter hash matches DP rank 0.
param_hashes_match = torch.equal(local_param_hashes, all_param_hashes[0])
if not param_hashes_match:
for i, (model_chunk_id, param_name, param) in enumerate(params):
if not torch.equal(local_param_hashes[i], all_param_hashes[0][i]):
rank = torch.distributed.get_rank()
logger.info(
f"[Rank {rank}] Hash not matching for {param_name} in model chunk {model_chunk_id}"
)
return param_hashes_match
def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
""" Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group.
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
if replica_id is None:
replica_id = (0, 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
(
tp_axis + prepend_axis_num,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_id=None, **kwargs):
""" Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group).
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
if replica_id is None:
replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def prepare_input_tensors_for_wgrad_compute(grad_output, all_gathered_input):
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if grad_output.dim() == 3:
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
all_gathered_input = all_gathered_input.view(
all_gathered_input.shape[0] * all_gathered_input.shape[1], all_gathered_input.shape[2]
)
return grad_output, all_gathered_input
def drain_embedding_wgrad_compute(config, embedding_activation_buffer, grad_output_buffer, weight):
""" Helper for performing embedding wgrad GEMM's during the pipeline drain phase, pipelines the AllGather and GEMM's.
Should only be used when pipeline model parallelism and gradient accumulation fusion are enabled.
"""
assert len(embedding_activation_buffer) == len(
grad_output_buffer
), "Length of activation and gradient buffers need to be equal!"
import fused_weight_gradient_mlp_cuda
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
input = embedding_activation_buffer.pop(0)
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gathered_input = [None, None]
if config.sequence_parallel:
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu_0")
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=False
)
all_gathered_input[0] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[0] = input
input = None
def wgrad_compute(all_gathered_input, grad_output, weight):
grad_output, all_gathered_input = prepare_input_tensors_for_wgrad_compute(
grad_output, all_gathered_input
)
if config.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
all_gathered_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
all_gathered_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
# We have all_gathered_input list acting as a double buffer here,
# since we are pipelining the AllGather and GEMM,one buffer all gathers
# the input while the other buffer reads from it for the GEMM. We use i
# and (i+1) for indexing to enable this double buffering.
for i in range(len(embedding_activation_buffer)):
input = embedding_activation_buffer.pop(0)
if config.sequence_parallel:
name = "mpu_" + str((i + 1) % 2)
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, name)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
all_gathered_input[(i + 1) % 2] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[(i + 1) % 2] = input
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[i % 2], grad_output, weight)
input, all_gathered_input[i % 2], grad_output = None, None, None
if config.sequence_parallel:
handle.wait()
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[1], grad_output, weight)
input, all_gathered_input[1], grad_output = None, None, None
class _ValueWithRank:
"""This is an internal class, not for use outside this module
Attributes:
_rank (int): rank for the value
_value (float) : the value it stores, eg elapsed time
_unit (str) : unit for the value
"""
def __init__(self, value: float, rank: int, unit: str = "") -> None:
"""Initializer
Args:
_value (float): the initial value with which it is inited
_rank (int): the rank number
_unit (str) : the unit of the value, eg ms or flops
"""
self._rank = rank
self._value = value
self._unit = unit
def __lt__(self, other) -> bool:
""" Check if value of self is smaller than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is less than rhs._value, else False
"""
return self._value < other._value
def __gt__(self, other) -> bool:
"""Check if value of self is larger than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is greater than rhs._value, else False
"""
return self._value > other._value
def __call__(self) -> Tuple[float, int, str]:
"""Returns the value, the rank, and unit as a Tuple
Returns:
Tuple[float, int, str]: value, rank, unit
"""
return self._value, self._rank, self._unit
def __str__(self) -> str:
"""String representation of the object
Returns:
str: strigified object
"""
return f"{self._value:.2f}{self._unit}/{self._rank}"
@dataclass
class _StragglerData:
"""This is an internal dataclass, not for use outside this module
Attributes:
min_elapsed (_ValueWithRank) min iteration time across all ranks
max_elapsed (_ValueWithRank) max iteration time across all ranks
min_btime (_ValueWithRank) min cpu time across all ranks
max_btime (_ValueWithRank) max cpu time across all ranks
min_temp (_ValueWithRank): min gpu temp across all ranks
max_temp (_ValueWithRank): max gpu temp across all ranks
min_power (_ValueWithRank) min gpu power across all ranks
max_power (_ValueWithRank) max gpu power across all ranks
min_util (_ValueWithRank): min gpu util across all ranks
max_util (_ValueWithRank): max gpu util across all ranks
min_clock (_ValueWithRank): min gpu clock across all ranks
max_clock (_ValueWithRank) max gpu clock across all ranks
aflops (List[_ValueWithRank]): sorted array of (_ValueWithRank)
"""
# gemm time
min_elapsed = _ValueWithRank(sys.float_info.max, 0, "ms")
max_elapsed = _ValueWithRank(sys.float_info.min, 0, "ms")
# get_batch time
min_btime = _ValueWithRank(sys.float_info.max, 0, "us")
max_btime = _ValueWithRank(sys.float_info.min, 0, "us")
# temp
min_temp = _ValueWithRank(sys.float_info.max, 0, "C")
max_temp = _ValueWithRank(sys.float_info.min, 0, "C")
# power
min_power = _ValueWithRank(sys.float_info.max, 0, "W")
max_power = _ValueWithRank(sys.float_info.min, 0, "W")
# util
min_util = _ValueWithRank(sys.float_info.max, 0, "%")
max_util = _ValueWithRank(sys.float_info.min, 0, "%")
# clock
min_clock = _ValueWithRank(sys.float_info.max, 0, "MHz")
max_clock = _ValueWithRank(sys.float_info.min, 0, "MHz")
aflops: Union[List[_ValueWithRank], None] = None
class StragglerDetector:
"""Singleton Class implementing per rank Straggler Detector
It use cuda events to time operation of choice using the
start and stop methods which can be directly invoked using
the class instance or can be used like a python context.
After collection, a report() method is available to display
the collected metrics. It is only supported if CUDA is
available. megatron/core/README_STRAGGLER.md for more info
Note:
The instance and class attributes mentioned below are all
private to the class and has no use outside the class
Attributes:
_off (bool): current state of the toggle
start (FunctionType): start method
stop (FunctionType): stop method
world (int): world size
rank (int): rank for this instance
mmcnt (int): number of ranks to report
port (int): control port
amp (float): amplification factor for TFLOPs, default 3.0
toggle (bool): whether to start/stop detector collection
bdata (bool): when true, just collect get_batch
dev (int): cuda device
evt_q (LifoQueue): cuda event queue
start_gemm_ev (list[torch.cuda.Event]): cuda start event
stop_gemm_ev (list[torch.cuda.Event]): cuda stop event
start_data_ev (list[torch.cuda.Event]): cuda start event
stop_data_ev (list[torch.cuda.Event]): cuda stop event
start_gemm_tm (list[int]): start time (wallclock)
stop_gemm_tm (list[int]): stop time (wallclock)
start_data_tm (list[int]): start time for get_batch
stop_data_tm (list[int]): stop time for get_batch
sock (socket): the controller socket
ctrlr (Thread): the controller thread
"""
_configured = False
"""Indicates if the singleton instance is configured or not
"""
def __new__(cls: Type["StragglerDetector"]) -> "StragglerDetector":
"""Constructor
Creates an instance of the class if not created
Args:
cls (Type[&#39;StragglerDetector&#39;]): The class type
Returns:
StragglerDetector: the class instance
"""
if not hasattr(cls, "_instance"):
cls._instance = super(StragglerDetector, cls).__new__(cls)
return cls._instance
def __init__(self) -> None:
"""Initializer
The inital state of the StragglerDetector instance is disabled.
The enabled state is indicated using self._off member variable
and the proerty enabled.
"""
self._off: bool = True
self.start = self.null_method
self.stop = self.null_method
self.world: int = 0
self.rank: int = 0
self.mmcnt: int = 1
self.port: int = 0
self.amp: float = 3.0
self.toggle: bool = False
self.bdata: bool = False
self.dev: Union[torch.device, int, None] = None
self.evt_q: Union[queue.LifoQueue, None] = None
self.start_gemm_ev: List[torch.cuda.Event] = []
self.stop_gemm_ev: List[torch.cuda.Event] = []
self.start_data_ev: List[torch.cuda.Event] = []
self.stop_data_ev: List[torch.cuda.Event] = []
self.start_gemm_tm: List[int] = []
self.stop_gemm_tm: List[int] = []
self.start_data_tm: List[int] = []
self.stop_data_tm: List[int] = []
self.sock: Union[socket.socket, None] = None
self.ctrlr: Union[threading.Thread, None] = None
def configure(
self,
world: int,
rank: int,
mmcnt: int = 1,
amp: float = 3.0,
port: int = 65535,
prefill: int = 1024,
enabled: bool = False,
) -> None:
"""This method is called to configure the Singleton instance
It should be called once per instantiation per process.
Note:
The constructor keeps the state of instance disabled
i.e no collection will happen even when start/stop methods are
called. Only when enabled is True (self._off is True), the
start/stop method pointers get assigned the real collection
methods, otherwise they are initialized with null_method
Args:
world (int): World Size
rank (int): The rank of this trainer
mmcnt (int, optional): Number of ranks to print for showing Min/Max Etpt.
Defaults to 1.
amp (float, optional): Set to 3.0 if we only use timers in fwd pass.
Defaults to 3.0.
port (int, optional): Control port, useful only for rank-0. Defaults to 65535.
prefill (int, optional): Howmany Events to pre-populate. Defaults to 1024.
enabled (bool, optional): Whether or not collection is enabled on startup.
Defaults to False.
"""
if StragglerDetector._configured:
# don't throw
return
StragglerDetector._configured = True
self.bdata = False
self.start = self.null_method
self.stop = self.null_method
self._off = True
# No CUDA, No Support
if torch.cuda.is_available():
self._off = not enabled
self.world = world
self.rank = rank
self.mmcnt = mmcnt if mmcnt > 1 else 1
self.amp = amp
self.port = port
self.toggle = False
self.bdata = False
self.evt_q = queue.LifoQueue()
self.start_gemm_ev = []
self.stop_gemm_ev = []
self.start_data_ev = []
self.stop_data_ev = []
self.start_gemm_tm = []
self.stop_gemm_tm = []
self.start_data_tm = []
self.stop_data_tm = []
backend = torch.distributed.get_backend()
if backend == "nccl":
self.dev = torch.cuda.current_device()
else:
self.dev = torch.device("cpu")
# cache some events
for _ in range(prefill):
self.evt_q.put(torch.cuda.Event(enable_timing=True))
if self.rank == 0:
# Start the controller
self._controller()
if not self._off:
self.start = self.start_method
self.stop = self.stop_method
def reset(self) -> None:
"""This method is called to reset the metrics state of the instance
It is generally called from within elapsed() after extracting per rank metrics.
"""
if self._off:
return
# Pool them
if self.evt_q is not None:
_ = [self.evt_q.put(ev) for ev in self.start_gemm_ev]
_ = [self.evt_q.put(ev) for ev in self.stop_gemm_ev]
_ = [self.evt_q.put(ev) for ev in self.start_data_ev]
_ = [self.evt_q.put(ev) for ev in self.stop_data_ev]
self.start_gemm_ev = []
self.stop_gemm_ev = []
self.start_data_ev = []
self.stop_data_ev = []
# Use regular timers
self.start_gemm_tm = []
self.stop_gemm_tm = []
self.start_data_tm = []
self.stop_data_tm = []
self.bdata = False
def start_method(self) -> None:
"""This method adds the start timers.
Both cuda event and perf_counter are added. If bdata is set to
true from __call__, this method skips inserting cuda
timer. This way it can be used to measure time spent on
CPU - generally useful for timing get_batch()
"""
# Not reentrant
if self.evt_q is not None and self.evt_q.qsize() > 1:
sev = self.evt_q.get() # no try-catch
eev = self.evt_q.get() # no try-catch
else:
sev = torch.cuda.Event(enable_timing=True)
eev = torch.cuda.Event(enable_timing=True)
# First check if this start is for data
if self.bdata:
self.start_data_ev.append(sev)
self.stop_data_ev.append(eev)
self.start_data_tm.append(0)
self.stop_data_tm.append(0)
idx = len(self.stop_data_tm) - 1
self.start_data_tm[idx] = time.perf_counter_ns()
self.start_data_ev[idx].record()
self.bdata = False
return
self.start_gemm_ev.append(sev)
self.stop_gemm_ev.append(eev)
self.start_gemm_tm.append(0)
self.stop_gemm_tm.append(0)
idx = len(self.stop_gemm_tm) - 1
self.start_gemm_tm[idx] = time.perf_counter_ns()
self.start_gemm_ev[idx].record()
def stop_method(self) -> None:
"""This method adds the stop timers.
Both cuda event and perf_counter are added. If bdata is set to
true from __call__, this method skips inserting cuda
timer. Also see start_method()
"""
# Not reentrant
# First check if this stop is for data
idx = len(self.stop_data_tm) - 1
if idx >= 0 and self.stop_data_tm[idx] == 0:
self.stop_data_tm[idx] = time.perf_counter_ns()
self.stop_data_ev[idx].record()
return
idx = len(self.stop_gemm_tm) - 1
if idx >= 0 and self.stop_gemm_tm[idx] == 0:
self.stop_gemm_tm[idx] = time.perf_counter_ns()
self.stop_gemm_ev[idx].record()
def elapsed(self) -> Tuple[float, float, int, int, int, int]:
"""This method is called from report(), or can be called directly
It is called to collect all the elapsed time since last reset().
It finally calls reset()
Returns:
Tuple[float, float, int, int, int, int]: see below for returns
delta : time spent in kernel
batch_delta : time spent in get_batch
temp : observed gpu temp
power : observed gpu power
util : observed gpu utilization
clock : observed gpu clock
"""
if self._off:
# match with return below
return 0, 0, 0, 0, 0, 0
ls_ev = len(self.start_gemm_ev)
le_ev = len(self.stop_gemm_ev)
ls_bs = len(self.start_data_ev)
ls_be = len(self.stop_data_ev)
delta = 0.0
batch_delta = 0.0
temp = 0
power = 0
clock = 0
if ls_ev != le_ev:
logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
elif ls_bs != ls_be:
logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
else:
temp = torch.cuda.temperature()
power = torch.cuda.power_draw()
util = torch.cuda.utilization()
clock = torch.cuda.clock_rate()
torch.cuda.synchronize()
# Process Events
for i in range(ls_ev):
e_ev = self.start_gemm_ev[i].elapsed_time(self.stop_gemm_ev[i])
e_tm = (self.stop_gemm_tm[i] - self.start_gemm_tm[i]) / 1e6 # ns to ms
# Pick the larger of Event and perf_counter time?
delta += max(e_ev, e_tm)
# Process get_batch
for i in range(ls_bs):
b_ev = self.start_data_ev[i].elapsed_time(self.stop_data_ev[i])
b_tm = (self.stop_data_tm[i] - self.start_data_tm[i]) / 1e6 # ns to ms
# data fetching has prefetch, hence take the max, instead of avg
batch_delta = max(batch_delta, max(b_ev, b_tm))
self.reset() # Prepare for next round
# time in ms, batch_delta in ms, check return above
return delta, batch_delta, temp, power, util, clock
def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
"""Function to log the min/max metircs and the associated rank over a time period
It finds the slowest and fastest rank among all ranks. It should be
called by all ranks, but only rank-0 prints the analysis
At the end it checks, if the straggler detector should
remain active or if it should be deactivated.
Args:
total_flops (float, optional): The theoretical flops over the period. Defaults to 0.0.
log_interval (int, optional): The training interval over which reporting is called(ms)
Defaults to 0.
Returns:
bool: True if reported, else False
"""
ret = False
if not self._off and total_flops > 0.0 and log_interval > 0:
elapsed, btime, temp, power, util, clock = self.elapsed() # get raw time
# btime (get_batch time is max in the iteration)
ptime = elapsed / (log_interval * 1.0) # avg per iteration elapsed time, ms
api_flops = total_flops / (log_interval * 1.0) # avg per iteration flops, ms
apir_flops = api_flops / (
ptime * 10 ** 9 * self.world
) # this is avg per iteration this rank's thruput, TFLOP/s (note 10**9),
et_flops = apir_flops / self.amp # Estimated TFLOPs, not tracing backward
o_dt = self._min_max(
ptime, btime, float(temp), float(power), float(util), float(clock), et_flops,
)
if self.rank == 0 and o_dt is not None and o_dt.aflops is not None:
now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
min_flops, min_frank, _ = o_dt.aflops[0]()
max_flops, max_frank, _ = o_dt.aflops[-1]()
logger.info(
f"{now} | "
f"MnRtt/Rnk: {o_dt.min_elapsed} | "
f"MxRtt/Rnk: {o_dt.max_elapsed} | "
f"MnPwr/Rnk: {o_dt.min_power} | "
f"MxPwr/Rnk: {o_dt.max_power} | "
f"MnTmp/Rnk: {o_dt.min_temp} | "
f"MxTmp/Rnk: {o_dt.max_temp} | "
f"MnUtl/Rnk: {o_dt.min_util} | "
f"MxUtl/Rnk: {o_dt.max_util} | "
f"MnClk/Rnk: {o_dt.min_clock} | "
f"MxClk/Rnk: {o_dt.max_clock} | "
f"MnDRtt/Rnk: {o_dt.min_btime} | "
f"MxDRtt/Rnk: {o_dt.max_btime} | "
f"MnEtpt/Rnk: {min_flops:.2f}TF/{min_frank} | "
f"MxEtpt/Rnk: {max_flops:.2f}TF/{max_frank}"
)
if self.mmcnt > 1 and self.mmcnt < self.world:
line = f"^^^^ Bottom {self.mmcnt} Ranks with lowest Etpt(TF):"
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i]},"
logger.info(line)
line = f"^^^^ Top {self.mmcnt} Ranks with highest Etpt(TF):"
shift = self.world - self.mmcnt
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i+shift]},"
logger.info(line)
ret = True
# Check/Communicate if tracking is turned off or on
self._check_toggle()
return ret
def _check_toggle(self) -> None:
"""Helper method to check if a request to toggle the collection state was made
It checks iof collection state toggle req was made via the server listening on
rank-0 since last call to report(). Called by report(). Calling this method
indirectly from report() is the only way to activate the change that is made
via rank-0
"""
# If no change just commnunicate the current
off = self._off
if self.rank == 0 and self.toggle:
off = not self._off
self.toggle = False
st = torch.tensor(off, dtype=torch.bool, device=self.dev)
torch.distributed.broadcast(st, 0) # Blocking
# save old switch
off = self._off
self._off = bool(st.item())
if off != self._off:
if not self._off:
self.start = self.start_method
self.stop = self.stop_method
state = "ON"
else:
self.start = self.null_method
self.stop = self.null_method
state = "OFF"
if self.rank == 0:
logger.info(f"Toggling StragglerDetector State {state}")
def _handler(self) -> None:
"""Thread function for the controller.
It is a tcp-server that listens on a port. Uses HTTP protocol.
If connected to it using curl, it indicates a toggle of the
collection state. The actual toggling happens at the end of
calling report() when _check_toggle() is called.
"""
resp = f"HTTP/1.0 200 OK\r\nConnection: Close\r\nContent-length: "
if self.rank == 0:
state = "OFF" if self._off else "ON"
logger.info(
f"Controller ready to recv " f"commands on port {self.port}. Current state {state}"
)
while True and self.sock is not None:
try:
conn, _ = self.sock.accept()
_ = conn.recv(1024)
self.toggle = True
state = "ON" if self._off else "OFF"
msg = f"Will turn StragglerDetector {state} at next logging interval"
msg_len = len(msg)
final_resp = f"{resp}{msg_len}\r\n\r\n{msg}"
conn.send(final_resp.encode())
conn.close()
logger.info(msg)
except Exception as err:
logger.error(f"Error in stragler handler.. {str(err)}")
return
def _controller(self):
"""Installs a controller listener that is used to toggle collection state.
Called from configure(). Ignored for all ranks other than rank-0
"""
try:
if self.rank == 0:
neth = "0.0.0.0"
netp = self.port
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((neth, netp))
self.sock.listen(128)
self.ctrlr = threading.Thread(
target=self._handler, args=(), name="straggler", daemon=True
)
self.ctrlr.start()
except Exception as err:
logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")
def _min_max(
self,
ptime: float,
btime: float,
temp: float,
power: float,
util: float,
clock: float,
flops: float,
) -> Union[_StragglerData, None]:
"""Helper function to find the min/max values
Args:
ptime (float): avg per iteration gpu time
btime (float): avg per iteration cpu time
temp (float): gpu temp at the time of reporting
power (float): gpu power at the time of reporting
util (float): gpu util at the time of reporting
clock (float): gpu clock at the time of reporting
flops (float): estimated flops for the rank
Returns:
Union[_StragglerData, None]: It contains the min/max of few metrics and the
corresponding rank it also has sorted list of
all (flops, rank) sorted by flops (aflops)
or returns None if collecton is disabled
"""
if self._off:
return None
# initialize output data object
o_dt = _StragglerData()
prof_data: Dict[str, Union[int, float]] = {}
data_list: List[Dict[str, Union[int, float]]] = []
prof_data["rank"] = self.rank
prof_data["time"] = ptime
prof_data["btime"] = btime
prof_data["temp"] = temp
prof_data["power"] = power
prof_data["util"] = util
prof_data["clock"] = clock
prof_data["flops"] = flops
if self.rank == 0:
data_list = [prof_data] * self.world
# this is blocking by default
torch.distributed.gather_object(prof_data, object_gather_list=data_list, dst=0)
if self.rank == 0:
min_ctime = min(data_list, key=lambda k: k["time"]) # elapsed
max_ctime = max(data_list, key=lambda k: k["time"]) # elapsed
min_cbatch = min(data_list, key=lambda k: k["btime"]) # batch time
max_cbatch = max(data_list, key=lambda k: k["btime"]) # batch time
min_ctemp = min(data_list, key=lambda k: k["temp"]) # temp
max_ctemp = max(data_list, key=lambda k: k["temp"]) # temp
min_cpower = min(data_list, key=lambda k: k["power"]) # power
max_cpower = max(data_list, key=lambda k: k["power"]) # power
min_cutil = min(data_list, key=lambda k: k["util"]) # gpu util
max_cutil = max(data_list, key=lambda k: k["util"]) # gpu util
min_cclock = min(data_list, key=lambda k: k["clock"]) # gpu clock
max_cclock = max(data_list, key=lambda k: k["clock"]) # gpu clock
min_val = min_ctime["time"]
min_rank = min_ctime["rank"]
max_val = max_ctime["time"]
max_rank = max_ctime["rank"]
o_dt.min_elapsed = _ValueWithRank(min_val, int(min_rank), "ms")
o_dt.max_elapsed = _ValueWithRank(max_val, int(max_rank), "ms")
min_val = min_cbatch["btime"]
min_rank = min_cbatch["rank"]
max_val = max_cbatch["btime"]
max_rank = max_cbatch["rank"]
o_dt.min_btime = _ValueWithRank(min_val, int(min_rank), "ms")
o_dt.max_btime = _ValueWithRank(max_val, int(max_rank), "ms")
min_val = min_ctemp["temp"]
min_rank = min_ctemp["rank"]
max_val = max_ctemp["temp"]
max_rank = max_ctemp["rank"]
o_dt.min_temp = _ValueWithRank(min_val, int(min_rank), "C")
o_dt.max_temp = _ValueWithRank(max_val, int(max_rank), "C")
min_val = min_cpower["power"]
min_rank = min_cpower["rank"]
max_val = max_cpower["power"]
max_rank = max_cpower["rank"]
o_dt.min_power = _ValueWithRank(min_val, int(min_rank), "W")
o_dt.max_power = _ValueWithRank(max_val, int(max_rank), "W")
min_val = min_cutil["util"]
min_rank = min_cutil["rank"]
max_val = max_cutil["util"]
max_rank = max_cutil["rank"]
o_dt.min_util = _ValueWithRank(min_val, int(min_rank), "%")
o_dt.max_util = _ValueWithRank(max_val, int(max_rank), "%")
min_val = min_cclock["clock"]
min_rank = min_cclock["rank"]
max_val = max_cclock["clock"]
max_rank = max_cclock["rank"]
o_dt.min_clock = _ValueWithRank(min_val, int(min_rank), "MHz")
o_dt.max_clock = _ValueWithRank(max_val, int(max_rank), "MHz")
o_dt.aflops = [
_ValueWithRank(d.get("flops", 0.0), int(d.get("rank", -1)))
for _, d in enumerate(data_list)
]
o_dt.aflops.sort(key=lambda val_with_rank: val_with_rank()[0])
# wait for everyone here
torch.distributed.barrier()
return o_dt
@property
def enabled(self) -> bool:
"""Can be called to check the enabled state of the instance
Note:
After the request to toggle the state, the
actual state change happens at end of call
to report()
"""
return not self._off
@property
def configured(self) -> bool:
"""Can be called to check if the the instance is already configured
Returns:
bool: returns True if configure was called and was a success, else False
"""
return StragglerDetector._configured
@property
def my_rank(self):
"""Can be called to get configured rank of this instance
Returns:
int: Configured rank for this instance
"""
return self.rank
@property
def world_size(self) -> int:
"""Can be called to get configured world of this instance
Returns:
int: World size configured for this instance
"""
return self.world
def null_method(self) -> None:
"""Default method to initialize start/stop method ptrs"""
pass
def __enter__(self) -> "StragglerDetector":
"""Define context/instance entry
Returns:
StragglerDetector: the instance
"""
self.start()
return self
def __call__(self, bdata: bool = False) -> "StragglerDetector":
"""Callable for the instance. Set context state,
Useful when the context is used for cpu timers only when bdata=True
Args:
bdata (bool, optional): when true, only enables cpu timers. Defaults to False.
Returns:
StragglerDetector: the instance
"""
self.bdata = bdata
return self
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> bool:
"""Define context/instance exit, calls the stop method
Args:
ex_type (Optional[Type[BaseException]]): Exception type
ex_val (Optional[BaseException]): _description_
ex_tb (Optional[TracebackType]): _description_
Returns:
bool: True if the exception was handled
"""
# Should not suppress errors even if turned off
if ex_type is not None:
err = traceback.format_exception(ex_type, ex_val, ex_tb)
logger.warning(f"{str(ex_val)}\n{err}")
self.stop()
return False
# Singleton, global visibility
__straggler__ = StragglerDetector()
"""StragglerDetector: private module variable, not be directly accessed
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
def add_modelopt_args(parser):
"""Add additional arguments for using TensorRT Model Optimizer (modelopt) features."""
group = parser.add_argument_group(title="modelopt-generic")
group.add_argument(
"--export-legacy-megatron",
action="store_true",
help="Export a legacy megatron-lm checkpoint.",
)
group.add_argument(
"--export-te-mcore-model",
action="store_true",
help="Export a megatron-core transformer-engine checkpoint.",
)
group.add_argument(
"--export-quant-cfg",
type=str,
default=None,
choices=["int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4", "None"],
help="Specify a quantization config from the supported choices.",
)
return parser
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from pathlib import Path
from typing import Optional, Dict
from megatron.core import dist_checkpointing
from megatron.training import get_args
from megatron.training.checkpointing import _load_base_checkpoint, load_checkpoint
from megatron.training.utils import print_rank_0, unwrap_model
try:
from modelopt.torch.opt.plugins import (
get_sharded_modelopt_state,
restore_modelopt_state_metadata,
)
except ImportError as e:
raise ImportError("Required `\"nvidia-modelopt[torch]\"` is not installed!") from e
def load_modelopt_state(load_dir: Optional[str] = None) -> Dict:
"""Loading modelopt_state without a model.
If --use-dist-ckpt, we try to load from the sharded modelopt_state. This will not load the model
state_dict. Otherwise, if the checkpoint is not sharded, we load the base checkpoint (that
contains the model state as well) and extract the modelopt_state.
Args:
load_dir: optionally provide a different loading path
"""
args = get_args()
if load_dir is None:
load_dir = args.load
if args.use_dist_ckpt:
# Read the tracker file and set the iteration.
tracker_filename = os.path.join(load_dir, 'latest_checkpointed_iteration.txt')
# If no tracker file, assuming that it is a .nemo checkpoint.
if not os.path.isfile(tracker_filename):
sharded_load_dir = Path(load_dir) / "model_weights"
else:
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
sharded_load_dir = Path(load_dir) / 'iter_{:07d}'.format(iteration)
except ValueError:
sharded_load_dir = Path(load_dir) / metastring
modelopt_state_dir = sharded_load_dir / "modelopt_state"
if modelopt_state_dir.exists():
print_rank_0("Loading sharded modelopt_state ({})".format(modelopt_state_dir))
modelopt_state = restore_modelopt_state_metadata(
dist_checkpointing.load(
get_sharded_modelopt_state(args.num_layers), modelopt_state_dir,
)
)
return modelopt_state
else:
print_rank_0(
"sharded modelopt_state ({}) does not exist!".format(modelopt_state_dir)
)
return {}
else:
print_rank_0("Loading modelopt_state from base checkpoint ({})".format(load_dir))
try:
state_dict, _, _ = _load_base_checkpoint(args.load, rank0=False)
except Exception:
print_rank_0("Failed to load base checkpoint via megatron _load_base_checkpoint!")
return {}
if state_dict is None:
return {}
return state_dict.get("modelopt_state", {})
def load_modelopt_checkpoint(
model,
optimizer=None,
opt_param_scheduler=None,
strict: bool = True,
additional_sharded_prefix: str = "model.",
load_arg: str = "load",
) -> None:
"""Load a sharded (untar .nemo or megatron --use-dist-ckpt) or unsharded checkpoint.
Essentially, the function is detecting whether the checkpoint is a .nemo sharded checkpoint.
If so, we load the sharded state_dict with additional_sharded_prefix `model.`.
This additional prefix is tha artifact of the lightning module wrapper. Once the sharded
state_dict is loaded, we use a state_dict pre_hook to pop this additional prefix (`model.`)
from all state_dict keys.
If this is not a .nemo sharded checkpoint, then this function will simply call
load_checkpoint. See megatron.checkpointing.load_checkpoint for explanation.
Args:
additional_sharded_prefix: append additional prefix to align the sharded checkpoint keys.
When loading an .nemo sharded checkpoint, this is usually `model.`. Otherwise, this is
typically an empty string.
"""
def _remove_prefix_state_dict_pre_hook(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs,
):
"""Pytorch state_dict pre_hook to remove prefix of the state_dict keys."""
if additional_sharded_prefix is None:
return
key_rewrite_list = []
for key, _ in state_dict.items():
if key.startswith(additional_sharded_prefix):
key_rewrite_list.append(key)
for old_key in key_rewrite_list:
new_key = old_key[len(additional_sharded_prefix) :]
state_dict[new_key] = state_dict.pop(old_key)
args = get_args()
load_dir = getattr(args, load_arg)
sharded_load_dir = Path(load_dir) / "model_weights"
if sharded_load_dir.exists() and optimizer is None and opt_param_scheduler is None:
unwrapped_model = unwrap_model(model)
# Set this attribute will alter the sharded_offsets of transformer_block.
unwrapped_model[0].decoder.config.non_homogeneous_layers = False
sharded_state_dict = unwrapped_model[0].sharded_state_dict(prefix=additional_sharded_prefix)
if additional_sharded_prefix:
unwrapped_model[0]._register_load_state_dict_pre_hook(
_remove_prefix_state_dict_pre_hook
)
unwrapped_model[0].load_state_dict(
dist_checkpointing.load(sharded_state_dict, sharded_load_dir)
)
# Set the attribute to True such that by-default we are storing the heterogenous arch.
unwrapped_model[0].decoder.config.non_homogeneous_layers = True
else:
_ = load_checkpoint(model, optimizer, opt_param_scheduler, strict=strict, load_arg=load_arg)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""ModelOpt GPT model provider."""
import modelopt.torch.opt as mto
from megatron.core.inference.gpt.model_specs import get_gpt_layer_modelopt_spec
from megatron.core.inference.gpt.state_dict_hooks import (
mcore_gpt_load_legacy_state_dict_pre_hook,
mcore_gpt_load_te_state_dict_pre_hook,
)
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.parallel_state import get_tensor_model_parallel_rank
from megatron.core.transformer.spec_utils import import_module
from megatron.inference.checkpointing import load_modelopt_state
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
def model_provider(pre_process=True, post_process=True, parallel_output=True) -> MCoreGPTModel:
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the core GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
parallel_output (bool): whether to allgather the output logits? This must be
True if `model_provider` is called in text_generation_server.
Returns:
MCoreGPTModel: The returned model
"""
args = get_args()
print_rank_0("building GPT model ...")
# ModelOpt by default assumes none homogenous layers. This affect the storage format of the sharded checkpoint.
config = core_transformer_config_from_args(args)
config.non_homogeneous_layers = True
if args.use_legacy_models:
raise ValueError(
"ModelOpt integration only support MCore models. Use --use-mcore-modules instead."
)
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
transformer_layer_spec = get_gpt_layer_modelopt_spec(
remap_te_layernorm=args.export_te_mcore_model, qk_layernorm=False,
)
model_type = MCoreGPTModel
model_kwargs = {
"config": config,
"transformer_layer_spec": transformer_layer_spec,
"vocab_size": args.padded_vocab_size,
"max_sequence_length": args.max_position_embeddings,
"pre_process": pre_process,
"post_process": post_process,
"fp16_lm_cross_entropy": args.fp16_lm_cross_entropy,
"parallel_output": parallel_output,
"share_embeddings_and_output_weights": not args.untie_embeddings_and_output_weights,
"position_embedding_type": args.position_embedding_type,
"rotary_percent": args.rotary_percent,
}
model = model_type(**model_kwargs)
# Load modelopt_state
modelopt_state = load_modelopt_state() if args.load else {}
if modelopt_state:
model = mto.restore_from_modelopt_state(model, modelopt_state)
# Register some load_state_dict prehooks to handle some known state_dict key mismatch.
# (legacy <-> modelopt) and (default te <-> modelopt)
if args.export_legacy_megatron:
model._register_load_state_dict_pre_hook(mcore_gpt_load_legacy_state_dict_pre_hook)
if args.export_te_mcore_model:
model._register_load_state_dict_pre_hook(mcore_gpt_load_te_state_dict_pre_hook)
# Print models on all pp ranks.
if get_tensor_model_parallel_rank() == 0:
print(str(model))
return model
<!-- coding=utf-8-->
<!-- Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Megatron</title>
<style>
.wrapper {
max-width: 75%;
margin: auto;
}
h1 {
margin: 3rem 0 1rem 0;
padding: 0;
font-size: 1.5rem;
}
textarea {
width: 100%;
min-height: 300px;
resize: none;
border-radius: 8px;
border: 1px solid #ddd;
padding: 0.5rem;
box-shadow: inset 0 0 0.25rem #ddd;
&:focus {
outline: none;
border: 1px solid darken(#ddd, 5%);
box-shadow: inset 0 0 0.5rem darken(#ddd, 5%);
}
}
#the-count {
float: right;
padding: 0.1rem 0 0 0;
font-size: 0.875rem;
}
/* Chat containers */
.container {
font-family: 'Arial', sans-serif;
font-size: 16px;
border: 2px solid #dedede;
background-color: #f1f1f1;
border-radius: 5px;
padding: 15px;
margin: 10px 0;
}
/* Clear floats */
.container::after {
content: "";
clear: both;
display: table;
}
/* Style images */
.container img {
float: left;
max-width: 60px;
width: 100%;
margin-right: 20px;
border-radius: 50%;
}
</style>
</head>
<body>
<div class="wrapper">
<h1>Prompt Megatron</h1>
<textarea name="prompt" id="prompt" maxlength="1024" placeholder="Add prompt"autofocus></textarea>
<label for="tokens_to_generate">Number tokens to generate (1-1024):</label>
<input type="number" id="tokens_to_generate" name="tokens_to_generate" min="10" max="256", value=32>
<button onclick="submit_query()">Submit</button>
<div id="the-count">
<span id="current">0</span>
<span id="maximum">/ 1000</span>
</div>
<textarea name="response" id="response" maxlength="2048" placeholder="Megatron response..."></textarea>
</div>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script type="text/javascript">
function submit_query() {
$("#response").val("Waiting for Megatron response...");
$.ajax({
url:"api",
type:"PUT",
data:JSON.stringify({prompts: [$("#prompt").val()], tokens_to_generate: parseInt($("#tokens_to_generate").val(),10)}),
contentType:"application/json; charset=utf-8",
dataType:"json",
success: function(data){
data.max_len=35;
$("#response").val(data.text);
}
});
}
$('textarea').keyup(function() {
var characterCount = $(this).val().length,
current = $('#current'),
maximum = $('#maximum'),
theCount = $('#the-count');
current.text(characterCount);
if (characterCount >= 800) {
maximum.css('color', '#8f0001');
current.css('color', '#8f0001');
theCount.css('font-weight','bold');
} else {
maximum.css('color','#666');
theCount.css('font-weight','normal');
}
});
</script>
</body>
</html>
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from .api import (
generate,
generate_and_post_process,
beam_search_and_post_process)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Inference API."""
import torch
from megatron.core import mpu
from .communication import broadcast_float_list
from .generation import (
generate_tokens_probs_and_return_on_first_stage,
score_and_return_on_first_stage,
beam_search_and_return_on_first_stage)
from .tokenization import (
tokenize_prompts,
detokenize_generations)
from .forward_step import ForwardStep
def generate_and_post_process(model,
forward_step=ForwardStep,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
top_p_decay=0.0,
top_p_bound=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1,
return_logits=False):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs, logits = generate(
model,
forward_step=forward_step,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
tokens, prompts_plus_generations, prompts_plus_generations_segments = \
detokenize_generations(tokens, lengths, True)
if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist()
for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
output_log_probs[i] = prob[:len(seg)-1]
if return_logits:
assert(tokens_to_generate == 0)
assert(mpu.get_pipeline_model_parallel_world_size() == 1)
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens, logits
else:
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens
return None
def generate(model,
forward_step=None,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
top_p_decay=0.0,
top_p_bound=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
return_output_log_probs,
top_k_sampling, top_p_sampling, top_p_decay, top_p_bound,
temperature, add_BOS, use_eod_token_for_early_termination,
stop_on_double_eol,
stop_on_eol,
prevent_newline_after_colon,
random_seed]
values_float_tensor = broadcast_float_list(len(values), float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item()
top_p_decay = values_float_tensor[4].item()
top_p_bound = values_float_tensor[5].item()
temperature = values_float_tensor[6].item()
add_BOS = bool(values_float_tensor[7].item())
use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
stop_on_double_eol = bool(values_float_tensor[9].item())
stop_on_eol = bool(values_float_tensor[10].item())
prevent_newline_after_colon = bool(values_float_tensor[11].item())
random_seed = int(values_float_tensor[12].item())
if random_seed != -1:
torch.random.manual_seed(random_seed)
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if tokens_to_generate == 0:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, forward_step, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
top_k=top_k_sampling,
top_p=top_p_sampling,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon)
def beam_search_and_post_process(model,
forward_step=ForwardStep,
prompts=None,
tokens_to_generate=0,
beam_size=0,
add_BOS=False,
stop_token=50256,
num_return_gen=1,
length_penalty=1,
prevent_newline_after_colon=False):
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, scores = beam_search(model,
forward_step=forward_step,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size=beam_size,
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=num_return_gen,
length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True)
scores = scores.cpu().numpy().tolist()
return prompts_plus_generations, prompts_plus_generations_segments, scores
return None
def beam_search(model, forward_step, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1, prevent_newline_after_colon=False):
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
beam_size,
add_BOS,
stop_token,
num_return_gen,
length_penalty,
prevent_newline_after_colon]
values_float_tensor = broadcast_float_list(len(values), float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
beam_size = int(values_float_tensor[1].item())
add_BOS = bool(values_float_tensor[2].item())
stop_token = int(values_float_tensor[3].item())
num_return_gen = int(values_float_tensor[4].item())
length_penalty = values_float_tensor[5].item()
prevent_newline_after_colon = values_float_tensor[6].item()
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, forward_step, context_tokens_tensor, context_length_tensor,
beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## from huggingface beam search
class BeamHypotheses(object):
def __init__(self, num_beams, length_penalty=1.0, early_stopping=False):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs, length):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / length ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Communications utilities."""
import torch
from megatron.core import mpu
# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
torch.distributed.broadcast(tensor, rank)
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Forward step utilities."""
from collections.abc import Iterable
import torch
from megatron.training import get_args
from megatron.core import mpu, InferenceParams
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_length):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_length)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def _forward(self, tokens, position_ids, attention_mask):
return self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params)
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return self._with_pipelining_forward_step(tokens,
position_ids,
attention_mask,
micro_batch_size)
return self._no_pipelining_forward_step(tokens,
position_ids,
attention_mask)
def _forward_step_helper(self, tokens, position_ids, attention_mask, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
self.model.set_input_tensor(recv_buffer)
output_tensor = self._forward(tokens, position_ids, attention_mask)
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
return output_tensor
def _no_pipelining_forward_step(self, tokens, position_ids, attention_mask,
recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = self._forward_step_helper(tokens, position_ids,
attention_mask, recv_buffer=recv_buffer)
# Update the sequence length offset.
self.inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = self._forward_step_helper(tokens2use, position_ids2use, attention_mask, recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
self.inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
self.inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
self.inference_params.batch_size_offset = 0
return logits
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment