Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
class IdentityOp(torch.nn.Module):
"""
This is a placeholder for IdentityOp(x) -> x
"""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
class IdentityFuncOp(IdentityOp):
"""
This is a placeholder for IdentityFuncOp(...)(x) -> IdentityOp(x) -> x.
Such a func is handy for ops like `bias_dropout_fusion` which themselves
return a function at runtime based on passed arguments
"""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, *args, **kwargs):
return super().forward
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
)
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
@dataclass
class MLPSubmodules:
linear_fc1: Union[ModuleSpec, type] = None
linear_fc2: Union[ModuleSpec, type] = None
class MLP(MegatronModule):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
Returns an output and a bias to be added to the output.
If config.add_bias_linear is False, the bias returned is None.
We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
is_expert: bool = False,
input_size: int = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.input_size = input_size if input_size != None else self.config.hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc1',
)
self.activation_func = self.config.activation_func
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2',
)
def forward(self, hidden_states):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
if self.config.gated_linear_unit and name == 'linear_fc1':
assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
for k, v in sub_sd.items():
if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
sharded_state_dict.update(sub_sd)
return sharded_state_dict
def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
# We must split the tensor into 2 parts, each sharded separately.
# This requires a ShardedTensorFactory which `chunk`s during saving
# and `cat`s during loading
tp_rank = parallel_state.get_tensor_model_parallel_rank()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
swiglu_shard_axis = 0
prepend_axis_num = len(sharded_offsets)
original_shape = original_sh_ten.local_shape
original_numel = int(np.prod(original_shape))
@torch.no_grad()
def sh_ten_build_fn(
key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
):
offset_w = (swiglu_shard_axis + prepend_axis_num, tp_rank, tp_size * 2)
offset_v = (swiglu_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2)
if flattened_range is None:
tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
return [
ShardedTensor.from_rank_offsets(
key,
tensor_w,
*sharded_offsets,
offset_w,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
ShardedTensor.from_rank_offsets(
key,
tensor_v,
*sharded_offsets,
offset_v,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
]
else:
# Here we need to map a slice `t` (`flattened_range` specifies slice start and stop)
# of the *original* flattened tensor into slices `w` and `v` of chunked
# and flattened tensor.
# Example:
# If original tensor has (16, 5) shape and flattened_range is `slice(8, 64)`,
# then `t` has shape `(56,)` and we need to create 2 tensors:
# w: first 32 elements of `t` with flattened_range slice(8, 40)
# v: last 24 elements of `t` with flattened_range slice(0, 24)
# Global offsets are the same as in the non-flattened case
assert t.ndim == 1, (key, t.shape)
non_flat_local_shape = (original_shape[0] // 2, *original_shape[1:])
chunk_numel = original_numel // 2
result = []
if flattened_range.start < chunk_numel:
# Non-empty `w` chunk
tensor_w = t[: chunk_numel - flattened_range.start]
flattened_range_w = slice(
flattened_range.start, min(chunk_numel, flattened_range.stop)
)
assert len(tensor_w) == flattened_range_w.stop - flattened_range_w.start
result.append(
ShardedTensor.from_rank_offsets_flat(
key,
tensor_w,
non_flat_local_shape,
*sharded_offsets,
offset_w,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=flattened_range_w,
)
)
if flattened_range.stop > chunk_numel:
# Non-empty `v` chunk
tensor_v = t[-(flattened_range.stop - chunk_numel) :]
flattened_range_v = slice(
max(chunk_numel, flattened_range.start) - chunk_numel,
flattened_range.stop - chunk_numel,
)
assert len(tensor_v) == flattened_range_v.stop - flattened_range_v.start, (
len(tensor_v),
flattened_range_v,
)
result.append(
ShardedTensor.from_rank_offsets_flat(
key,
tensor_v,
non_flat_local_shape,
*sharded_offsets,
offset_v,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=flattened_range_v,
)
)
assert sum(sh_ten.data.numel() for sh_ten in result) == t.numel(), (result, t.shape)
return result
def sh_ten_merge_fn(sub_state_dict):
with torch.no_grad():
return torch.cat(sub_state_dict)
return ShardedTensorFactory(
original_sh_ten.key,
original_sh_ten.data,
sh_ten_build_fn,
sh_ten_merge_fn,
original_sh_ten.replica_id,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module."""
from typing import Optional, Tuple
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
make_sharded_tensors_for_checkpoint,
sharded_state_dict_default,
)
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Base Megatron module inhertied by all Models.
Megatron specific extensions of torch Module with support
for pipelining
Args:
config (TransformerConfig): Transformer config
"""
# def __init__(self, config: TransformerConfig, share_word_embeddings=True):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False):
"""Override state dict for saving checkpoints Use this function to override the
state dict for saving checkpoints.
Args:
prefix (str, optional): _description_. Defaults to ''.
keep_vars (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Default implementation for sharded state dict for distributed checkpointing.
General definition of sharded_state_dict simply calls `sharded_state_dict_default`
(which call sharded_state_dict method if possible or a default implementation otherwise)
recursively on all submodules.
Args:
prefix (str): prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional): metadata passed recursively to sharded_state_dict methods
Returns:
dict: dictionary of state dict keys mapped to ShardedTensors
"""
sharded_state_dict = {}
# Save parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
sharded_state_dict, prefix, sharded_offsets=sharded_offsets
)
# Recurse into submodules
for name, module in self.named_children():
sharded_state_dict.update(
sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata)
)
return sharded_state_dict
def set_is_first_microbatch(self):
"""Sets the is_first_microbatch flag if it exists. When this flag is set, TE modules will update their fp8 parameter cache.
"""
for m in self.modules():
if hasattr(m, "is_first_microbatch"):
m.is_first_microbatch = True
def conversion_helper(val, conversion):
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(MegatronModule):
"""Float 16 Module.
Attributes:
config (TransformerConfig): Transformer config
fp16 (bool) : Specifies if the model runs in fp16 mode
bf16 (bool) : Specifies if the model runs in bf16 mode
Args:
config (TransformerConfig): The transformer config used to initalize the model
"""
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super(Float16Module, self).__init__(config)
self.config = config
self.fp16 = config.fp16
self.bf16 = config.bf16
if self.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif self.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('Either config.fp16 or config.bf16 should be True.')
self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs):
if parallel_state.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if parallel_state.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(self, prefix='', *args, **kwargs):
"""Retrieve sharded_state_dict from the module being wrapped."""
return self.module.sharded_state_dict(prefix, *args, **kwargs)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
# Megatron Core MoE Key Features
### Parallelism
- **Expert Parallel**
- A specific method of parallelism for MoE models, where experts are partitioned onto different workers and each worker processes a different batch of training samples, each worker process one or more experts for each MoE layer.
- **3D Parallel**: Data Parallel , Tensor Parallel, Pipeline Parallel, Sequence Parallel
- Note: When using MoE with expert parallelism and tensor parallelism, sequence parallelism must be used.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/SP for handling larger MoE variants.
- **Full distributed optimizer support.**
### Router and Load Balancing
- Router type:
- Top-K MLP router
- Load Balancing algorithms:
- Sinkhorn (S-BASE)
- Aux loss / Load balancing loss
### Performance Optimizations
- GroupedGEMM when num local experts > 1
- Supported dtype: bf16
- Performance improvements for larger MoE models
- Enable `--tp-comm-overlap` for MoE
### Token Dispatch Mechanism
- Dropless / No token drop.
- Token drop and padding.
### Ease of use
- Checkpoint converter (coming soon)
- Per-layer logging
## Upcoming features
- Enhanced cutlass GroupedGEMM kernels
- Reduced host-device syncs.
- More supported dtype: fp32/bf16/fp16
- Kernel heuristics tuned for H100/A100/A10/L40S
- BWD cutlass GroupedGEMM kernels supported
- Token permutation / unpermutation fusion
- Fused Sinkhorn Kernel
- Context Parallel with MoE
- FP8 training support
# User Guide
### MoE Related Arguments
| Item | Description |
| --- | --- |
| num-experts | Number of Experts in MoE (None means no MoE) |
| expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| moe-grouped-gemm | When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). |
| moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| moe-router-topk | Number of experts to route to for each token. The default is 2. |
| moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. Default is 0.0. |
| moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. Default is None. |
| moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. Default is None. |
| moe-token-dispatcher-type | Determines the token dispatcher type. Choices are "allgather" and "alltoall". Default is "allgather". |
| moe-per-layer-logging | Enable per-layer logging for MoE, currently supports auxiliary loss and z loss. |
| moe-expert-capacity-factor | The capacity factor for each expert, None means no token will be dropped. Default is None. |
| moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |
### Usage
To train a top-2 MoE model with an auxiliary loss, include the following arguments:
```python
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--use-distributed-optimizer
```
To avoid out-of-memory in dropless MoE training, we can set a large capacity factor, add:
```python
--moe-expert-capacity-factor 4.0
```
To enable the token drop mechanism, such as GShard and SwitchTransformer, include the following arguments:
```python
--moe-expert-capacity-factor 1.0
--moe-pad-expert-input-to-capacity # Optional
```
## Dropless MoE training script example:
<details>
<summary>Click here. </summary>
```bash
#!/bin/bash
# Runs Mixtral 8x7B model on 32 H100/A100 GPUs
# The Dropless MoE suffers from an imbalanced token distribution at the early stage of training (the first few hundred iterations), which may lead to poor performance and out-of-memory (OOM) issues.
# To check the performance of a Dropless MoE model, we should run the model for at least 500 iterations or resume from trained checkpoints.
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--disable-bias-linear
--seq-length 4096
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
)
MOE_ARGS=(
--num-experts 8
--expert-model-parallel-size 8
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, None. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
)
DATA_ARGS=(
--tokenizer-type Llama2Tokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--data-path $DATA_PATH
--split 99990,8,2
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 128
--lr 1e-4
--train-iters 500000
--lr-decay-iters 320000
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--lr-warmup-iters 500
--clip-grad 1.0
--bf16
--overlap-grad-reduce
--overlap-param-gather
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 2
--pipeline-model-parallel-size 1
--sequence-parallel
--use-distributed-optimizer
)
LOGGING_ARGS=(
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
--no-load-optim \
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"Mixtral-Finetuning"}
--wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"}
)
fi
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
```
</details>
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.jit import jit_fuser
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.transformer_config import TransformerConfig
class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""
def __init__(self, num_local_experts: int, config: TransformerConfig):
super().__init__(config=config)
self.config: TransformerConfig = config
self.num_local_experts = num_local_experts
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
if self.config.activation_func not in (F.silu, F.gelu):
raise ValueError("Activation function must be silu or gelu when using GroupedMLP.")
@jit_fuser
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
# How many feature each rank holds for fc1 and fc2, respectively.
if config.moe_extended_tp:
tp_size = parallel_state.get_tensor_and_expert_parallel_world_size()
else:
tp_size = parallel_state.get_tensor_model_parallel_world_size()
fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts
if config.gated_linear_unit:
# Project to 4h. If using swiglu double the output width,
# see https://arxiv.org/pdf/2002.05202.pdf
fc1_output_size *= 2
fc1_output_size_per_partition = divide(fc1_output_size, tp_size)
fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts
fc2_input_size_per_partition = divide(fc2_input_size, tp_size)
# Note: The current kernel implementations of grouped_gemm
# does not support transposition with CUTLASS grouped GEMM
# (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
# and as a result we avoid allocate the transpose of weights.
# Initialize weight.
if config.use_cpu_initialization:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight1,
self.config.hidden_size,
fc1_output_size,
fc1_output_size_per_partition,
partition_dim=1,
init_method=config.init_method,
params_dtype=config.params_dtype,
)
_initialize_affine_weight_cpu(
self.weight2,
fc2_input_size,
self.config.hidden_size,
fc2_input_size_per_partition,
partition_dim=0,
init_method=config.output_layer_init_method,
params_dtype=config.params_dtype,
)
else:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight1,
config.init_method,
partition_dim=1,
expert_parallel=self.expert_parallel,
)
_initialize_affine_weight_gpu(
self.weight2,
config.output_layer_init_method,
partition_dim=0,
expert_parallel=self.expert_parallel,
)
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = gg.ops.gmm(
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False
)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
else:
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
h = self.activation_func(h)
h = torch.matmul(h, w2)
fc2_output = h
return fc2_output, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
raise NotImplementedError(
'Currently distributed checkpointing is not supported for GroupedMLP'
)
class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
self.add_bias = config.add_bias_linear
self.moe_extended_tp = config.moe_extended_tp
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden)
output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
return output_local, output_bias_local
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Maps local expert to global experts. """
if self.moe_extended_tp:
raise NotImplementedError(
'Currently distributed checkpointing is not supported for moe_extended_tp'
)
sharded_state_dict = {}
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
expert_sharded_prefix = f'{prefix}experts.'
for expert_local_idx, expert in enumerate(self.local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
expert_state_dict_prefix = f'{prefix}local_experts.{expert_local_idx}.'
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
expert_state_dict = expert.sharded_state_dict(
expert_state_dict_prefix, expert_sharded_offsets, metadata
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(
expert_state_dict, expert_state_dict_prefix, expert_sharded_prefix
)
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in expert_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (
*replica_id[:2],
parallel_state.get_data_modulo_expert_parallel_rank(),
)
sharded_state_dict.update(expert_state_dict)
return sharded_state_dict
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
try:
import grouped_gemm
except ImportError:
grouped_gemm = None
def grouped_gemm_is_available():
return grouped_gemm is not None
def assert_grouped_gemm_is_available():
assert grouped_gemm_is_available(), (
"Grouped GEMM is not available. Please run "
"`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0`."
)
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class BaseMoELayer(MegatronModule, ABC):
"""Base class for a mixture of experts layer.
Args:
config (TransformerConfig): Configuration object for the transformer model.
"""
def __init__(self, config: TransformerConfig, layer_number: int = None):
super(BaseMoELayer, self).__init__(config)
self.config = config
self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()
assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size"
if self.config.moe_extended_tp:
self.num_local_experts = self.config.num_moe_experts
local_expert_indices_offset = 0
else:
assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
self.router = None
self.experts = None
self.token_dispatcher = None
self.layer_number = layer_number
@abstractmethod
def forward(self, hidden_states):
pass
def set_layer_number(self, layer_number: int):
self.layer_number = layer_number
self.router.set_layer_number(layer_number)
class MoELayer(BaseMoELayer):
"""Mixture of experts Layer **currently only supports no token dropping**.
Args:
BaseMoELayer (MegatronModule): Base class for MoE layers
"""
def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
self.router = TopKRouter(config=self.config)
if self.config.moe_grouped_gemm:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules, MLPSubmodules)
self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall":
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
else:
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
self.moe_layer_recompute = config.moe_layer_recompute
def forward(self, hidden_states: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
# process MoE
def custom_forward(hidden_states):
probs, indices = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, indices
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
return output, mlp_bias
if self.moe_layer_recompute:
output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
else:
output, mlp_bias = custom_forward(hidden_states)
return output, mlp_bias
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
import torch
from megatron.core import parallel_state
def switch_load_balancing_loss_func(
probs: torch.Tensor, tokens_per_expert: torch.Tensor, topk: int, moe_aux_loss_coeff: float
):
"""Calculate the auxiliary loss for better load balacing.
Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
Args:
probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts]
tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts]
Returns:
torch.Tensor: The auxiliary loss for load balancing.
"""
num_tokens = probs.shape[0] * topk
num_experts = probs.shape[1]
probs_mean_per_expert = probs.mean(dim=0)
aux_loss = torch.sum(probs_mean_per_expert * tokens_per_expert) * (
num_experts / num_tokens * moe_aux_loss_coeff
)
return aux_loss
def z_loss_func(logits, z_loss_coeff):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss
def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
"""Sinkhorn based MoE routing function"""
cost = torch.exp(cost)
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
eps = 0.00000001
error = 1e9
d1_old = d1
while error > tol:
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
error = torch.mean(torch.abs(d1_old - d1))
d1_old = d1
return d1 * cost * d0.unsqueeze(1)
def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None):
"""
Calculate the capacity of each expert.
Args:
num_tokens (int): num of the input tokens.
num_experts (int): num of the experts.
capacity_factor (float): Capacity factor.
min_capacity (int, optional): Minimum capacity. Defaults to None.
Returns:
Tensor: Capacity of each expert.
"""
capacity = math.ceil((num_tokens / num_experts) * capacity_factor)
if min_capacity is not None and capacity < min_capacity:
capacity = min_capacity
return capacity
class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that compute and scales the grad for auxiliary loss.
"""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
@staticmethod
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
"""Preserve the aux_loss by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
aux_loss (torch.Tensor): The auxiliary loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(aux_loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for auxiliary loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
"""
(aux_loss,) = ctx.saved_tensors
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
return grad_output, scaled_aux_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the aux loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
"""
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False):
"""Permute the tokens based on the indices. Token with the same index will be grouped together.
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
num_out_tokens (int, optional): The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
Returns:
torch.Tensor: The permuted tensor.
torch.Tensor: The sorted_indices corresponding permuted tensor.
"""
if padded_mode:
return permute_with_padded_tokens(tokens, indices)
if indices.dim() == 1:
topk = 1
else:
topk = indices.size(1)
flatten_indices = indices.view(-1)
sorted_indices = torch.argsort(flatten_indices, stable=True)
if num_out_tokens is not None:
sorted_indices = sorted_indices[:num_out_tokens]
permuted_tokens = tokens.index_select(0, sorted_indices // topk)
return permuted_tokens, sorted_indices
def unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor = None,
padded_mode: bool = False,
restore_shape: torch.Size = None,
):
"""Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
Args:
permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None.
Returns:
torch.Tensor: The unpermuted tokens, optionally merged with probabilities.
"""
if padded_mode:
return unpermute_with_padded_tokens(
permuted_tokens, sorted_indices, probs, restore_shape=restore_shape
)
assert sorted_indices.numel() == permuted_tokens.size(0)
if probs is not None:
# Unpermute and merge the tokens with their probabilities
num_unpermuted_tokens = probs.numel()
topk = probs.size(1)
else:
# Unpermute the tokens without merge
num_unpermuted_tokens = permuted_tokens.size(0)
topk = 1
unpermuted_tokens = torch.zeros(
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
if probs is not None:
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
return unpermuted_tokens
def permute_with_padded_tokens(tokens, indices):
"""Permute the tokens based on the indices, only used in padding mode.
The input indices shape is [num_expert, capacity], it indicates which tokens were selected by each expert separately.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
Returns:
torch.Tensor: The permuted tensor.
torch.Tensor: The sorted_indices corresponding permuted tensor.
"""
permuted_tokens = tokens.index_select(dim=0, index=indices.view(-1))
return permuted_tokens, indices
def unpermute_with_padded_tokens(
permuted_tokens: torch.Tensor,
indices: torch.Tensor,
probs: torch.Tensor,
restore_shape: torch.Size,
) -> torch.Tensor:
"""
Unpermutes a padded permuted tokens based on sorted indices and merges the tokens with their corresponding probabilities.
This function takes a tensor of permuted tokens and reorders them according to the provided indices. It also combines the tokens with their associated probabilities.
Parameters:
permuted_tokens (torch.Tensor): A 2D tensor containing permuted tokens.
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities corresponding to each token.
restore_shape (torch.Size): The target shape for the unpermuted tokens tensor.
Returns:
torch.Tensor: A tensor of unpermuted tokens, merged with their probabilities.
"""
# Ensure permuted_tokens is 2D
assert permuted_tokens.dim() == 2, f"Got {permuted_tokens.dim()}D."
# Reshape and expand probabilities and indices to match permuted_tokens
probs = probs.view(-1).unsqueeze(-1)
indices = indices.view(-1, 1).expand(-1, permuted_tokens.shape[1])
assert (
permuted_tokens.shape == indices.shape
), "Shape mismatch between permuted_tokens and indices."
# Combine tokens with their probabilities
combined_output = probs * permuted_tokens
# Prepare a tensor of zeros with the desired output shape
empty_tokens = torch.zeros(
restore_shape,
dtype=combined_output.dtype,
device=combined_output.device,
requires_grad=True,
)
# Scatter the combined tokens back to their original positions
unpermuted_tokens = torch.scatter_add(empty_tokens, 0, indices, combined_output)
return unpermuted_tokens
def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
capacity_factor: float = None,
pad_to_capacity: bool = False,
drop_policy: str = "probs",
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor.
(1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token.
(2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert.
"""
# TODO: Add Pre softmax.
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens = logits.shape[0]
num_experts = logits.shape[1]
scores, top_indices = torch.topk(logits, k=topk, dim=1)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
if capacity_factor is None:
# TopK without capacity
tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts)
return probs, top_indices, tokens_per_expert
else:
# TopK with capacity
expert_capacity = get_capacity(
num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor,
)
# TopK selection, Maskout unused experts
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_mask = torch.zeros_like(logits).scatter(1, top_indices, 1)
# Maskout exceeded tokens
if drop_policy == "probs":
capacity_probs, capacity_indices = torch.topk(
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
elif drop_policy == "position":
_, capacity_indices = torch.topk(topk_mask, k=expert_capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices)
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")
if pad_to_capacity:
final_probs, final_indices = (
capacity_probs.T.contiguous(),
capacity_indices.T.contiguous(),
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
else:
# Get exceed mask and maskout exceeded probs and indices
final_mask = torch.logical_and(topk_mask, capacity_mask)
drop_mask = torch.logical_not(final_mask)
exceed_mask = torch.gather(drop_mask, 1, top_indices)
final_probs = probs * torch.logical_not(exceed_mask)
final_indices = top_indices.clone().masked_fill_(
exceed_mask, torch.iinfo(torch.long).max
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
return final_probs, final_indices, tokens_per_expert_before_capacity
def save_to_aux_losses_tracker(name: str, loss: torch.Tensor, layer_number: int, num_layers: int):
"""Save the auxiliary loss for logging.
Args:
name (str): The name of the loss.
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
"""
# Skip aux loss logging if layer_number is None.
if layer_number is None:
return
if name not in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name] = torch.zeros(
num_layers, device=loss.device
)
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name][layer_number - 1] += loss.detach()
def clear_aux_losses_tracker():
"""Clear the auxiliary losses."""
for name in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name].zero_()
def get_aux_losses_tracker():
"""Return the auxiliary losses."""
return parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER
def aggregate_aux_losses_tracker_across_pipeline_parallel():
"""Sum aux losses across PP."""
for name in parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER:
loss = parallel_state._MOE_AUX_LOSSES_LOGGING_TRACKER[name]
torch.distributed.all_reduce(loss, group=parallel_state.get_pipeline_model_parallel_group())
def track_moe_metrics(
loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False
):
# Aux loss logging
aggregate_aux_losses_tracker_across_pipeline_parallel()
if writer is not None:
aux_losses = {k: v.float() * loss_scale for k, v in get_aux_losses_tracker().items()}
for name, loss_list in aux_losses.items():
if total_loss_dict is not None:
if name not in total_loss_dict:
total_loss_dict[name] = loss_list.mean()
else:
total_loss_dict[name] += loss_list.mean()
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
writer.add_scalar(name, loss_list.mean(), iteration)
if per_layer_logging:
for i, loss in enumerate(loss_list.tolist()):
writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)
# W&B logging lacks support for logging multiple scalars simultaneously.
# As a workaround, we log each scalar individually first, then we can create
# a custom panel to manually group them to a single plot.
if wandb_writer:
wandb_writer.log({f"{name}": loss_list.mean()}, iteration)
if per_layer_logging:
wandb_writer.log(
{
f"moe/{name}_layer_{i}": loss
for i, loss in enumerate(loss_list.tolist())
},
iteration,
)
clear_aux_losses_tracker()
class moe_gather(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, map_):
ctx.input_size = input_.size()
ctx.map = map_
return torch.gather(input_, 0, map_)
@staticmethod
def backward(ctx, grad_output):
input_size = ctx.input_size
map_ = ctx.map
output = torch.zeros(
input_size, dtype=grad_output.dtype, device=torch.cuda.current_device()
)
output.scatter_add_(0, map_, grad_output)
return output, None, None
class moe_scatter(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, map_, output_size=None):
ctx.map = map_
if output_size is not None:
output = torch.zeros(
output_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
else:
output = torch.zeros_like(input_)
output.scatter_add_(0, map_, input_)
return output
@staticmethod
def backward(ctx, grad_output):
map_ = ctx.map
grad_input = torch.gather(grad_output, 0, map_)
return grad_input, None, None, None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel import (
gather_from_sequence_parallel_region,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
save_to_aux_losses_tracker,
sinkhorn,
switch_load_balancing_loss_func,
topk_softmax_with_capacity,
z_loss_func,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class Router(ABC, MegatronModule):
"""Base Router class"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the Router module.
Args:
config (TransformerConfig): Configuration object for the Transformer model.
"""
super().__init__(config)
self.config = config
self.num_experts = self.config.num_moe_experts
self.moe_aux_loss_func = None
self.layer_number = None
# Initialize the gate weights.
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size))
)
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
config.init_method(self.weight)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Logits tensor.
"""
logits = torch.nn.functional.linear(input, self.weight)
return logits
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices.
"""
raise NotImplementedError("Routing function not implemented.")
@abstractmethod
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
raise NotImplementedError("Forward function not implemented.")
def set_layer_number(self, layer_number: int):
"""Set the layer number for the router."""
self.layer_number = layer_number
class TopKRouter(Router):
"""Route each token to the top-k experts."""
def __init__(self, config: TransformerConfig,) -> None:
"""Initialize the zero token dropping router.
Args:
config (TransformerConfig): The configuration for the transformer model.
"""
super().__init__(config=config)
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.input_jitter = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
torch.Tensor: The logits tensor after applying sinkhorn routing.
"""
def _sinkhorn_activation(logits):
if self.topk == 1:
logits = torch.sigmoid(logits)
else: # k > 1
logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
return logits
assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
if self.training:
with torch.no_grad():
norm_logits = sinkhorn(
logits.to(dtype=torch.float32)
) # explicit fp32 conversion for stability
_, indices = torch.topk(norm_logits, k=self.topk, dim=1)
logits = _sinkhorn_activation(logits)
scores = torch.gather(logits, 1, indices)
else:
logits = _sinkhorn_activation(logits)
scores, indices = torch.topk(logits, k=self.topk, dim=1)
return scores, indices
def aux_loss_load_balancing(self, logits: torch.Tensor):
"""Apply loss-based load balancing to the logits tensor.
Args:
logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): the probabilities tensor after load balancing.
indices (torch.Tensor): the indices tensor after top-k selection.
"""
probs, indices, tokens_per_expert = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
)
# Apply load balancing loss
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs)
return probs, indices
def apply_load_balancing_loss(
self,
probs: torch.Tensor,
num_local_tokens_per_expert: torch.Tensor,
activation: torch.Tensor,
):
"""Applies auxiliary loss to the MoE layer.
Args:
probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts]
activation (torch.Tensor): The activation tensor to attach the gradient function to.
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
moe_aux_loss_coeff = (
self.config.moe_aux_loss_coeff / parallel_state.get_tensor_model_parallel_world_size()
)
aux_loss = switch_load_balancing_loss_func(
probs, num_local_tokens_per_expert, self.topk, moe_aux_loss_coeff
)
save_to_aux_losses_tracker(
"load_balancing_loss",
aux_loss / moe_aux_loss_coeff,
self.layer_number,
self.config.num_layers,
)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def apply_z_loss(self, logits):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None:
moe_z_loss_coeff = (
self.config.moe_z_loss_coeff / parallel_state.get_tensor_model_parallel_world_size()
)
z_loss = z_loss_func(logits, moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
save_to_aux_losses_tracker(
"z_loss",
z_loss / self.config.moe_z_loss_coeff,
self.layer_number,
self.config.num_layers,
)
return logits
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, device=input.device),
torch.tensor(1.0 + eps, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
Returns:
probs (torch.Tensor): the probabilities tensor after load balancing.
indices (torch.Tensor): the indices tensor after top-k selection.
"""
logits = logits.view(-1, self.config.num_moe_experts)
# Apply Z-Loss
logits = self.apply_z_loss(logits)
if (
parallel_state.get_tensor_model_parallel_world_size() > 1
and self.config.moe_token_dispatcher_type == "alltoall"
):
# Gather the logits from the TP region
logits = gather_from_sequence_parallel_region(logits)
if self.routing_type == "sinkhorn":
scores, indices = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, indices = self.aux_loss_load_balancing(logits)
elif self.routing_type == "none":
# A naive top-k routing without load balancing
scores, indices, _ = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
return scores, indices
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
self.hidden = input.shape[-1]
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
logits = logits.view(-1, self.config.num_moe_experts)
scores, indices = self.routing(logits)
return scores, indices
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import abstractmethod
from typing import List, Optional, Tuple
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel
from megatron.core.transformer.moe.moe_utils import moe_gather, moe_scatter, permute, unpermute
from megatron.core.transformer.transformer_config import TransformerConfig
class MoETokenDispatcher:
"""
MoE Token Dispatcher
"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.config = config
@abstractmethod
def token_permutation(
self, tokens: torch.Tensor, indices: torch.Tensor,
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
indices (torch.Tensor): indices tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(
self, expert_output: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
probs (torch.Tensor): Each token's score with each expert.
indices (torch.Tensor): The indices used to reorder the expert output.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise NotImplementedError("Restore function not implemented.")
class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
"""
AllGather Based Token dispatcher.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the zero token dropping router.
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert len(self.local_expert_indices) > 0, "Expected at least one local expert index"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
# self.local_probs: probs of global token assignment to local experts.
self.local_probs = None
# self.indices: The indices of `local_indices` (which holds the un-sorted expert indices of tokens that local expert can process) that give its sorted order along dim 0.
self.indices = None
# self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where each element is True if it's between the local_expert_indices. Only useful when cross device token permutation is enabled and **AllGahter** is performed.
self.global_local_map = None
def token_permutation(
self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind: torch.Tensor
):
"""Dispatch tokens to local experts. It's composed of two stages:
(1) Permute the tokens across the expert parallel devices. After this stage,
each device receives all of the tokens assigned to its local set of experts
in its local HBM.
(2) Permute the tokens locally so that they are grouped by their expert
assignment. After the stage (1), the tokens are grouped by which device
they came from. We re-order them locally for subsequent efficient computation.
Args:
hidden_states: input tokens of shape [SeqLen/TP, MBS, HiddenSize]
max_prob: probs of local token assignment to global experts.
max_ind: token assignment to local experts.
Returns:
permuted_local_hidden_states: Permutation of tokens to local experts group.
tokens_per_expert: the number of tokens each local expert to process.
"""
self.hidden_shape = hidden_states.shape
# [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Permute the tokens across the expert parallel devices.
if (self.config.tensor_model_parallel_size > 1) or (
self.config.expert_model_parallel_size > 1
):
with torch.no_grad():
global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
max_ind
)
# Create a mask of mapping between global and local tokens where each
# element is True if it's between the local_expert_indices
global_local_mask = (global_indices >= self.local_expert_indices[0]) & (
global_indices <= self.local_expert_indices[-1]
)
local_indices = global_indices.masked_select(global_local_mask)
if self.router_topk > 1: # k > 1
global_probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(max_prob)
self.local_probs = global_probs.masked_select(global_local_mask)
else:
self.local_probs = max_prob
# [S*B/TP, H] -> [S*B, H]
global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
hidden_states, use_global_buffer=True
)
# Reshape global_local_mask to be compatible with Tensor.gather
global_local_map = global_local_mask.nonzero()[:, 0]
self.global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
local_hidden_states = moe_gather.apply(global_hidden_states, self.global_local_map)
else:
if self.router_topk > 1:
global_local_mask = torch.ones_like(max_ind).bool()
local_indices = max_ind.masked_select(global_local_mask)
self.local_probs = max_prob.masked_select(global_local_mask)
global_local_map = global_local_mask.nonzero()[:, 0]
self.global_local_map = global_local_map.view(-1, 1).expand(
-1, hidden_states.shape[-1]
)
local_hidden_states = torch.gather(hidden_states, 0, self.global_local_map)
else:
local_indices = max_ind
self.local_probs = max_prob
local_hidden_states = hidden_states
self.global_local_map = None
with torch.no_grad():
# The indices of local_indices that give its sorted order along dim 0.
self.indices = torch.argsort(local_indices, dim=0)
tokens_per_expert = torch.histc(
local_indices,
bins=self.num_local_experts,
min=self.local_expert_indices[0],
max=self.local_expert_indices[-1],
)
tokens_per_expert = tokens_per_expert.cpu().to(torch.long)
# Stage2: permute the tokens locally so that they are grouped by their expert assignment
# Reshape indices to be compatible with Tensor.gather
self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
if self.num_local_experts > 1:
permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices)
else:
permuted_local_hidden_states = local_hidden_states
return (
permuted_local_hidden_states,
tokens_per_expert,
)
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
):
"""
Reverse process of `dispatch()` which permutes the ouput of local
experts locallay and across expert parallel rank into the original order to
produce the final output.
Args:
hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize],
ouput of local experts.
bias (optional): The bias tensor.
Returns:
output_total: un-permuted updated hidden states output from all local experts
with shape of [SeqLen/TP, MBS, HiddenSize]
"""
# Stage1: unpermute the tokens and bias locally respectively.
scores = self.local_probs.to(dtype=hidden_states.dtype)
if self.num_local_experts > 1:
assert self.indices.shape == hidden_states.shape
unpermuted_local_hidden = moe_scatter.apply(hidden_states, self.indices)
else:
unpermuted_local_hidden = hidden_states
# Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
if self.router_topk > 1:
unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
unpermuted_local_bias = None
if self.add_bias:
assert bias is not None
unpermuted_local_bias = torch.zeros_like(hidden_states)
assert self.indices.shape == bias.shape
unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias)
if self.router_topk > 1:
unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
output_total = unpermuted_local_hidden
output_bias_total = unpermuted_local_bias
# Unpermute the tokens across expert parallel devices.
if (self.config.tensor_model_parallel_size > 1) or (
self.config.expert_model_parallel_size > 1
):
assert (
self.global_local_map is not None
), "global_local_map is necessary for `AllGather`."
ep_group_size = parallel_state.get_tensor_and_expert_parallel_world_size()
# hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP)
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
assert self.global_local_map.shape == unpermuted_local_hidden.shape
unpermuted_global_hidden = moe_scatter.apply(
unpermuted_local_hidden, self.global_local_map, global_hidden_shape
)
output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_hidden
)
if self.add_bias:
# Unpermute the bias across expert parallel devices.
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
unpermuted_global_bias = unpermuted_global_bias.scatter_add(
0, self.global_local_map, unpermuted_local_bias
)
output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
unpermuted_global_bias
)
# bias is duplicated across tensor parallelism ranks;
# reduce scatter reduces bias across tensor parallel_ranks
output_bias_total = (
output_bias_total / parallel_state.get_tensor_model_parallel_world_size()
)
else:
if self.router_topk > 1:
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
unpermuted_global_hidden = torch.zeros(
global_hidden_shape,
dtype=hidden_states.dtype,
device=torch.cuda.current_device(),
)
output_total = unpermuted_global_hidden.scatter_add(
0, self.global_local_map, unpermuted_local_hidden
)
if self.add_bias:
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
output_bias_total = unpermuted_global_bias.scatter_add(
0, self.global_local_map, unpermuted_local_bias
)
if self.router_topk == 1:
output_total = output_total * scores
output_total = output_total.view(self.hidden_shape)
if self.add_bias:
assert output_bias_total is not None
if self.router_topk == 1:
output_bias_total = output_bias_total * scores
output_bias_total = output_bias_total.view(self.hidden_shape)
else:
output_bias_total = None
return output_total, output_bias_total
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
AlltoAll Based Token dispatcher.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the AlltoAll token dispatcher.
Args:
num_local_experts (int): Number of local experts on the current device.
local_expert_indices (List[int]): Indices of local experts on the current device.
config (TransformerConfig): Configuration for the transformer model.
"""
super().__init__(config=config)
self.hidden_shape = None
self.num_input_tokens = None
self.num_local_experts = num_local_experts
self.num_experts = config.num_moe_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert (
len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
self.ep_size = config.expert_model_parallel_size
self.probs = None
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
# Token drop and padding.
# We need to keep track of the token num if we drop tokens without padding them.
self.num_out_tokens = None
# Drop and pad the input to capacity.
self.drop_and_pad = self.config.moe_pad_expert_input_to_capacity
if self.drop_and_pad:
assert self.config.moe_expert_capacity_factor is not None
self.capacity = None
def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
"""
Preprocess token indices for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the input indices.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts.
Args:
indices (torch.Tensor): Tensor of indices mapping tokens to experts.
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
num_local_tokens_per_expert = torch.histc(
indices, bins=self.num_experts, min=0, max=self.num_experts
)
# num_local_tokens_per_expert: [num_experts]
ep_size = self.config.expert_model_parallel_size
if self.drop_and_pad:
# probs: [num_experts, capacity]
self.capacity = self.probs.size(1)
num_tokens_per_local_expert = torch.full(
(self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long
)
return num_tokens_per_local_expert
elif self.config.moe_expert_capacity_factor is not None:
self.num_out_tokens = num_local_tokens_per_expert.sum().cpu()
if ep_size > 1:
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
.sum(axis=1)
.to(torch.device("cpu"))
.numpy()
)
num_global_tokens_per_expert = _gather_along_first_dim_expert_parallel(
num_local_tokens_per_expert
).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, self.local_expert_indices
]
self.output_splits = (
self.num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu")).numpy()
)
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0).to(
torch.device("cpu"), non_blocking=True
)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
else:
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
-1, self.num_experts
)
num_tokens_per_local_expert = num_local_tokens_per_expert.to(
torch.device("cpu"), non_blocking=True
)
if self.num_local_experts > 1:
expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.config.num_moe_experts)],
dtype=torch.int32,
device=torch.cuda.current_device(),
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
)
return num_tokens_per_local_expert
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): Probs of tokens assigned to experts.
indices (torch.Tensor): Indices of tokens assigned to experts.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
self.probs = probs
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert indices.dim() == 2, "Expected 2D tensor for indices"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(indices)
# Perform tensor parallel AlltoAll communication
# hidden_states: [S*B/TP, H] -> [S*B, H/TP]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states)
# Permutation 1: input to AlltoAll input
self.hiddden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
indices,
num_out_tokens=self.num_out_tokens,
padded_mode=self.drop_and_pad,
)
# Perform expert parallel AlltoAll communication
global_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
)
# Permutation 2: Sort alltoall output by local experts when num_local_experts > 1.
if self.num_local_experts > 1:
if not self.drop_and_pad:
global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
global_input_tokens, self.global_input_tokens_local_experts_indices
)
else:
global_input_tokens = global_input_tokens.reshape(
self.ep_size, self.num_local_experts, self.capacity, -1
)
global_input_tokens = (
global_input_tokens.transpose(0, 1)
.reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
.contiguous()
)
# Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(
global_input_tokens
)
return global_input_tokens, tokens_per_expert
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverse the token permutation to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
# Perform tensor parallel Reduce-Scatter
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.reduce_scatter_last_dim_to_tensor_parallel_region(
hidden_states
)
# Unpermutation 2: expert output to AlltoAll input
if self.num_local_experts > 1:
if not self.drop_and_pad:
hidden_states = unpermute(
hidden_states, self.reversed_global_input_permutation_mapping,
)
else:
hidden_states = hidden_states.reshape(
self.num_local_experts, self.ep_size, self.capacity, -1
)
hidden_states = (
hidden_states.transpose(0, 1)
.reshape(self.ep_size * self.num_local_experts * self.capacity, -1)
.contiguous()
)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
hidden_states,
self.input_splits,
self.output_splits,
)
# Unpermutation 1: AlltoAll output to output
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
probs=self.probs,
padded_mode=self.drop_and_pad,
restore_shape=self.hiddden_shape_before_permute,
)
# Perform tensor parallel AlltoAll communication
# output: [S*B, H/TP] -> [S*B/TP, H]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
output = tensor_parallel.all_to_all_hp2sp(output)
# Reshape the output tensor
output = output.view(self.hidden_shape)
return output, None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import types
from dataclasses import dataclass, field
from typing import Tuple, Union
import torch
@dataclass
class ModuleSpec:
"""This is a Module Specification dataclass.
Specification defines the location of the module (to import dynamically)
or the imported module itself. It also defines the params that need to be
passed to initialize the module.
Args:
module (Union[Tuple, type]): A tuple describing the location of the
module class e.g. `(module.location, ModuleClass)` or the imported
module class itself e.g. `ModuleClass` (which is already imported
using `from module.location import ModuleClass`).
params (dict): A dictionary of params that need to be passed while init.
"""
module: Union[Tuple, type]
params: dict = field(default_factory=lambda: {})
submodules: type = None
def import_module(module_path: Tuple[str]):
"""Import a named object from a module in the context of this function.
TODO: make this importer module more robust, at least make sure there
are no side effects of using this as is
"""
base_path, name = module_path
try:
module = __import__(base_path, globals(), locals(), [name])
except ImportError as e:
print(f"couldn't import module due to {e}")
return None
return vars(module)[name]
def get_module(spec_or_module: Union[ModuleSpec, type], **additional_kwargs):
# If a module clas is already provided return it as is
if isinstance(spec_or_module, (type, types.FunctionType)):
return spec_or_module
# If the module is provided instead of module path, then return it as is
if isinstance(spec_or_module.module, (type, types.FunctionType)):
return spec_or_module.module
# Otherwise, return the dynamically imported module from the module path
return import_module(spec_or_module.module)
def build_module(spec_or_module: Union[ModuleSpec, type], *args, **kwargs):
# If the passed `spec_or_module` is
# a `Function`, then return it as it is
# NOTE: to support an already initialized module add the following condition
# `or isinstance(spec_or_module, torch.nn.Module)` to the following if check
if isinstance(spec_or_module, types.FunctionType):
return spec_or_module
# If the passed `spec_or_module` is actually a spec (instance of
# `ModuleSpec`) and it specifies a `Function` using its `module`
# field, return the `Function` as it is
if isinstance(spec_or_module, ModuleSpec) and isinstance(
spec_or_module.module, types.FunctionType
):
return spec_or_module.module
# Check if a module class is provided as a spec or if the module path
# itself is a class
if isinstance(spec_or_module, type):
module = spec_or_module
elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
module = spec_or_module.module
else:
# Otherwise, dynamically import the module from the module path
module = import_module(spec_or_module.module)
# If the imported module is actually a `Function` return it as it is
if isinstance(module, types.FunctionType):
return module
# Finally return the initialized module with params from the spec as well
# as those passed as **kwargs from the code
# Add the `submodules` argument to the module init call if it exists in the
# spec.
if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
kwargs["submodules"] = spec_or_module.submodules
try:
return module(
*args, **spec_or_module.params if hasattr(spec_or_module, "params") else {}, **kwargs
)
except Exception as e:
# improve the error message since we hide the module name in the line above
import sys
tb = sys.exc_info()[2]
raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
sys.exc_info()[2]
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Tuple, Union
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.packed_seq_params import PackedSeqParams
try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDelayedScaling,
TENorm,
get_cpu_offload_context,
te_checkpoint,
)
except ImportError:
TEDelayedScaling = None
TENorm = None
get_cpu_offload_context = None
te_checkpoint = None
#print("Do not support transformer_engine")
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import BaseTransformerLayer, TransformerLayer
from megatron.core.transformer.utils import sharded_state_dict_default
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
def get_num_layers_to_build(config: TransformerConfig) -> int:
num_layers_per_pipeline_rank = (
config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# Interleaved pipeline parallelism:
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
num_layers_to_build = num_layers_per_virtual_rank
else:
# Non-interleaved pipeline parallelism:
# Each stage gets a contiguous set of layers.
num_layers_to_build = num_layers_per_pipeline_rank
return num_layers_to_build
@dataclass
class TransformerBlockSubmodules:
layer_specs: List[ModuleSpec] = None
def _get_block_submodules(
config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec],
) -> TransformerBlockSubmodules:
# Transformer block submodules.
if isinstance(spec, TransformerBlockSubmodules):
return spec
# ModuleSpec here is generally assumed to be for a transformer layer that
# is implemented in `transformer_layer.py` or if it subclasses
# `BaseTransformerLayer` from the `transformer_layer.py` file.
elif isinstance(spec, ModuleSpec):
if issubclass(spec.module, TransformerBlock):
return spec.submodules
elif issubclass(spec.module, BaseTransformerLayer):
num_layers = get_num_layers_to_build(config)
return TransformerBlockSubmodules(layer_specs=[spec] * num_layers)
else:
raise Exception(f"specialize for {spec.module.__name__}.")
else:
raise Exception(f"specialize for {type(spec).__name__}.")
class TransformerBlock(MegatronModule):
"""Transformer class."""
def __init__(
self,
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
):
super().__init__(config=config)
self.submodules = _get_block_submodules(config, spec)
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
# Dictionary to store CUDA graphs. Number of items in the dictionary = len(self.layers).
# Item `i` in the dictionary is a list of `N` CUDA graphs for layer 'i' where N is the
# number of microbatches. Multiple CUDA graphs per layer is required to support
# pipelining which requires running FWD graph of multiple microbatches before BWD graph.
self.cuda_graphs = {}
self.current_microbatch = -1
# required for pipeline parallel schedules
self.input_tensor = None
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
if get_cpu_offload_context is not None:
(
self.offload_context,
self.group_prefetch_offload_commit_async,
) = get_cpu_offload_context(
self.config.cpu_offloading,
self.config.cpu_offloading_num_layers,
self.config.cpu_offloading_activations,
self.config.cpu_offloading_weights,
)
self.config._cpu_offloading_context = (
self.offload_context if self.config.cpu_offloading else None
)
else:
assert (
self.config.cpu_offloading == False
), "CPU Offloading is enabled when TE is not present"
self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None
self.config._cpu_offloading_context = None
self._build_layers()
self.num_layers_per_pipeline_rank = len(self.layers)
def _build_layers(self):
# Transformer layers.
# @jcasper can we improve how we deal with layer_number?
# currently it's only used in CoreAttention?
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
def build_layer(layer_spec, layer_number):
return build_module(layer_spec, config=self.config, layer_number=layer_number,)
# offset is implicit in TransformerLayer
self.layers = torch.nn.ModuleList(
[
build_layer(layer_spec, i + 1)
for i, layer_spec in enumerate(self.submodules.layer_specs)
]
)
# # TODO: add back standalone_embedding_stage
# if self.num_layers == 0:
# # When a standalone embedding stage is used (e.g.,
# # args.standalone_embedding_stage == True), virtual pipeline ranks
# # on pipeline rank 0 will have zero transformer layers assigned to
# # them. This results in the model's input and output tensors to be
# # the same, which will cause failure for certain output tensor
# # optimizations (e.g., pipeline output deallocation). To remedy
# # this, we assign a 'no-op' layer on these ranks, which will
# # disconnect the input tensor from the output tensor.
# self.num_layers = 1
# self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
# else:
# self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = TENorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
def _get_layer(self, layer_number: int):
return self.layers[layer_number]
def _checkpointed_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor,
context_mask: Tensor,
rotary_pos_emb: Tensor,
packed_seq_params: PackedSeqParams,
):
"""Forward method with activation checkpointing."""
def custom(start: int, end: int):
def custom_forward(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
):
for index in range(start, end):
layer = self._get_layer(index)
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=None,
packed_seq_params=packed_seq_params,
)
return hidden_states, context
return custom_forward
def checkpoint_handler(forward_func):
if self.config.fp8:
return te_checkpoint(
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
return tensor_parallel.checkpoint(
forward_func,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers_per_pipeline_rank:
hidden_states, context = checkpoint_handler(
custom(l, l + self.config.recompute_num_layers)
)
l += self.config.recompute_num_layers
elif self.config.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
recompute_skip_num_layers = 0
for l in range(self.num_layers_per_pipeline_rank):
# Skip recomputation when input grad computation is not needed.
# Need to have at least one input tensor with gradient computation
# for re-enterant autograd engine.
if self.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if (
l >= recompute_skip_num_layers
and l < self.config.recompute_num_layers + recompute_skip_num_layers
):
hidden_states, context = checkpoint_handler(custom(l, l + 1))
else:
hidden_states, context = custom(l, l + 1)(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def set_input_tensor(self, input_tensor: Tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
):
# hidden_states (float): [s, b, h]
# attention_mask (bool): [1, 1, s, s]
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True,
)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context and fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
else:
for l_no, layer in enumerate(self.layers):
with self.offload_context:
if (len(self.cuda_graphs) == 0) or (not self.training):
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
)
# CUDA graph doesn't output context and is expected to be None
assert (
(context is None)
or (not self.config.enable_cuda_graph)
or (not self.training)
)
else:
# CUDA graph replay for layer `l_no` and microbatch `self.current_microbatch`
# CUDA graph requires positional arguments with the exception of is_first_microbatch.
# Also CUDA graph accepts only Tensor inputs and outputs. Hence, the arg list and
# returned list is limited to `hidden_states`.
assert (len(self.cuda_graphs) > l_no) and (
self.current_microbatch < len(self.cuda_graphs[l_no])
)
hidden_states = self.cuda_graphs[l_no][self.current_microbatch](
hidden_states, is_first_microbatch=(self.current_microbatch == 0),
)
if (
torch.is_grad_enabled()
and self.config.cpu_offloading
and self.group_prefetch_offload_commit_async is not None
):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None
) -> ShardedStateDict:
assert not sharded_offsets, "Unexpected sharded offsets"
non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False
)
sharded_state_dict = {}
layer_prefix = f'{prefix}layers.'
num_layers = self.config.num_layers
for layer in self.layers:
offset = layer._get_layer_offset()
global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock
if non_homogeneous_layers:
sharded_prefix = f'{layer_prefix}{global_layer_offset}.'
sharded_pp_offset = []
else:
sharded_prefix = layer_prefix
sharded_pp_offset = [
(0, global_layer_offset, num_layers)
] # PP sharding offset for ShardedTensors
layer_sharded_state_dict = layer.sharded_state_dict(
state_dict_prefix, sharded_pp_offset, metadata
)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)
# Add modules other than self.layers
for name, module in self.named_children():
if not module is self.layers:
sharded_state_dict.update(
sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata
)
)
return sharded_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment