Commit bc5c7fa7 authored by wxj's avatar wxj
Browse files

第一次测试提交

parent 70fddd0f
# Megatron Core MoE Key Features
### Parallelism
- **Expert Parallel**
- A specific method of parallelism for MoE models, where experts are partitioned onto different workers and each worker processes a different batch of training samples, each worker process one or more experts for each MoE layer.
- **3D Parallel**: Data Parallel , Tensor Parallel, Pipeline Parallel, Sequence Parallel
- Note: When using MoE with expert parallelism and tensor parallelism, sequence parallelism must be used.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/SP for handling larger MoE variants.
- **Distributed optimizer.**
### Router and Load Balancing
- Router type:
- Top-K MLP router
- Expert Choice router (coming soon)
- Load Balancing algorithms:
- Sinkhorn (S-BASE)
- Aux loss / Load balancing loss
### Performance Optimizations
- GroupedGEMM when num local experts > 1
- Supported dtype: bf16
### Token Dispatch Mechanism
- Dropless / No token drop.
- Token drop. (coming soon)
### Ease of use
- Checkpoint converter (coming soon)
## Upcoming features
- Enhanced cutlass GroupedGEMM kernels
- Reduced host-device syncs.
- More supported dtype: fp32/bf16/fp16
- Kernel heuristics tuned for A100/A10/L40S
- BWD cutlass GroupedGEMM kernels supported
- Token permutation / unpermutation fusion
- Fused Sinkhorn Kernel
- Context Parallel with MoE
- FP8 training support
- Enable ’--tp-comm-overlap‘ for MoE
- Distributed optimizer for MoE params.
# User Guide
### MoE Related Arguments
| Item | Description |
| --- | --- |
| num-experts | Number of Experts in MoE (None means no MoE) |
| expert-model-parallel-size | Degree of expert model parallelism. |
| moe-grouped-gemm | When there are multiple experts per rank, compress multiple local gemms into a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 |
| moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| moe-router-topk | Number of experts to route to for each token. The default is 2. |
| moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. |
| moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. |
| moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. |
| moe-token-dropping | This feature involves selectively dropping and padding tokens for each expert to achieve a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note: Currently unsupported. |
### Example
To train a top-2 MoE model with an auxiliary loss, include the following arguments:
```python
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--use-distributed-optimizer
```
## A detailed MoE script:
<details>
<summary>Click here. </summary>
```bash
#!/bin/bash
# Runs Mixtral 8x7B model on 16 A100 GPUs
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--use-mcore-models
--disable-bias-linear
--seq-length 2048
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
)
MOE_ARGS=(
--num-experts 8
--expert-model-parallel-size 4
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, None. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
)
DATA_ARGS=(
--tokenizer-type Llama2Tokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--data-path $DATA_PATH
--split 99990,8,2
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 128
--lr 1e-4
--train-iters 500000
--lr-decay-iters 320000
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--lr-warmup-iters 500
--clip-grad 1.0
--bf16
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 4
--pipeline-model-parallel-size 1
--sequence-parallel
--use-distributed-optimizer
)
LOGGING_ARGS=(
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
--no-load-optim \
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"Mixtral-Finetuning"}
--wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"}
)
fi
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
```
</details>
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.jit import jit_fuser
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.transformer_config import TransformerConfig
class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""
def __init__(self, num_local_experts: int, config: TransformerConfig):
super().__init__(config=config)
self.config: TransformerConfig = config
self.num_local_experts = num_local_experts
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
if self.config.activation_func != F.silu:
raise ValueError("Activation function must be silu when using GroupedMLP.")
@jit_fuser
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
# How many feature each rank holds for fc1 and fc2, respectively.
tp_size = parallel_state.get_tensor_model_parallel_world_size()
fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts
if config.gated_linear_unit:
# Project to 4h. If using swiglu double the output width,
# see https://arxiv.org/pdf/2002.05202.pdf
fc1_output_size *= 2
fc1_output_size_per_partition = divide(fc1_output_size, tp_size)
fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts
fc2_input_size_per_partition = divide(fc2_input_size, tp_size)
# Note: The current kernel implementations of grouped_gemm
# does not support transposition with CUTLASS grouped GEMM
# (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
# and as a result we avoid allocate the transpose of weights.
# Initialize weight.
if config.use_cpu_initialization:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight1,
self.config.hidden_size,
fc1_output_size,
fc1_output_size_per_partition,
partition_dim=1,
init_method=config.init_method,
params_dtype=config.params_dtype,
)
_initialize_affine_weight_cpu(
self.weight2,
fc2_input_size,
self.config.hidden_size,
fc2_input_size_per_partition,
partition_dim=0,
init_method=config.output_layer_init_method,
params_dtype=config.params_dtype,
)
else:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight1,
config.init_method,
partition_dim=1,
expert_parallel=self.expert_parallel,
)
_initialize_affine_weight_gpu(
self.weight2,
config.output_layer_init_method,
partition_dim=0,
expert_parallel=self.expert_parallel,
)
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = gg.ops.gmm(
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False
)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
else:
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
h = self.activation_func(h)
h = torch.matmul(h, w2)
fc2_output = h
return fc2_output, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
raise NotImplementedError(
'Currently distributed checkpointing is not supported for GroupedMLP'
)
class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
self.add_bias = config.add_bias_linear
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden)
output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
return output_local, output_bias_local
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Maps local expert to global experts. """
sharded_state_dict = {}
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
expert_sharded_prefix = f'{prefix}experts.'
for expert_local_idx, expert in enumerate(self.local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
expert_state_dict_prefix = f'{prefix}local_experts.{expert_local_idx}.'
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
expert_state_dict = expert.sharded_state_dict(
expert_state_dict_prefix, expert_sharded_offsets, metadata
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(
expert_state_dict, expert_state_dict_prefix, expert_sharded_prefix
)
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in expert_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (
*replica_id[:2],
parallel_state.get_data_modulo_expert_parallel_rank(),
)
sharded_state_dict.update(expert_state_dict)
return sharded_state_dict
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
try:
import grouped_gemm
except ImportError:
grouped_gemm = None
def grouped_gemm_is_available():
return grouped_gemm is not None
def assert_grouped_gemm_is_available():
assert grouped_gemm_is_available(), (
"Grouped GEMM is not available. Please run "
"`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0`."
)
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class BaseMoELayer(MegatronModule, ABC):
"""Base class for a mixture of experts layer.
Args:
config (TransformerConfig): Configuration object for the transformer model.
"""
def __init__(self, config: TransformerConfig, layer_number: int = None):
super(BaseMoELayer, self).__init__(config)
self.config = config
self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()
assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size"
assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
self.router = None
self.experts = None
self.token_dispatcher = None
self.layer_number = layer_number
@abstractmethod
def forward(self, hidden_states):
pass
def set_layer_number(self, layer_number: int):
self.layer_number = layer_number
self.router.set_layer_number(layer_number)
class MoELayer(BaseMoELayer):
"""Mixture of experts Layer **currently only supports no token dropping**.
Args:
BaseMoELayer (MegatronModule): Base class for MoE layers
"""
def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
self.router = TopKRouter(config=self.config)
if self.config.moe_grouped_gemm:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules, MLPSubmodules)
self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall":
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
else:
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
def forward(self, hidden_states: torch.Tensor):
# process MoE
scores, indices = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, scores, indices
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
return output, mlp_bias
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core import parallel_state
def switch_load_balancing_loss_func(gates, mask, moe_aux_loss_coeff):
"""Calculate the auxiliary loss for better load balacing.
Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
Args:
gates (torch.Tensor): The gates tensor representing the routing probabilities for each expert.
mask (torch.Tensor): The 2D mask tensor indicating which experts are selected.
Returns:
torch.Tensor: The auxiliary loss for load balancing.
"""
num_experts = mask.size(-1)
gates_mean = gates.mean(dim=0)
top_k = mask[0].count_nonzero()
selection_mean = mask.float().mean(dim=0) / top_k
aux_loss = torch.sum(gates_mean * selection_mean) * num_experts
aux_loss *= moe_aux_loss_coeff
return aux_loss
def z_loss_func(logits, z_loss_coeff):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss
def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
"""Sinkhorn based MoE routing function"""
cost = torch.exp(cost)
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
eps = 0.00000001
error = 1e9
d1_old = d1
while error > tol:
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
error = torch.mean(torch.abs(d1_old - d1))
d1_old = d1
return d1 * cost * d0.unsqueeze(1)
class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that compute and scales the grad for auxiliary loss.
"""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
@staticmethod
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
"""Preserve the aux_loss by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
aux_loss (torch.Tensor): The auxiliary loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(aux_loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for auxiliary loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
"""
(aux_loss,) = ctx.saved_tensors
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
return grad_output, scaled_aux_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the aux loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
"""
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
def permute(tokens, indices, topk: int = 1):
"""Permute the tokens based on the indices. Token with the same index will be grouped together.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens, topk].
topk (int, optional): The topk value. Defaults to 1.
Returns:
torch.Tensor: The permuted tensor.
"""
if topk > 1:
assert indices.size(1) == topk
flatten_indices = indices.view(-1)
sorted_indices = torch.argsort(flatten_indices, stable=True)
permuted_tokens = tokens.index_select(0, sorted_indices // topk)
return permuted_tokens, sorted_indices
def unpermute(permuted_tokens, sorted_indices, probs: torch.Tensor = None, topk: int = 1):
"""Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
Args:
permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
topk (int, optional): The number of top tokens to consider for merging with probabilities. Defaults to 1.
"""
if topk > 1:
assert probs is not None
assert (
probs.size(0) == permuted_tokens.size(0) // topk
), f"{probs.size()} {permuted_tokens.size()}"
if probs is not None:
assert probs.size(0) == permuted_tokens.size(0) // topk
assert probs.size(1) == topk, f"probs size {probs.size()} merge_factor {topk}"
unpermuted_tokens = torch.zeros_like(permuted_tokens)
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
if probs is not None:
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
return unpermuted_tokens
def save_to_aux_losses_tracker(name: str, loss: torch.Tensor, layer_number: int, num_layers: int):
"""Save the auxiliary loss for logging.
Args:
name (str): The name of the loss.
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
"""
# Skip aux loss logging if layer_number is None.
if layer_number is None:
return
if name not in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name] = torch.zeros(
num_layers, device=loss.device
)
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name][layer_number - 1] += loss.detach()
def clear_aux_losses_tracker():
"""Clear the auxiliary losses."""
for name in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name].zero_()
def get_aux_losses_tracker():
"""Return the auxiliary losses."""
return parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER
def aggregate_aux_losses_tracker_across_pipeline_parallel():
"""Sum aux losses across PP."""
for name in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
loss = parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name]
torch.distributed.all_reduce(loss, group=parallel_state.get_pipeline_model_parallel_group())
def track_moe_metrics(
loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False
):
# Aux loss logging
aggregate_aux_losses_tracker_across_pipeline_parallel()
if writer is not None:
aux_losses = {k: v.float() * loss_scale for k, v in get_aux_losses_tracker().items()}
for name, loss_list in aux_losses.items():
if total_loss_dict is not None:
if name not in total_loss_dict:
total_loss_dict[name] = loss_list.mean()
else:
total_loss_dict[name] += loss_list.mean()
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
writer.add_scalar(name, loss_list.mean(), iteration)
if per_layer_logging:
for i, loss in enumerate(loss_list.tolist()):
writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)
# W&B logging lacks support for logging multiple scalars simultaneously.
# As a workaround, we log each scalar individually first, then we can create
# a custom panel to manually group them to a single plot.
if wandb_writer:
wandb_writer.log({f"{name}": loss_list.mean()}, iteration)
if per_layer_logging:
wandb_writer.log(
{
f"moe/{name}_layer_{i}": loss
for i, loss in enumerate(loss_list.tolist())
},
iteration,
)
clear_aux_losses_tracker()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import math
from abc import ABC, abstractmethod
from typing import Callable, List
import torch
from megatron.core.tensor_parallel import (
gather_from_sequence_parallel_region,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
save_to_aux_losses_tracker,
sinkhorn,
switch_load_balancing_loss_func,
z_loss_func,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class Router(ABC, MegatronModule):
"""Base Router class"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the Router module.
Args:
config (TransformerConfig): Configuration object for the Transformer model.
"""
super().__init__(config)
self.config = config
self.num_experts = self.config.num_moe_experts
self.moe_aux_loss_func = None
self.layer_number = None
# Initialize the gate weights.
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size))
)
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
config.init_method(self.weight)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Logits tensor.
"""
logits = torch.nn.functional.linear(input, self.weight)
return logits
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices.
"""
raise NotImplementedError("Routing function not implemented.")
@abstractmethod
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
raise NotImplementedError("Forward function not implemented.")
def set_layer_number(self, layer_number: int):
"""Set the layer number for the router."""
self.layer_number = layer_number
class TopKRouter(Router):
"""Route each token to the top-k experts."""
def __init__(self, config: TransformerConfig,) -> None:
"""Initialize the zero token dropping router.
Args:
config (TransformerConfig): The configuration for the transformer model.
"""
super().__init__(config=config)
assert config.moe_token_dropping is False
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.input_jitter = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
torch.Tensor: The logits tensor after applying sinkhorn routing.
"""
def _sinkhorn_activation(logits):
if self.topk == 1:
logits = torch.sigmoid(logits)
else: # k > 1
logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
return logits
assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
if self.training:
with torch.no_grad():
norm_logits = sinkhorn(
logits.to(dtype=torch.float32)
) # explicit fp32 conversion for stability
_, indices = torch.topk(norm_logits, k=self.topk, dim=1)
logits = _sinkhorn_activation(logits)
scores = torch.gather(logits, 1, indices)
else:
logits = _sinkhorn_activation(logits)
scores, indices = torch.topk(logits, k=self.topk, dim=1)
return scores, indices
def aux_loss_load_balancing(self, logits: torch.Tensor):
"""Apply loss-based load balancing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The scores and the indices tensor after applying load balancing.
"""
top_logits, indices = torch.topk(logits, k=self.topk, dim=1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
# Apply load balancing loss
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
scores = self.apply_load_balancing_loss(probs, indices, activation=scores)
return scores, indices
def apply_load_balancing_loss(
self, probs: torch.Tensor, indices: torch.Tensor, activation: torch.Tensor,
):
"""Applies auxiliary loss to the MoE layer.
Args:
loss_func (callable): The loss function to be used.
probs (torch.Tensor): The probabilities output by the MoE layer.
indices (torch.Tensor): The indices of the selected experts.
activation (torch.Tensor): The activation tensor to attach the gradient function to.
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
mask = torch.nn.functional.one_hot(indices, num_classes=self.num_experts).sum(dim=1)
aux_loss = switch_load_balancing_loss_func(probs, mask, self.config.moe_aux_loss_coeff)
save_to_aux_losses_tracker(
"load_balancing_loss",
aux_loss / self.config.moe_aux_loss_coeff,
self.layer_number,
self.config.num_layers,
)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def apply_z_loss(self, logits):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None:
z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
save_to_aux_losses_tracker(
"z_loss",
z_loss / self.config.moe_z_loss_coeff,
self.layer_number,
self.config.num_layers,
)
return logits
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, device=input.device),
torch.tensor(1.0 + eps, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Probs and the indices tensor.
"""
logits = logits.view(-1, self.config.num_moe_experts)
# Apply Z-Loss
logits = self.apply_z_loss(logits)
if (
self.config.tensor_model_parallel_size > 1
and self.config.moe_token_dispatcher_type == "alltoall"
):
# Gather the logits from the TP region
logits = gather_from_sequence_parallel_region(logits)
if self.routing_type == "sinkhorn":
scores, indices = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, indices = self.aux_loss_load_balancing(logits)
elif self.routing_type == "none":
# A naive top-k routing without load balancing
top_logits, indices = torch.topk(logits, k=self.topk, dim=1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
return scores, indices
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: scores and indices.
"""
self.hidden = input.shape[-1]
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
logits = logits.view(-1, self.config.num_moe_experts)
scores, indices = self.routing(logits)
return scores, indices
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import abstractmethod
from typing import List, Optional, Tuple
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel
from megatron.core.transformer.moe.moe_utils import permute, unpermute
from megatron.core.transformer.transformer_config import TransformerConfig
class MoETokenDispatcher:
"""
MoE Token Dispatcher
"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.config = config
@abstractmethod
def token_permutation(
self, tokens: torch.Tensor, indices: torch.Tensor,
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
indices (torch.Tensor): indices tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(
self, expert_output: torch.Tensor, scores: torch.Tensor, indices: torch.Tensor,
):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
scores (torch.Tensor): Each token's score with each expert.
indices (torch.Tensor): The indices used to reorder the expert output.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise NotImplementedError("Restore function not implemented.")
class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
"""
AllGather Based Token dispatcher.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the zero token dropping router.
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert len(self.local_expert_indices) > 0, "Expected at least one local expert index"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
# self.local_probs: probs of global token assignment to local experts.
self.local_probs = None
# self.indices: The indices of `local_indices` (which holds the un-sorted expert indices of tokens that local expert can process) that give its sorted order along dim 0.
self.indices = None
# self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where each element is True if it's between the local_expert_indices. Only useful when cross device token permutation is enabled and **AllGahter** is performed.
self.global_local_map = None
def token_permutation(
self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind: torch.Tensor
):
"""Dispatch tokens to local experts. It's composed of two stages:
(1) Permute the tokens across the expert parallel devices. After this stage,
each device receives all of the tokens assigned to its local set of experts
in its local HBM.
(2) Permute the tokens locally so that they are grouped by their expert
assignment. After the stage (1), the tokens are grouped by which device
they came from. We re-order them locally for subsequent efficient computation.
Args:
hidden_states: input tokens of shape [SeqLen/TP, MBS, HiddenSize]
max_prob: probs of local token assignment to global experts.
max_ind: token assignment to local experts.
Returns:
permuted_local_hidden_states: Permutation of tokens to local experts group.
tokens_per_expert: the number of tokens each local expert to process.
"""
self.hidden_shape = hidden_states.shape
# [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Permute the tokens across the expert parallel devices.
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
# [S*B/TP, H] -> [S*B, H]
global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
hidden_states
)
with torch.no_grad():
global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
max_ind
)
# Create a mask of mapping between global and local tokens where each
# element is True if it's between the local_expert_indices
global_local_mask = (global_indices >= self.local_expert_indices[0]) & (
global_indices <= self.local_expert_indices[-1]
)
local_indices = global_indices.masked_select(global_local_mask)
if self.router_topk > 1: # k > 1
global_probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(max_prob)
self.local_probs = global_probs.masked_select(global_local_mask)
else:
self.local_probs = max_prob
# Reshape global_local_mask to be compatible with Tensor.gather
global_local_map = global_local_mask.nonzero()[:, 0]
self.global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
local_hidden_states = torch.gather(global_hidden_states, 0, self.global_local_map)
else:
if self.router_topk > 1:
global_local_mask = torch.ones_like(max_ind).bool()
local_indices = max_ind.masked_select(global_local_mask)
self.local_probs = max_prob.masked_select(global_local_mask)
global_local_map = global_local_mask.nonzero()[:, 0]
self.global_local_map = global_local_map.view(-1, 1).expand(
-1, hidden_states.shape[-1]
)
local_hidden_states = torch.gather(hidden_states, 0, self.global_local_map)
else:
local_indices = max_ind
self.local_probs = max_prob
local_hidden_states = hidden_states
self.global_local_map = None
with torch.no_grad():
# The indices of local_indices that give its sorted order along dim 0.
self.indices = torch.argsort(local_indices, dim=0)
tokens_per_expert = torch.histc(
local_indices,
bins=self.num_local_experts,
min=self.local_expert_indices[0],
max=self.local_expert_indices[-1],
)
tokens_per_expert = tokens_per_expert.cpu().to(torch.long)
# Stage2: permute the tokens locally so that they are grouped by their expert assignment
# Reshape indices to be compatible with Tensor.gather
self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
permuted_local_hidden_states = torch.gather(local_hidden_states, 0, self.indices)
return (
permuted_local_hidden_states,
tokens_per_expert,
)
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
):
"""
Reverse process of `dispatch()` which permutes the ouput of local
experts locallay and across expert parallel rank into the original order to
produce the final output.
Args:
hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize],
ouput of local experts.
bias (optional): The bias tensor.
Returns:
output_total: un-permuted updated hidden states output from all local experts
with shape of [SeqLen/TP, MBS, HiddenSize]
"""
# Stage1: unpermute the tokens and bias locally respectively.
scores = self.local_probs.to(dtype=hidden_states.dtype)
unpermuted_local_hidden = torch.zeros_like(hidden_states)
assert self.indices.shape == hidden_states.shape
unpermuted_local_hidden = unpermuted_local_hidden.scatter(0, self.indices, hidden_states)
# Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
if self.router_topk > 1:
unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
unpermuted_local_bias = None
if self.add_bias:
assert bias is not None
unpermuted_local_bias = torch.zeros_like(hidden_states)
assert self.indices.shape == bias.shape
unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias)
if self.router_topk > 1:
unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
output_total = unpermuted_local_hidden
output_bias_total = unpermuted_local_bias
# Unpermute the tokens across expert parallel devices.
if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1):
assert (
self.global_local_map is not None
), "global_local_map is necessary for `AllGather`."
ep_group_size = parallel_state.get_tensor_and_expert_parallel_world_size()
# hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP)
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
unpermuted_global_hidden = torch.zeros(
global_hidden_shape, dtype=hidden_states.dtype, device=torch.cuda.current_device()
)
# Reshape global_local_map to be compatible with Tensor.scatter
assert self.global_local_map.shape == unpermuted_local_hidden.shape
unpermuted_global_hidden = unpermuted_global_hidden.scatter_add(
0, self.global_local_map, unpermuted_local_hidden
)
output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_hidden
)
if self.add_bias:
# Unpermute the bias across expert parallel devices.
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
unpermuted_global_bias = unpermuted_global_bias.scatter_add(
0, self.global_local_map, unpermuted_local_bias
)
output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_bias
)
# bias is duplicated across tensor parallelism ranks;
# reduce scatter reduces bias across tensor parallel_ranks
output_bias_total = (
output_bias_total / parallel_state.get_tensor_model_parallel_world_size()
)
else:
if self.router_topk > 1:
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
unpermuted_global_hidden = torch.zeros(
global_hidden_shape,
dtype=hidden_states.dtype,
device=torch.cuda.current_device(),
)
output_total = unpermuted_global_hidden.scatter_add(
0, self.global_local_map, unpermuted_local_hidden
)
if self.add_bias:
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
output_bias_total = unpermuted_global_bias.scatter_add(
0, self.global_local_map, unpermuted_local_bias
)
if self.router_topk == 1:
output_total = output_total * scores
output_total = output_total.view(self.hidden_shape)
if self.add_bias:
assert output_bias_total is not None
if self.router_topk == 1:
output_bias_total = output_bias_total * scores
output_bias_total = output_bias_total.view(self.hidden_shape)
else:
output_bias_total = None
return output_total, output_bias_total
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
AlltoAll Based Token dispatcher.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the AlltoAll token dispatcher.
Args:
num_local_experts (int): Number of local experts on the current device.
local_expert_indices (List[int]): Indices of local experts on the current device.
config (TransformerConfig): Configuration for the transformer model.
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
self.num_experts = config.num_moe_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert (
len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
self.ep_size = config.expert_model_parallel_size
self.scores: torch.Tensor = None
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
"""
Preprocess token indices for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the input indices.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts.
Args:
indices (torch.Tensor): Tensor of indices mapping tokens to experts.
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
num_local_tokens_per_expert = torch.histc(
indices, bins=self.num_experts, min=0, max=self.num_experts
)
# num_local_tokens_per_expert: [num_experts]
ep_size = self.config.expert_model_parallel_size
if ep_size > 1:
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
.sum(axis=1)
.to(torch.device("cpu"))
.numpy()
)
num_global_tokens_per_expert = _gather_along_first_dim_expert_parallel(
num_local_tokens_per_expert
).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, self.local_expert_indices
]
self.output_splits = (
self.num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu")).numpy()
)
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0).to(
torch.device("cpu"), non_blocking=True
)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
else:
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
-1, self.num_experts
)
num_tokens_per_local_expert = num_local_tokens_per_expert.to(
torch.device("cpu"), non_blocking=True
)
if self.num_local_experts > 1:
expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.config.num_moe_experts)],
dtype=torch.int32,
device=torch.cuda.current_device(),
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
)
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor, scores: torch.Tensor, indices: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
Args:
hidden_states (torch.Tensor): Input token embeddings.
scores (torch.Tensor): Scores of tokens assigned to experts.
indices (torch.Tensor): Indices of tokens assigned to experts.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
"""
self.hidden_shape = hidden_states.shape
self.scores = scores
assert scores.dim() == 2, "Expected 2D tensor for scores"
assert indices.dim() == 2, "Expected 2D tensor for indices"
tokens_per_expert = self.preprocess(indices)
# TODO Optimize EP=1 case
# Flatten the input tensor
# hidden_states: [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Perform tensor parallel AlltoAll communication
# hidden_states: [S*B/TP, H] -> [S*B, H/TP]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states)
# Permutation 1: input to AlltoAll input
self.local_input_tokens_global_experts_indices = indices
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states, self.local_input_tokens_global_experts_indices, topk=self.router_topk,
)
# Perform expert parallel AlltoAll communication
global_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
)
# Permutation 2: AlltoAll output to expert input if num_local_experts > 1
if self.num_local_experts > 1:
global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
global_input_tokens, self.global_input_tokens_local_experts_indices
)
# Perform tensor parallel All-Gather
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(
global_input_tokens
)
return global_input_tokens, tokens_per_expert
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverse the token permutation to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
# Perform tensor parallel Reduce-Scatter
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.reduce_scatter_last_dim_to_tensor_parallel_region(
hidden_states
)
# Unpermutation 2: expert output to AlltoAll input
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
if self.num_local_experts > 1:
hidden_states = unpermute(
hidden_states, self.reversed_global_input_permutation_mapping,
)
# Perform expert parallel AlltoAll communication
permutated_local_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
hidden_states,
self.input_splits,
self.output_splits,
)
# Unpermutation 1: AlltoAll output to output
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
probs=self.scores,
topk=self.router_topk,
)
# Perform tensor parallel AlltoAll communication
if parallel_state.get_tensor_model_parallel_world_size() > 1:
# output: [S*B, H/TP] -> [S*B/TP, H]
output = tensor_parallel.all_to_all_hp2sp(output)
# Reshape the output tensor
output = output.view(self.hidden_shape)
return output, None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import types
from dataclasses import dataclass, field
from typing import Tuple, Union
import torch
@dataclass
class ModuleSpec:
"""This is a Module Specification dataclass.
Specification defines the location of the module (to import dynamically)
or the imported module itself. It also defines the params that need to be
passed to initialize the module.
Args:
module (Union[Tuple, type]): A tuple describing the location of the
module class e.g. `(module.location, ModuleClass)` or the imported
module class itself e.g. `ModuleClass` (which is already imported
using `from module.location import ModuleClass`).
params (dict): A dictionary of params that need to be passed while init.
"""
module: Union[Tuple, type]
params: dict = field(default_factory=lambda: {})
submodules: type = None
def import_module(module_path: Tuple[str]):
"""Import a named object from a module in the context of this function.
TODO: make this importer module more robust, at least make sure there
are no side effects of using this as is
"""
base_path, name = module_path
try:
module = __import__(base_path, globals(), locals(), [name])
except ImportError as e:
print(f"couldn't import module due to {e}")
return None
return vars(module)[name]
def get_module(spec_or_module: Union[ModuleSpec, type], **additional_kwargs):
# If a module clas is already provided return it as is
if isinstance(spec_or_module, (type, types.FunctionType)):
return spec_or_module
# If the module is provided instead of module path, then return it as is
if isinstance(spec_or_module.module, (type, types.FunctionType)):
return spec_or_module.module
# Otherwise, return the dynamically imported module from the module path
return import_module(spec_or_module.module)
def build_module(spec_or_module: Union[ModuleSpec, type], *args, **kwargs):
# If the passed `spec_or_module` is
# a `Function`, then return it as it is
# NOTE: to support an already initialized module add the following condition
# `or isinstance(spec_or_module, torch.nn.Module)` to the following if check
if isinstance(spec_or_module, types.FunctionType):
return spec_or_module
# If the passed `spec_or_module` is actually a spec (instance of
# `ModuleSpec`) and it specifies a `Function` using its `module`
# field, return the `Function` as it is
if isinstance(spec_or_module, ModuleSpec) and isinstance(
spec_or_module.module, types.FunctionType
):
return spec_or_module.module
# Check if a module class is provided as a spec or if the module path
# itself is a class
if isinstance(spec_or_module, type):
module = spec_or_module
elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
module = spec_or_module.module
else:
# Otherwise, dynamically import the module from the module path
module = import_module(spec_or_module.module)
# If the imported module is actually a `Function` return it as it is
if isinstance(module, types.FunctionType):
return module
# Finally return the initialized module with params from the spec as well
# as those passed as **kwargs from the code
# Add the `submodules` argument to the module init call if it exists in the
# spec.
if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
kwargs["submodules"] = spec_or_module.submodules
try:
return module(
*args, **spec_or_module.params if hasattr(spec_or_module, "params") else {}, **kwargs
)
except Exception as e:
# improve the error message since we hide the module name in the line above
import sys
tb = sys.exc_info()[2]
raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
sys.exc_info()[2]
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
import warnings
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDelayedScaling,
TENorm,
get_cpu_offload_context,
te_checkpoint,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import BaseTransformerLayer, TransformerLayer
from megatron.core.transformer.utils import sharded_state_dict_default
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
def get_num_layers_to_build(config: TransformerConfig) -> int:
num_layers_per_pipeline_rank = (
config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# Interleaved pipeline parallelism:
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
num_layers_to_build = num_layers_per_virtual_rank
else:
# Non-interleaved pipeline parallelism:
# Each stage gets a contiguous set of layers.
num_layers_to_build = num_layers_per_pipeline_rank
return num_layers_to_build
@dataclass
class TransformerBlockSubmodules:
layer_specs: List[ModuleSpec] = None
layer_norm: Optional[Union[ModuleSpec, torch.nn.Module]] = None
def _get_block_submodules(
config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec],
) -> TransformerBlockSubmodules:
# Transformer block submodules.
if isinstance(spec, TransformerBlockSubmodules):
return spec
# ModuleSpec here is generally assumed to be for a transformer layer that
# is implemented in `transformer_layer.py` or if it subclasses
# `BaseTransformerLayer` from the `transformer_layer.py` file.
elif isinstance(spec, ModuleSpec):
if issubclass(spec.module, TransformerBlock):
return spec.submodules
elif issubclass(spec.module, BaseTransformerLayer):
num_layers = get_num_layers_to_build(config)
return TransformerBlockSubmodules(layer_specs=[spec] * num_layers, layer_norm=TENorm,)
else:
raise Exception(f"specialize for {spec.module.__name__}.")
else:
raise Exception(f"specialize for {type(spec).__name__}.")
class TransformerBlock(MegatronModule):
"""Transformer class."""
def __init__(
self,
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
):
super().__init__(config=config)
self.submodules = _get_block_submodules(config, spec)
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
# Dictionary to store CUDA graphs. Number of items in the dictionary = len(self.layers).
# Item `i` in the dictionary is a list of `N` CUDA graphs for layer 'i' where N is the
# number of microbatches. Multiple CUDA graphs per layer is required to support
# pipelining which requires running FWD graph of multiple microbatches before BWD graph.
self.cuda_graphs = {}
self.current_microbatch = -1
# required for pipeline parallel schedules
self.input_tensor = None
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
if get_cpu_offload_context is not None:
(
self.offload_context,
self.group_prefetch_offload_commit_async,
) = get_cpu_offload_context(
self.config.cpu_offloading,
self.config.cpu_offloading_num_layers,
self.config.cpu_offloading_activations,
self.config.cpu_offloading_weights,
)
self.config._cpu_offloading_context = (
self.offload_context if self.config.cpu_offloading else None
)
else:
assert (
self.config.cpu_offloading == False
), "CPU Offloading is enabled when TE is not present"
self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None
self.config._cpu_offloading_context = None
self._build_layers()
self.num_layers_per_pipeline_rank = len(self.layers)
def _build_layers(self):
# Transformer layers.
# @jcasper can we improve how we deal with layer_number?
# currently it's only used in CoreAttention?
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
def build_layer(layer_spec, layer_number):
return build_module(layer_spec, config=self.config, layer_number=layer_number,)
# offset is implicit in TransformerLayer
self.layers = torch.nn.ModuleList(
[
build_layer(layer_spec, i + 1)
for i, layer_spec in enumerate(self.submodules.layer_specs)
]
)
# # TODO: add back standalone_embedding_stage
# if self.num_layers == 0:
# # When a standalone embedding stage is used (e.g.,
# # args.standalone_embedding_stage == True), virtual pipeline ranks
# # on pipeline rank 0 will have zero transformer layers assigned to
# # them. This results in the model's input and output tensors to be
# # the same, which will cause failure for certain output tensor
# # optimizations (e.g., pipeline output deallocation). To remedy
# # this, we assign a 'no-op' layer on these ranks, which will
# # disconnect the input tensor from the output tensor.
# self.num_layers = 1
# self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
# else:
# self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior
if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
self.final_layernorm = build_module(
self.submodules.layer_norm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None # Either this or nn.Identity
def _get_layer(self, layer_number: int):
return self.layers[layer_number]
def _checkpointed_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor,
context_mask: Tensor,
rotary_pos_emb: Tensor,
packed_seq_params: PackedSeqParams,
):
"""Forward method with activation checkpointing."""
def custom(start: int, end: int):
def custom_forward(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
):
for index in range(start, end):
layer = self._get_layer(index)
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=None,
packed_seq_params=packed_seq_params,
)
return hidden_states, context
return custom_forward
def checkpoint_handler(forward_func):
if self.config.fp8:
return te_checkpoint(
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
return tensor_parallel.checkpoint(
forward_func,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers_per_pipeline_rank:
hidden_states, context = checkpoint_handler(
custom(l, l + self.config.recompute_num_layers)
)
l += self.config.recompute_num_layers
elif self.config.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
recompute_skip_num_layers = 0
for l in range(self.num_layers_per_pipeline_rank):
# Skip recomputation when input grad computation is not needed.
# Need to have at least one input tensor with gradient computation
# for re-enterant autograd engine.
if self.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if (
l >= recompute_skip_num_layers
and l < self.config.recompute_num_layers + recompute_skip_num_layers
):
hidden_states, context = checkpoint_handler(custom(l, l + 1))
else:
hidden_states, context = custom(l, l + 1)(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def set_input_tensor(self, input_tensor: Tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
):
# hidden_states (float): [s, b, h]
# attention_mask (bool): [1, 1, s, s]
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True,
)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_tensor_model_parallel_group()
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context and fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
else:
for l_no, layer in enumerate(self.layers):
with self.offload_context:
if (len(self.cuda_graphs) == 0) or (not self.training):
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
)
# CUDA graph doesn't output context and is expected to be None
assert (
(context is None)
or (not self.config.enable_cuda_graph)
or (not self.training)
)
else:
# CUDA graph replay for layer `l_no` and microbatch `self.current_microbatch`
# CUDA graph requires positional arguments with the exception of is_first_microbatch.
# Also CUDA graph accepts only Tensor inputs and outputs. Hence, the arg list and
# returned list is limited to `hidden_states`.
assert (len(self.cuda_graphs) > l_no) and (
self.current_microbatch < len(self.cuda_graphs[l_no])
)
hidden_states = self.cuda_graphs[l_no][self.current_microbatch](
hidden_states, is_first_microbatch=(self.current_microbatch == 0),
)
if (
torch.is_grad_enabled()
and self.config.cpu_offloading
and self.group_prefetch_offload_commit_async is not None
):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
# Final layer norm.
if self.final_layernorm is not None:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None
) -> ShardedStateDict:
assert not sharded_offsets, "Unexpected sharded offsets"
non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False
)
sharded_state_dict = {}
layer_prefix = f'{prefix}layers.'
num_layers = self.config.num_layers
for layer in self.layers:
offset = layer._get_layer_offset()
global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock
if non_homogeneous_layers:
sharded_prefix = f'{layer_prefix}{global_layer_offset}.'
sharded_pp_offset = []
else:
sharded_prefix = layer_prefix
sharded_pp_offset = [
(0, global_layer_offset, num_layers)
] # PP sharding offset for ShardedTensors
layer_sharded_state_dict = layer.sharded_state_dict(
state_dict_prefix, sharded_pp_offset, metadata
)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)
# Add modules other than self.layers
for name, module in self.named_children():
if not module is self.layers:
sharded_state_dict.update(
sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata
)
)
return sharded_state_dict
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import types
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
import torch.nn.functional as F
from ..model_parallel_config import ModelParallelConfig
from ..utils import init_method_normal, scaled_init_method_normal
@dataclass
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
The initialization function has an argument for each parameter, including those in ModelParallelConfig.
"""
####################
# model architecture
####################
num_layers: int = 0
"""Number of transformer layers in a transformer block."""
hidden_size: int = 0
"""Transformer hidden size."""
num_attention_heads: int = 0
"""Number of transformer attention heads."""
num_query_groups: int = None
"""Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size: int = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided."""
kv_channels: int = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""
hidden_dropout: float = 0.1
"""Dropout probability for transformer hidden state."""
attention_dropout: float = 0.1
"""Post attention dropout probability."""
fp32_residual_connection: bool = False
"""If true, move residual connections to fp32."""
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
"""If True, uses the original BERT residule connection ordering."""
layernorm_epsilon: float = 1e-5
"""Epsilon value for any LayerNorm operations."""
layernorm_zero_centered_gamma: bool = False
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""
add_bias_linear: bool = True
"""Include a bias term in all linear layers (QKV projections, after core attention, and two in
MLP layer)."""
add_qkv_bias: bool = False
"""Add a bias term only for QKV projections."""
gated_linear_unit: bool = False
"""Use a gated linear unit for the first linear layer in the MLP."""
activation_func: Callable = F.gelu
"""Activation function to use for the non-linearity in the MLP."""
activation_func_fp8_input_store: bool = False
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts: int = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
rotary_interleaved: bool = False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""
window_size: Optional[Tuple[int, int]] = None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
normalization: bool = "LayerNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm: bool = False
"""Whether to apply LayerNorm to the query and key embeddings."""
test_mode: bool = False
"""Whether to run real-time tests."""
####################
# initialization
####################
init_method: Callable = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method: Callable = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
init_method_std: float = 0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""
####################
# mixed-precision
####################
apply_query_key_layer_scaling: bool = False
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16."""
attention_softmax_in_fp32: bool = True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""
####################
# fusion
####################
bias_activation_fusion: bool = False
"""If True, fuses bias addition and the activation function when possible."""
masked_softmax_fusion: bool = False
"""If True, uses softmax fusion."""
persist_layer_norm: bool = False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""
memory_efficient_layer_norm: bool = False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""
bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""
apply_rope_fusion: bool = False
"""If True, use fused RoPE kernel."""
####################
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: str = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method: str = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers: int = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations: bool = None
"""If True, distribute recomputed activations across the model parallel group."""
####################
# fp8 related
####################
fp8: str = None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
fp8_margin: int = 0
"""Margin for the scaling factor computation."""
fp8_interval: int = 1
"""Controls how often the scaling factor is recomputed."""
fp8_amax_history_len: int = 1
"""The length of the amax history window used for scaling factor computation."""
fp8_amax_compute_algo: str = "most_recent"
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""
fp8_wgrad: bool = True
"""When set to False, override FP8 config options and do the wgrad computation in higher precision."""
fp8_dot_product_attention: bool = False
"""When set to True, use the FP8 implementation of Dot Product Attention."""
fp8_multi_head_attention: bool = False
"""When set to True, use the FP8 implementation of Multi Head Attention."""
####################
# MoE related
####################
moe_router_load_balancing_type: str = "aux_loss"
"""Determines the load balancing strategy for the router. "aux_loss" corresponds to the load
balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing."""
moe_router_topk: int = 2
"""Number of experts to route to for each token."""
moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps: float = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping: bool = False # TODO: Support token dropping.
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
moe_token_dispatcher_type: str = "allgather"
"""The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather' and 'alltoall'."""
moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
####################
# miscellaneous
####################
clone_scatter_output_in_embedding: bool = True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""
disable_parameter_transpose_cache: bool = False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""
enable_cuda_graph: bool = False
"""When set to true, TransformerLayer blocks are wrapped with CUDA graph."""
# These 2 attributes are WAR for TRTLLM export. DO NOT USE!! WILL BE DEPRECATED SOON!!
max_position_embeddings: int = 0
"""Deprecated. Do not use."""
rotary_percent: float = 0
"""Deprecated. Do not use."""
def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
super().__post_init__()
if self.fp16 and self.bf16:
raise ValueError(
f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.'
)
if self.num_attention_heads % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
if self.kv_channels is None:
self.kv_channels = self.hidden_size // self.num_attention_heads
if self.num_query_groups is None:
self.num_query_groups = self.num_attention_heads
if self.num_query_groups % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_query_groups ({self.num_query_groups}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.expert_model_parallel_size > 1 and self.num_moe_experts is None:
raise ValueError(f'num_moe_experts must be non None to use expert-parallel.')
if self.num_moe_experts is not None and self.num_moe_experts <= 0:
raise ValueError(f'num_moe_experts must be non-negative.')
if self.cpu_offloading and (
self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers
):
raise ValueError(
f'CPU offloading can be done only for layers less than {self.num_layers}'
)
if self.cpu_offloading and self.pipeline_model_parallel_size > 1:
raise ValueError(
f'Currently there is no support for Pipeline parallelism with CPU offloading'
)
if self.cpu_offloading and self.recompute_granularity is not None:
raise ValueError(
f'CPU offloading does not work when activation recomputation is enabled'
)
if self.recompute_granularity is not None:
if not self.recompute_granularity in ['full', 'selective']:
raise ValueError(
f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".'
)
if self.recompute_method is not None:
if not self.recompute_method in ['block', 'uniform']:
raise ValueError(
f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
)
elif self.recompute_granularity != 'selective':
raise ValueError(
f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"'
)
if self.recompute_granularity != 'selective' and self.recompute_num_layers is None:
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between '
f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}'
)
elif (
self.recompute_granularity == 'selective' and self.recompute_num_layers is not None
):
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.'
)
if self.distribute_saved_activations and self.sequence_parallel:
raise ValueError(
f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}'
)
if self.virtual_pipeline_model_parallel_size is not None:
if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
raise ValueError(
f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}'
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.bias_activation_fusion:
if self.activation_func not in [F.gelu, F.silu]:
raise ValueError(
"When bias_activation_fusion is True, activation function should be either gelu or swiglu"
)
if (
self.activation_func == F.gelu
and not self.gated_linear_unit
and not self.add_bias_linear
):
raise ValueError(
"When bias_activation_fusion is True, gated_linear_unit is False, "
"and activation function is gelu, add_bias_linear must also be True."
)
if self.activation_func_fp8_input_store:
if self.activation_func != F.silu or not self.gated_linear_unit:
raise ValueError("Storing activation input in FP8 is supported only for SwiGLU.")
if self.apply_rope_fusion and self.rotary_interleaved:
raise ValueError(f'rotary_interleaved does not work with apply_rope_fusion.')
if self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std, self.num_layers
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
@dataclass
class TransformerLayerSubmodules:
input_layernorm: Union[ModuleSpec, type] = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class BaseTransformerLayer(ABC):
""" A common parent class for `TransformerLayer` like implementations.
A dummy class that is subclassed by similar `TransformerLayer`s e.g. the
`TransformerLayer` in this file and possibly other `TransformerLayer`
implementations that aim to use `TransformerBlock` as the base module.
The main purpose is to check if any layer (or module) provided in the spec
is a subclass of this class to allow fanning-out of that spec for all the
layers in the `TransformerBlock`. See `_get_block_submodules` method
implementation in `transformer_block.py` file for more details.
"""
def __init__(self):
pass
class TransformerLayer(MegatronModule, BaseTransformerLayer):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: float = None,
):
super().__init__(config=config)
self.submodules_config = submodules
self.layer_number = layer_number + self._get_layer_offset()
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
## [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = build_module(
submodules.input_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention, config=self.config, layer_number=layer_number,
)
## [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
## [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention, config=self.config, layer_number=layer_number,
)
## [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config,)
## [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = build_module(
submodules.pre_mlp_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module 8: MLP block]
# TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1,
# where MLP and MoE layer both appear alternately?
self.mlp = build_module(submodules.mlp, config=self.config)
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)
## [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
def _get_layer_offset(self):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
total_num_layers = self.config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
return offset
def forward(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
):
# hidden_states: [s, b, h]
# Residual connection.
residual = hidden_states
# Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)
# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_params=inference_params,
)
if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
# MLP.
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output, context
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for transformer layers."""
from functools import lru_cache
from operator import itemgetter
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedStateDict, StateDict
from megatron.core.jit import jit_fuser
from megatron.core.utils import (
make_sharded_tensor_for_checkpoint,
make_tp_sharded_tensor_for_checkpoint,
)
def get_linear_layer(rows, columns, init_method, perform_initialization=True):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
if perform_initialization: # Take from modelparallel config
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@lru_cache(maxsize=32)
def get_default_causal_mask(sq: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input."""
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
@jit_fuser
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@jit_fuser
def erf_gelu(x):
return (
x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
)
def make_sharded_tensors_for_checkpoint(
state_dict: StateDict,
prefix: str,
tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
extra_state_suffix: str = '_extra_state',
):
"""Wraps tensors from transformer layers with ShardedTensor or ShardedObject.
For a given `state_dict`, wraps:
- all _extra_states with ShardedObject
- all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor
- other values with DP sharded ShardedTensor
Args:
state_dict (StateDict): state_dict to convert
prefix (str): prefix appended to keys in final state dict
tensor_parallel_layers_axis_map (Dict[str, int], optional): dict mapping layer
names to the axis for TP sharding
sharded_offsets (Iterable[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related), passed along to ShardedTensor
extra_state_suffix (str, default = '_extra_state'): layers with this
suffix will be wrapped with ShardedObject instead of ShardedTensor.
"""
if tensor_parallel_layers_axis_map is None:
tensor_parallel_layers_axis_map = {}
sharded_state_dict = {}
for layer_name in state_dict.keys():
tensor = state_dict[layer_name]
layer_key = f'{prefix}{layer_name}'
if layer_name.endswith(extra_state_suffix):
sharded_state_dict[layer_key] = make_sharded_object_for_checkpoint(
tensor, layer_key, sharded_offsets
)
elif layer_name in tensor_parallel_layers_axis_map:
tp_axis = tensor_parallel_layers_axis_map[layer_name]
sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint(
tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets,
)
else:
sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint(
tensor, layer_key, prepend_offsets=sharded_offsets,
)
return sharded_state_dict
def make_sharded_object_for_checkpoint(
obj: Any,
key: str,
sharded_offsets: Iterable[Tuple[int, int, int]] = (),
replica_id: Union[None, int, Tuple[int, ...]] = None,
**kwargs,
):
""" Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
Args:
obj (object): any object to be sharded
key (str): unique identifier of the object
sharded_offsets (Iterable[Tuple[int, int, int]]): offsets normally
prepended to ShardedTensors, will be used as global offsets for
ShardedObject
replica_id (Union[None, int, Tuple[int, ...]]): replica id
"""
if replica_id is None:
replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
return ShardedObject(key, obj, *_get_extra_state_offsets(sharded_offsets), replica_id, **kwargs)
def _get_extra_state_offsets(
sharded_offsets: Iterable[Tuple[int, int, int]]
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
""" Turns ShardedTensor offsets into offsets suitable for ShardedObject. """
if sharded_offsets:
sharded_offsets = sorted(sharded_offsets, key=itemgetter(0)) # sort by axis
axis, extra_state_offset, extra_state_shape = zip(*sharded_offsets)
assert list(axis) == list(
range(len(axis))
), f'Expected contiguous axis for offsets: {sharded_offsets}'
else:
extra_state_shape = (1,)
extra_state_offset = (0,)
return extra_state_shape, extra_state_offset
def sharded_state_dict_default(
module: torch.nn.Module,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Provides implementation for sharded_state_dict method for non-MegatronModules.
Tries to call `module.sharded_state_dict` when possible,
otherwise uses regular state dict and assumes tensors are replicated across TP and DP.
`keep_vars=True` is passed to module.state_dict so that optimizer states
can be sharded later on.
Args:
module (torch.nn.Module): module which sharded state dict we want to obtain
prefix (str): prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional): metadata passed to module sharded_state_dict method
Returns:
dict: dictionary of state dict keys mapped to ShardedTensors
"""
if hasattr(module, 'sharded_state_dict'):
module_sharded_sd = module.sharded_state_dict(
prefix=prefix, sharded_offsets=sharded_offsets, metadata=metadata
)
else:
module_sd = module.state_dict(prefix='', keep_vars=True)
module_sharded_sd = make_sharded_tensors_for_checkpoint(
module_sd, prefix, {}, sharded_offsets,
)
return module_sharded_sd
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions used throughout Megatron core"""
import logging
import math
import operator
import queue
import socket
import sys
import threading
import time
import traceback
from dataclasses import dataclass
from datetime import datetime
from functools import reduce
from types import TracebackType
from typing import List, Optional, Tuple, Type, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False):
"""Get an attribute from a wrapped model.
If return_model_obj is true, return the object that has the 'attr' attribute;
otherwise, return the attribute directly."""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
if allow_none:
def condition(model, attr):
return not hasattr(model, attr)
else:
def condition(model, attr):
return getattr(model, attr, None) is None
while condition(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
if return_model_obj:
return model
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
def get_model_config(model):
return get_attr_wrapped_model(model, 'config', allow_none=False)
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if (
self.buffer.get((name, dtype), None) is None
or self.buffer[(name, dtype)].numel() < required_len
):
self.buffer[(name, dtype)] = torch.empty(
required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg=None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[assert_viewless_tensor(t) for t in tensor]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(
tensor,
extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
% ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape),
)
tensor.data = new_data_tensor
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
""" Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group.
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
if replica_id is None:
replica_id = (0, 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
(
tp_axis + prepend_axis_num,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_id=None, **kwargs):
""" Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group).
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
if replica_id is None:
replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def prepare_input_tensors_for_wgrad_compute(grad_output, all_gathered_input):
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if grad_output.dim() == 3:
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
all_gathered_input = all_gathered_input.view(
all_gathered_input.shape[0] * all_gathered_input.shape[1], all_gathered_input.shape[2]
)
return grad_output, all_gathered_input
def drain_embedding_wgrad_compute(config, embedding_activation_buffer, grad_output_buffer, weight):
""" Helper for performing embedding wgrad GEMM's during the pipeline drain phase, pipelines the AllGather and GEMM's.
Should only be used when pipeline model parallelism and gradient accumulation fusion are enabled.
"""
assert len(embedding_activation_buffer) == len(
grad_output_buffer
), "Length of activation and gradient buffers need to be equal!"
import fused_weight_gradient_mlp_cuda
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
input = embedding_activation_buffer.pop(0)
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gathered_input = [None, None]
if config.sequence_parallel:
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu_0")
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=False
)
all_gathered_input[0] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[0] = input
input = None
def wgrad_compute(all_gathered_input, grad_output, weight):
grad_output, all_gathered_input = prepare_input_tensors_for_wgrad_compute(
grad_output, all_gathered_input
)
if config.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
all_gathered_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
all_gathered_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
# We have all_gathered_input list acting as a double buffer here,
# since we are pipelining the AllGather and GEMM,one buffer all gathers
# the input while the other buffer reads from it for the GEMM. We use i
# and (i+1) for indexing to enable this double buffering.
for i in range(len(embedding_activation_buffer)):
input = embedding_activation_buffer.pop(0)
if config.sequence_parallel:
name = "mpu_" + str((i + 1) % 2)
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, name)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
all_gathered_input[(i + 1) % 2] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[(i + 1) % 2] = input
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[i % 2], grad_output, weight)
input, all_gathered_input[i % 2], grad_output = None, None, None
if config.sequence_parallel:
handle.wait()
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[1], grad_output, weight)
input, all_gathered_input[1], grad_output = None, None, None
class _ValueWithRank:
"""This is an internal class, not for use outside this module
Attributes:
_rank (int): rank for the value
_value (float) : the value it stores, eg elapsed time
_unit (str) : unit for the value
"""
def __init__(self, value: float, rank: int, unit: str = "") -> None:
"""Initializer
Args:
_value (float): the initial value with which it is inited
_rank (int): the rank number
_unit (str) : the unit of the value, eg ms or flops
"""
self._rank = rank
self._value = value
self._unit = unit
def __lt__(self, other) -> bool:
""" Check if value of self is smaller than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is less than rhs._value, else False
"""
return self._value < other._value
def __gt__(self, other) -> bool:
"""Check if value of self is larger than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is greater than rhs._value, else False
"""
return self._value > other._value
def __call__(self) -> Tuple[float, int, str]:
"""Returns the value, the rank, and unit as a Tuple
Returns:
Tuple[float, int, str]: value, rank, unit
"""
return self._value, self._rank, self._unit
def __str__(self) -> str:
"""String representation of the object
Returns:
str: strigified object
"""
return f"{self._value:.2f}{self._unit}/{self._rank}"
@dataclass
class _StragglerData:
"""This is an internal dataclass, not for use outside this module
Attributes:
min_elapsed (_ValueWithRank) min iteration time across all ranks
max_elapsed (_ValueWithRank) max iteration time across all ranks
min_btime (_ValueWithRank) min cpu time across all ranks
max_btime (_ValueWithRank) max cpu time across all ranks
min_temp (_ValueWithRank): min gpu temp across all ranks
max_temp (_ValueWithRank): max gpu temp across all ranks
min_power (_ValueWithRank) min gpu power across all ranks
max_power (_ValueWithRank) max gpu power across all ranks
min_util (_ValueWithRank): min gpu util across all ranks
max_util (_ValueWithRank): max gpu util across all ranks
min_clock (_ValueWithRank): min gpu clock across all ranks
max_clock (_ValueWithRank) max gpu clock across all ranks
aflops (List[_ValueWithRank]): sorted array of (_ValueWithRank)
"""
# gemm time
min_elapsed = _ValueWithRank(sys.float_info.max, 0, "ms")
max_elapsed = _ValueWithRank(sys.float_info.min, 0, "ms")
# get_batch time
min_btime = _ValueWithRank(sys.float_info.max, 0, "us")
max_btime = _ValueWithRank(sys.float_info.min, 0, "us")
# temp
min_temp = _ValueWithRank(sys.float_info.max, 0, "C")
max_temp = _ValueWithRank(sys.float_info.min, 0, "C")
# power
min_power = _ValueWithRank(sys.float_info.max, 0, "W")
max_power = _ValueWithRank(sys.float_info.min, 0, "W")
# util
min_util = _ValueWithRank(sys.float_info.max, 0, "%")
max_util = _ValueWithRank(sys.float_info.min, 0, "%")
# clock
min_clock = _ValueWithRank(sys.float_info.max, 0, "MHz")
max_clock = _ValueWithRank(sys.float_info.min, 0, "MHz")
aflops: List[_ValueWithRank] = None
class StragglerDetector:
"""Singleton Class implementing per rank Straggler Detector
It use cuda events to time operation of choice using the
start and stop methods which can be directly invoked using
the class instance or can be used like a python context.
After collection, a report() method is available to display
the collected metrics. It is only supported if CUDA is
available. megatron/core/README_STRAGGLER.md for more info
Note:
The instance and class attributes mentioned below are all
private to the class and has no use outside the class
Attributes:
_off (bool): current state of the toggle
start (FunctionType): start method
stop (FunctionType): stop method
world (int): world size
rank (int): rank for this instance
mmcnt (int): number of ranks to report
port (int): control port
amp (float): amplification factor for TFLOPs, default 3.0
toggle (bool): whether to start/stop detector collection
bdata (bool): when true, just collect get_batch
dev (int): cuda device
idx (int): index into the list below
idx_q (LifoQueue): queue of index
evt_q (LifoQueue): cuda event queue
start_events (list[torch.cuda.Event]): cuda start event
stop_events (list[torch.cuda.Event]): cuda stop event
start_time (list[int]): start time (wallclock)
stop_time (list[int]): stop time (wallclock)
start_batch (list[int]): start time for get_batch
stop_batch (list[int]): stop time for get_batch
sock (socket): the controller socket
ctrlr (Thread): the controller thread
logger (Logger): the logger instance for this instance
"""
_configured = False
"""Indicates if the singleton instance is configured or not
"""
def __new__(cls: Type["StragglerDetector"]) -> "StragglerDetector":
"""Constructor
Creates an instance of the class if not created
Args:
cls (Type[&#39;StragglerDetector&#39;]): The class type
Returns:
StragglerDetector: the class instance
"""
if not hasattr(cls, "_instance"):
cls._instance = super(StragglerDetector, cls).__new__(cls)
return cls._instance
def __init__(self) -> None:
"""Initializer
The inital state of the StragglerDetector instance is disabled.
The enabled state is indicated using self._off member variable
and the proerty enabled.
"""
self._off = True
self.start = self.null_method
self.stop = self.null_method
self.world = 0
self.rank = 0
self.mmcnt = 1
self.port = 0
self.amp = 3.0
self.toggle = False
self.bdata = False
self.dev = None
self.idx = 0
self.idx_q = None
self.evt_q = None
self.start_events = None
self.stop_events = None
self.start_time = None
self.stop_time = None
self.start_batch = None
self.stop_batch = None
self.sock = None
self.ctrlr = None
self.logger = logging.getLogger(__name__)
def configure(
self,
world: int,
rank: int,
mmcnt: int = 1,
amp: float = 3.0,
port: int = 65535,
prefill: int = 1024,
enabled: bool = False,
) -> None:
"""This method is called to configure the Singleton instance
It should be called once per instantiation per process.
Note:
The constructor keeps the state of instance disabled
i.e no collection will happen even when start/stop methods are
called. Only when enabled is True (self._off is True), the
start/stop method pointers get assigned the real collection
methods, otherwise they are initialized with null_method
Args:
world (int): World Size
rank (int): The rank of this trainer
mmcnt (int, optional): Number of ranks to print for showing Min/Max Etpt.
Defaults to 1.
amp (float, optional): Set to 3.0 if we only use timers in fwd pass.
Defaults to 3.0.
port (int, optional): Control port, useful only for rank-0. Defaults to 65535.
prefill (int, optional): Howmany Events to pre-populate. Defaults to 1024.
enabled (bool, optional): Whether or not collection is enabled on startup.
Defaults to False.
"""
if StragglerDetector._configured:
# don't throw
return
StragglerDetector._configured = True
self.bdata = False
self.start = self.null_method
self.stop = self.null_method
self._off = True
# No CUDA, No Support
if torch.cuda.is_available():
self._off = not enabled
self.world = world
self.rank = rank
self.mmcnt = mmcnt if mmcnt > 1 else 1
self.amp = amp
self.port = port
self.toggle = False
self.bdata = False
self.idx = 0
self.idx_q = queue.LifoQueue()
self.evt_q = queue.LifoQueue()
self.start_events = []
self.stop_events = []
self.start_time = []
self.stop_time = []
self.start_batch = []
self.stop_batch = []
backend = torch.distributed.get_backend()
if backend == "nccl":
self.dev = torch.cuda.current_device()
else:
self.dev = torch.device("cpu")
# cache some events
for _ in range(prefill):
self.evt_q.put(torch.cuda.Event(enable_timing=True))
if self.rank == 0:
# Start the controller
self._controller()
if not self._off:
self.start = self.start_method
self.stop = self.stop_method
def reset(self) -> None:
"""This method is called to reset the metrics state of the instance
It is generally called from within elapsed() after extracting per rank metrics.
"""
if self._off:
return
self.idx = 0
self.idx_q = queue.LifoQueue()
# Pool them
_ = [self.evt_q.put(ev) for ev in self.start_events]
_ = [self.evt_q.put(ev) for ev in self.stop_events]
self.start_events = []
self.stop_events = []
# Use regular timers
self.start_time = []
self.stop_time = []
self.start_batch = []
self.stop_batch = []
self.bdata = False
def start_method(self) -> None:
"""This method adds the start timers.
Both cuda event and perf_counter are added. If bdata is set to
true from __call__, this method skips inserting cuda
timer. This way it can be used to measure time spent on
CPU - generally useful for timing get_batch()
"""
# Not reentrant
# First check if this start is for data
if self.bdata:
self.start_batch.append(time.perf_counter_ns())
self.stop_batch.append(0) # this indicate we need to add timer
self.bdata = False
return
if self.evt_q.qsize() > 1:
sev = self.evt_q.get() # no try-catch
eev = self.evt_q.get() # no try-catch
else:
sev = torch.cuda.Event(enable_timing=True)
eev = torch.cuda.Event(enable_timing=True)
self.start_events.append(sev)
self.stop_events.append(eev)
self.start_time.append(0)
self.stop_time.append(0)
self.idx_q.put(self.idx)
self.start_time[self.idx] = time.perf_counter_ns()
self.start_events[self.idx].record()
self.idx += 1
def stop_method(self) -> None:
"""This method adds the stop timers.
Both cuda event and perf_counter are added. If bdata is set to
true from __call__, this method skips inserting cuda
timer. Also see start_method()
"""
# Not reentrant
# First check if this stop is for data
dle = len(self.stop_batch) - 1
if dle >= 0 and self.stop_batch[dle] == 0:
self.stop_batch[dle] = time.perf_counter_ns()
return
idx = self.idx_q.get()
self.stop_time[idx] = time.perf_counter_ns()
self.stop_events[idx].record()
def elapsed(self) -> Tuple[float, float, int, int, int, int]:
"""This method is called from report(), or can be called directly
It is called to collect all the elapsed time since last reset().
It finally calls reset()
Returns:
Tuple[float, float, int, int, int, int]: see below for returns
delta : time spent in kernel
batch_delta : time spent in get_batch
temp : observed gpu temp
power : observed gpu power
util : observed gpu utilization
clock : observed gpu clock
"""
if self._off:
# match with return below
return 0, 0, 0, 0, 0, 0
ls_ev = len(self.start_events)
le_ev = len(self.stop_events)
ls_bs = len(self.start_batch)
ls_be = len(self.stop_batch)
delta = 0.0
batch_delta = 0.0
temp = 0
power = 0
clock = 0
if ls_ev != le_ev:
self.logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
elif ls_bs != ls_be:
self.logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
else:
temp = torch.cuda.temperature()
power = torch.cuda.power_draw()
util = torch.cuda.utilization()
clock = torch.cuda.clock_rate()
torch.cuda.synchronize()
# Process Events
for i in range(ls_ev):
e_ev = self.start_events[i].elapsed_time(self.stop_events[i])
e_tm = (self.stop_time[i] - self.start_time[i]) / 1e6 # ns to ms
# Pick the larger of Event and perf_counter time?
delta += max(e_ev, e_tm)
# Process get_batch
for i in range(ls_bs):
batch_delta = (self.stop_batch[i] - self.start_batch[i]) / 1e3 # us
self.reset() # Prepare for next round
# time in ms, batch_delta in us, check return above
return delta, batch_delta, temp, power, util, clock
def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
"""Function to log the min/max metircs and the associated rank over a time period
It finds the slowest and fastest rank among all ranks. It should be
called by all ranks, but only rank-0 prints the analysis
At the end it checks, if the straggler detector should
remain active or if it should be deactivated.
Args:
total_flops (float, optional): The theoretical flops over the period. Defaults to 0.0.
log_interval (int, optional): The training interval over which reporting is called(ms)
Defaults to 0.
Returns:
bool: True if reported, else False
"""
ret = False
if not self._off and total_flops > 0.0 and log_interval > 0:
elapsed, btime_us, temp, power, util, clock = self.elapsed() # get raw time
ptime = elapsed / (log_interval * 1.0) # avg per iteration elapsed time, ms
btime = btime_us / (log_interval * 1.0) # avg per iteration get_batch time, us
api_flops = total_flops / (log_interval * 1.0) # avg per iteration flops, ms
apir_flops = api_flops / (
ptime * 10 ** 9 * self.world
) # this is avg per iteration this rank's thruput, TFLOP/s (note 10**9),
et_flops = apir_flops / self.amp # Estimated TFLOPs, not tracing backward
o_dt = self._min_max(
ptime, btime, float(temp), float(power), float(util), float(clock), et_flops,
)
if self.rank == 0:
now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
min_flops, min_frank, _ = o_dt.aflops[0]()
max_flops, max_frank, _ = o_dt.aflops[-1]()
self.logger.info(
f"{now} | "
f"MnRtt/Rnk: {o_dt.min_elapsed} | "
f"MxRtt/Rnk: {o_dt.max_elapsed} | "
f"MnPwr/Rnk: {o_dt.min_power} | "
f"MxPwr/Rnk: {o_dt.max_power} | "
f"MnTmp/Rnk: {o_dt.min_temp} | "
f"MxTmp/Rnk: {o_dt.max_temp} | "
f"MnUtl/Rnk: {o_dt.min_util} | "
f"MxUtl/Rnk: {o_dt.max_util} | "
f"MnClk/Rnk: {o_dt.min_clock} | "
f"MxClk/Rnk: {o_dt.max_clock} | "
f"MnDRtt/Rnk: {o_dt.min_btime} | "
f"MxDRtt/Rnk: {o_dt.max_btime} | "
f"MnEtpt/Rnk: {min_flops:.2f}TF/{min_frank} | "
f"MxEtpt/Rnk: {max_flops:.2f}TF/{max_frank}"
)
if self.mmcnt > 1 and self.mmcnt < self.world:
line = f"^^^^ Bottom {self.mmcnt} Ranks with lowest Etpt(TF):"
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i]},"
self.logger.info(line)
line = f"^^^^ Top {self.mmcnt} Ranks with highest Etpt(TF):"
shift = self.world - self.mmcnt
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i+shift]},"
self.logger.info(line)
ret = True
# Check/Communicate if tracking is turned off or on
self._check_toggle()
return ret
def _check_toggle(self) -> None:
"""Helper method to check if a request to toggle the collection state was made
It checks iof collection state toggle req was made via the server listening on
rank-0 since last call to report(). Called by report(). Calling this method
indirectly from report() is the only way to activate the change that is made
via rank-0
"""
# If no change just commnunicate the current
off = self._off
if self.rank == 0 and self.toggle:
off = not self._off
self.toggle = False
state = torch.tensor(off, dtype=torch.bool, device=self.dev)
torch.distributed.broadcast(state, 0) # Blocking
self._off = state.item()
if not self._off:
self.start = self.start_method
self.stop = self.stop_method
state = "ON"
else:
self.start = self.null_method
self.stop = self.null_method
state = "OFF"
if self.rank == 0 and off is not self._off:
self.logger.info(f"Toggling StragglerDetector State {state}")
def _handler(self) -> None:
"""Thread function for the controller.
It is a tcp-server that listens on a port. Uses HTTP protocol.
If connected to it using curl, it indicates a toggle of the
collection state. The actual toggling happens at the end of
calling report() when _check_toggle() is called.
"""
resp = f"HTTP/1.0 200 OK\r\nConnection: Close\r\nContent-length: "
if self.rank == 0:
state = "OFF" if self._off else "ON"
self.logger.info(
f"Controller ready to recv " f"commands on port {self.port}. Current state {state}"
)
while True:
try:
conn, _ = self.sock.accept()
_ = conn.recv(1024)
self.toggle = True
state = "ON" if self._off else "OFF"
msg = f"Will turn StragglerDetector {state} at next logging interval"
msg_len = len(msg)
final_resp = f"{resp}{msg_len}\r\n\r\n{msg}"
conn.send(final_resp.encode())
conn.close()
self.logger.info(msg)
except Exception as err:
self.logger.error(f"Error in stragler handler.. {str(err)}")
return
def _controller(self):
"""Installs a controller listener that is used to toggle collection state.
Called from configure(). Ignored for all ranks other than rank-0
"""
try:
if self.rank == 0:
neth = "0.0.0.0"
netp = self.port
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((neth, netp))
self.sock.listen(128)
self.ctrlr = threading.Thread(
target=self._handler, args=(), name="straggler", daemon=True
)
self.ctrlr.start()
except Exception as err:
self.logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")
def _min_max(
self,
ptime: float,
btime: float,
temp: float,
power: float,
util: float,
clock: float,
flops: float,
) -> Union[_StragglerData, None]:
"""Helper function to find the min/max values
Args:
ptime (float): avg per iteration gpu time
btime (float): avg per iteration cpu time
temp (float): gpu temp at the time of reporting
power (float): gpu power at the time of reporting
util (float): gpu util at the time of reporting
clock (float): gpu clock at the time of reporting
flops (float): estimated flops for the rank
Returns:
Union[_StragglerData, None]: It contains the min/max of few metrics and the
corresponding rank it also has sorted list of
all (flops, rank) sorted by flops (aflops)
or returns None if collecton is disabled
"""
if self._off:
return None
# initialize output data object
o_dt = _StragglerData()
prof_data = {}
prof_data["rank"] = self.rank
prof_data["time"] = ptime
prof_data["btime"] = btime
prof_data["temp"] = temp
prof_data["power"] = power
prof_data["util"] = util
prof_data["clock"] = clock
prof_data["flops"] = flops
if self.rank == 0:
data_list = [prof_data] * self.world
else:
data_list = None
# this is blocking by default
torch.distributed.gather_object(prof_data, object_gather_list=data_list, dst=0)
if self.rank == 0:
min_ctime = min(data_list, key=lambda k: k["time"]) # elapsed
max_ctime = max(data_list, key=lambda k: k["time"]) # elapsed
min_cbatch = min(data_list, key=lambda k: k["btime"]) # batch time
max_cbatch = max(data_list, key=lambda k: k["btime"]) # batch time
min_ctemp = min(data_list, key=lambda k: k["temp"]) # temp
max_ctemp = max(data_list, key=lambda k: k["temp"]) # temp
min_cpower = min(data_list, key=lambda k: k["power"]) # power
max_cpower = max(data_list, key=lambda k: k["power"]) # power
min_cutil = min(data_list, key=lambda k: k["util"]) # gpu util
max_cutil = max(data_list, key=lambda k: k["util"]) # gpu util
min_cclock = min(data_list, key=lambda k: k["clock"]) # gpu clock
max_cclock = max(data_list, key=lambda k: k["clock"]) # gpu clock
min_val = min_ctime["time"]
min_rank = min_ctime["rank"]
max_val = max_ctime["time"]
max_rank = max_ctime["rank"]
o_dt.min_elapsed = _ValueWithRank(min_val, min_rank, "ms")
o_dt.max_elapsed = _ValueWithRank(max_val, max_rank, "ms")
min_val = min_cbatch["btime"]
min_rank = min_cbatch["rank"]
max_val = max_cbatch["btime"]
max_rank = max_cbatch["rank"]
o_dt.min_btime = _ValueWithRank(min_val, min_rank, "us")
o_dt.max_btime = _ValueWithRank(max_val, max_rank, "us")
min_val = min_ctemp["temp"]
min_rank = min_ctemp["rank"]
max_val = max_ctemp["temp"]
max_rank = max_ctemp["rank"]
o_dt.min_temp = _ValueWithRank(min_val, min_rank, "C")
o_dt.max_temp = _ValueWithRank(max_val, max_rank, "C")
min_val = min_cpower["power"]
min_rank = min_cpower["rank"]
max_val = max_cpower["power"]
max_rank = max_cpower["rank"]
o_dt.min_power = _ValueWithRank(min_val, min_rank, "W")
o_dt.max_power = _ValueWithRank(max_val, max_rank, "W")
min_val = min_cutil["util"]
min_rank = min_cutil["rank"]
max_val = max_cutil["util"]
max_rank = max_cutil["rank"]
o_dt.min_util = _ValueWithRank(min_val, min_rank, "%")
o_dt.max_util = _ValueWithRank(max_val, max_rank, "%")
min_val = min_cclock["clock"]
min_rank = min_cclock["rank"]
max_val = max_cclock["clock"]
max_rank = max_cclock["rank"]
o_dt.min_clock = _ValueWithRank(min_val, min_rank, "MHz")
o_dt.max_clock = _ValueWithRank(max_val, max_rank, "MHz")
o_dt.aflops = [
_ValueWithRank(d.get("flops"), d.get("rank")) for _, d in enumerate(data_list)
]
o_dt.aflops.sort(key=lambda val_with_rank: val_with_rank()[0])
# wait for everyone here
torch.distributed.barrier()
return o_dt
@property
def enabled(self) -> bool:
"""Can be called to check the enabled state of the instance
Note:
After the request to toggle the state, the
actual state change happens at end of call
to report()
"""
return not self._off
@property
def configured(self) -> bool:
"""Can be called to check if the the instance is already configured
Returns:
bool: returns True if configure was called and was a success, else False
"""
return StragglerDetector._configured
@property
def my_rank(self):
"""Can be called to get configured rank of this instance
Returns:
int: Configured rank for this instance
"""
return self.rank
@property
def world_size(self) -> int:
"""Can be called to get configured world of this instance
Returns:
int: World size configured for this instance
"""
return self.world
def null_method(self) -> None:
"""Default method to initialize start/stop method ptrs"""
pass
def __enter__(self) -> "StragglerDetector":
"""Define context/instance entry
Returns:
StragglerDetector: the instance
"""
self.start()
return self
def __call__(self, bdata: bool = False) -> "StragglerDetector":
"""Callable for the instance. Set context state,
Useful when the context is used for cpu timers only when bdata=True
Args:
bdata (bool, optional): when true, only enables cpu timers. Defaults to False.
Returns:
StragglerDetector: the instance
"""
self.bdata = bdata
return self
def __exit__(
self,
ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType],
) -> bool:
"""Define context/instance exit, calls the stop method
Args:
ex_type (Optional[Type[BaseException]]): Exception type
ex_val (Optional[BaseException]): _description_
ex_tb (Optional[TracebackType]): _description_
Returns:
bool: True if the exception was handled
"""
# Should not suppress errors even if turned off
ret = False
if ex_type is not None:
err = traceback.format_exception(ex_tb)
self.logger.warning(f"{str(ex_val)}\n{err}")
ret = True
self.stop()
return ret
# Singleton, global visibility
__straggler__ = StragglerDetector()
"""StragglerDetector: private module variable, not be directly accessed
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
def add_ammo_args(parser):
"""Add additional arguments for ammo."""
group = parser.add_argument_group(title="ammo-generic")
group.add_argument(
"--ammo-load-classic-megatron-to-mcore",
action="store_true",
help="Load a classic megatron-lm checkpoint to a new megatron-core model.",
)
group.add_argument(
"--ammo-convert-te-to-local-spec",
action="store_true",
help="Load a megatron-core transformer-engine checkpoint to a model with local spec.",
)
group.add_argument(
"--ammo-quant-cfg",
type=str,
default=None,
choices=["int8_sq", "fp8", "int4_awq", "None"],
help="Algorithms supported by atq.quantize.",
)
return parser
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""ModelOpt GPT model provider."""
from typing import Union
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core.inference.gpt.model_specs import get_gpt_layer_ammo_spec
from megatron.core.inference.gpt.state_dict_hooks import (
mcore_gpt_load_classic_state_dict_pre_hook,
mcore_gpt_load_te_state_dict_pre_hook,
)
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
def model_provider(
pre_process=True, post_process=True, parallel_output=True,
) -> Union[MCoreGPTModel]:
"""Builds the GPT model.
This model_provider only sypport use_mcore_models=True.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
parallel_output (bool): whether to allgather the output logits? This must be
True if `model_provider` is called in text_generation_server.
Returns:
Union[MCoreGPTModel]: The returned model
"""
args = get_args()
print_rank_0("building GPT model ...")
config = core_transformer_config_from_args(get_args())
if args.use_mcore_models:
if args.spec is not None:
raise ValueError("Custom layer specs are not supported!")
else:
if args.num_experts is None:
transformer_layer_spec = get_gpt_layer_ammo_spec()
else:
raise ValueError("MoE is not supported for now!")
model_type = MCoreGPTModel
model_kwargs = {
"config": config,
"transformer_layer_spec": transformer_layer_spec,
"vocab_size": args.padded_vocab_size,
"max_sequence_length": args.max_position_embeddings,
"pre_process": pre_process,
"post_process": post_process,
"fp16_lm_cross_entropy": args.fp16_lm_cross_entropy,
"parallel_output": parallel_output,
"share_embeddings_and_output_weights": not args.untie_embeddings_and_output_weights,
"position_embedding_type": args.position_embedding_type,
"rotary_percent": args.rotary_percent,
}
else:
raise ValueError("Classic Megatron-LM models are not supported!")
model = model_type(**model_kwargs)
print_rank_0(str(model))
if args.use_mcore_models:
if args.ammo_load_classic_megatron_to_mcore:
model._register_load_state_dict_pre_hook(mcore_gpt_load_classic_state_dict_pre_hook)
elif args.ammo_convert_te_to_local_spec:
model._register_load_state_dict_pre_hook(mcore_gpt_load_te_state_dict_pre_hook)
return model
<!-- coding=utf-8-->
<!-- Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Megatron</title>
<style>
.wrapper {
max-width: 75%;
margin: auto;
}
h1 {
margin: 3rem 0 1rem 0;
padding: 0;
font-size: 1.5rem;
}
textarea {
width: 100%;
min-height: 300px;
resize: none;
border-radius: 8px;
border: 1px solid #ddd;
padding: 0.5rem;
box-shadow: inset 0 0 0.25rem #ddd;
&:focus {
outline: none;
border: 1px solid darken(#ddd, 5%);
box-shadow: inset 0 0 0.5rem darken(#ddd, 5%);
}
}
#the-count {
float: right;
padding: 0.1rem 0 0 0;
font-size: 0.875rem;
}
/* Chat containers */
.container {
font-family: 'Arial', sans-serif;
font-size: 16px;
border: 2px solid #dedede;
background-color: #f1f1f1;
border-radius: 5px;
padding: 15px;
margin: 10px 0;
}
/* Clear floats */
.container::after {
content: "";
clear: both;
display: table;
}
/* Style images */
.container img {
float: left;
max-width: 60px;
width: 100%;
margin-right: 20px;
border-radius: 50%;
}
</style>
</head>
<body>
<div class="wrapper">
<h1>Prompt Megatron</h1>
<textarea name="prompt" id="prompt" maxlength="1024" placeholder="Add prompt"autofocus></textarea>
<label for="tokens_to_generate">Number tokens to generate (1-1024):</label>
<input type="number" id="tokens_to_generate" name="tokens_to_generate" min="10" max="256", value=32>
<button onclick="submit_query()">Submit</button>
<div id="the-count">
<span id="current">0</span>
<span id="maximum">/ 1000</span>
</div>
<textarea name="response" id="response" maxlength="2048" placeholder="Megatron response..."></textarea>
</div>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script type="text/javascript">
function submit_query() {
$("#response").val("Waiting for Megatron response...");
$.ajax({
url:"api",
type:"PUT",
data:JSON.stringify({prompts: [$("#prompt").val()], tokens_to_generate: parseInt($("#tokens_to_generate").val(),10)}),
contentType:"application/json; charset=utf-8",
dataType:"json",
success: function(data){
data.max_len=35;
$("#response").val(data.text);
}
});
}
$('textarea').keyup(function() {
var characterCount = $(this).val().length,
current = $('#current'),
maximum = $('#maximum'),
theCount = $('#the-count');
current.text(characterCount);
if (characterCount >= 800) {
maximum.css('color', '#8f0001');
current.css('color', '#8f0001');
theCount.css('font-weight','bold');
} else {
maximum.css('color','#666');
theCount.css('font-weight','normal');
}
});
</script>
</body>
</html>
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from .api import (
generate,
generate_and_post_process,
beam_search_and_post_process)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment