Commit d3dd8642 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add

parents
Pipeline #1259 failed with stages
in 0 seconds
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.core import tensor_parallel
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
# from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
@dataclass
class MLPSubmodules:
linear_fc1: Union[ModuleSpec, type] = None
linear_fc2: Union[ModuleSpec, type] = None
class MLP(MegatronModule):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
Returns an output and a bias to be added to the output.
If config.add_bias_linear is False, the bias returned is None.
We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules, is_expert: bool = False
):
super().__init__(config=config)
self.config: TransformerConfig = config
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = TEColumnParallelLinear(
self.config.hidden_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
)
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
self.linear_fc2 = TERowParallelLinear(
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
skip_bias_add=True,
)
def forward(self, hidden_states):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
if self.config.bias_gelu_fusion:
assert self.config.add_bias_linear is True
assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)
return output, output_bias
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config):
super(ParallelMLP, self).__init__(config=config)
args = get_args()
self.add_bias = config.add_bias_linear
if args.num_shared_experts is not None:
ffn_hidden_size = config.ffn_hidden_size * args.num_shared_experts
self.moe_intermediate_size = ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
self.num_shared_experts = args.num_shared_experts
self.isolate_shared_experts = args.isolate_shared_experts
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
skip_bias_add=True,
)
self.bias_gelu_fusion = False
self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.shared_experts = torch.nn.ModuleList()
if self.isolate_shared_experts:
for i in range(self.num_shared_experts):
self.shared_experts.append(tensor_parallel.RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias,
input_is_parallel=True)
)
else:
self.shared_experts.append(tensor_parallel.RowParallelLinear(
self.moe_intermediate_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias,
input_is_parallel=True)
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
assert self.add_bias is True
assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
if self.isolate_shared_experts:
intermediate_parallel = torch.chunk(intermediate_parallel, self.num_shared_experts, dim=-1)
output = torch.zeros_like(hidden_states)
for expert_num, expert in enumerate(self.shared_experts):
output += expert(intermediate_parallel[expert_num])[0]
else:
for expert_num, expert in enumerate(self.shared_experts):
output = expert(intermediate_parallel)[0]
return output, None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module"""
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.transformer_config import TransformerConfig
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
# def __init__(self, config: TransformerConfig, share_word_embeddings=True):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(MegatronModule):
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super(Float16Module, self).__init__(config)
self.config = config
self.fp16 = config.fp16
self.bf16 = config.bf16
if self.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif self.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('Either config.fp16 or config.bf16 should be True.')
self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs):
if parallel_state.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if parallel_state.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
# Megatron Core MoE Key Features
### Parallelism
- **Expert Parallel**
- A specific method of parallelism for MoE models, where experts are partitioned onto different workers and each worker processes a different batch of training samples, each worker process one or more experts for each MoE layer.
- **3D Parallel**: Data Parallel , Tensor Parallel, Pipeline Parallel, Sequence Parallel
- Note: When using MoE with expert parallelism and tensor parallelism, sequence parallelism must be used.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/SP for handling larger MoE variants.
- **Distributed optimizer.**
### Router and Load Balancing
- Router type:
- Top-K MLP router
- Expert Choice router (coming soon)
- Load Balancing algorithms:
- Sinkhorn (S-BASE)
- Aux loss / Load balancing loss
### Performance Optimizations
- GroupedGEMM when num local experts > 1
- Supported dtype: bf16
### Token Dispatch Mechanism
- Dropless / No token drop.
- Token drop. (coming soon)
### Ease of use
- Checkpoint converter (coming soon)
## Upcoming features
- Enhanced cutlass GroupedGEMM kernels
- Reduced host-device syncs.
- More supported dtype: fp32/bf16/fp16
- Kernel heuristics tuned for A100/A10/L40S
- BWD cutlass GroupedGEMM kernels supported
- Token permutation / unpermutation fusion
- Fused Sinkhorn Kernel
- Context Parallel with MoE
- FP8 training support
- Enable ’--tp-comm-overlap‘ for MoE
- Distributed optimizer for MoE params.
# User Guide
### MoE Related Arguments
| Item | Description |
| --- | --- |
| num-experts | Number of Experts in MoE (None means no MoE) |
| expert-model-parallel-size | Degree of expert model parallelism. |
| moe-grouped-gemm | When there are multiple experts per rank, compress multiple local gemms into a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 |
| moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| moe-router-topk | Number of experts to route to for each token. The default is 2. |
| moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. |
| moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. |
| moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. |
| moe-token-dropping | This feature involves selectively dropping and padding tokens for each expert to achieve a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note: Currently unsupported. |
### Example
To train a top-2 MoE model with an auxiliary loss, include the following arguments:
```python
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--use-distributed-optimizer
```
## A detailed MoE script:
<details>
<summary>Click here. </summary>
```python
#!/bin/bash
# Runs Mixtral 8x7B model on 16 A100 GPUs
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--use-mcore-models
--disable-bias-linear
--seq-length 2048
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
)
MOE_ARGS=(
--num-experts 8
--expert-model-parallel-size 4
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, None. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
)
DATA_ARGS=(
--tokenizer-type Llama2Tokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--data-path $DATA_PATH
--split 99990,8,2
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 128
--lr 1e-4
--train-iters 500000
--lr-decay-iters 320000
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--lr-warmup-iters 500
--clip-grad 1.0
--bf16
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 4
--pipeline-model-parallel-size 1
--sequence-parallel
--use-distributed-optimizer
)
LOGGING_ARGS=(
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
--no-load-optim \
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"Mixtral-Finetuning"}
--wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"}
)
fi
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
```
</details>
\ No newline at end of file
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import numpy as np
import torch
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""
def __init__(self, num_local_experts: int, config: TransformerConfig):
super().__init__(config=config)
self.config: TransformerConfig = config
self.num_local_experts = num_local_experts
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
self.expert_parallel = False
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
# How many feature each rank holds for fc1 and fc2, respectively.
tp_size = parallel_state.get_tensor_model_parallel_world_size()
fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts
if config.gated_linear_unit:
# Project to 4h. If using swiglu double the output width,
# see https://arxiv.org/pdf/2002.05202.pdf
fc1_output_size *= 2
fc1_output_size_per_partition = divide(fc1_output_size, tp_size)
fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts
fc2_input_size_per_partition = divide(fc2_input_size, tp_size)
# Note: The current kernel implementations of grouped_gemm
# does not support transposition with CUTLASS grouped GEMM
# (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
# and as a result we avoid allocate the transpose of weights.
# Initialize weight.
if config.use_cpu_initialization:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight1,
self.config.hidden_size,
fc1_output_size,
fc1_output_size_per_partition,
partition_dim=1,
init_method=config.init_method,
params_dtype=config.params_dtype,
)
_initialize_affine_weight_cpu(
self.weight2,
fc2_input_size,
self.config.hidden_size,
fc2_input_size_per_partition,
partition_dim=0,
init_method=config.output_layer_init_method,
params_dtype=config.params_dtype,
)
else:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight1,
config.init_method,
partition_dim=1,
expert_parallel=self.expert_parallel,
)
_initialize_affine_weight_gpu(
self.weight2,
config.output_layer_init_method,
partition_dim=0,
expert_parallel=self.expert_parallel,
)
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = gg.ops.gmm(
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False
)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output_parallel = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
fc2_output = reduce_from_tensor_model_parallel_region(fc2_output_parallel)
else:
# None token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
fc2_output = permuted_local_hidden_states
return fc2_output, None
def sharded_state_dict(self, prefix='', sharded_offsets=()):
raise NotImplementedError(
'Currently distributed checkpointing is not supported for GroupedMLP'
)
class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
self.add_bias = config.add_bias_linear
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden)
output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
return output_local, output_bias_local
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
try:
import grouped_gemm
except ImportError:
grouped_gemm = None
def grouped_gemm_is_available():
return grouped_gemm is not None
def assert_grouped_gemm_is_available():
assert grouped_gemm_is_available(), (
"Grouped GEMM is not available. Please run "
"`pip install git+https://github.com/fanshiqing/grouped_gemm@main`."
)
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron import get_args
from megatron.core import parallel_state
from megatron.core.transformer.mlp import MLPSubmodules,ParallelMLP
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.token_dispatcher import MoEDroplessTokenDispatcher
from megatron.core.transformer.transformer_config import TransformerConfig
class BaseMoELayer(MegatronModule, ABC):
"""Base class for a mixture of experts layer.
Args:
config (TransformerConfig): Configuration object for the transformer model.
"""
def __init__(self, config: TransformerConfig):
super(BaseMoELayer, self).__init__(config)
self.config = config
self.expert_parallel_size = 1
assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size"
assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
local_expert_indices_offset = (
0
)
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
self.router = None
self.experts = None
self.token_dispatcher = None
@abstractmethod
def forward(self, hidden_states):
pass
class MoELayer(BaseMoELayer):
"""Mixture of experts Layer **currently only supports no token dropping**.
Args:
BaseMoELayer (MegatronModule): Base class for MoE layers
"""
def __init__(self, config: TransformerConfig, submodules: MLPSubmodules = None):
self.submodules = submodules
super(MoELayer, self).__init__(config=config)
args = get_args()
self.use_fp32_router = args.use_fp32_router
self.router = TopKRouter(config=self.config)
if args.num_shared_experts is not None:
self.mlp = ParallelMLP(config)
self.num_shared_experts = args.num_shared_experts
if self.config.moe_grouped_gemm:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules, MLPSubmodules)
self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
self.token_dispatcher = MoEDroplessTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
def forward(self, hidden_states: torch.Tensor):
# process MoE
# hidden_states: [SeqLen/TP, MBS, hidden_states]
# scores, indices: [SeqLen/TP * MBS, num_moe_experts]
scores, indices = self.router(hidden_states)
if self.use_fp32_router:
scores = scores.to(hidden_states.dtype)
(
dispatched_input,
tokens_per_expert,
scores,
indices,
global_local_map,
) = self.token_dispatcher.token_permutation(hidden_states, scores, indices)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(
expert_output, scores, indices, global_local_map, mlp_bias
)
if self.num_shared_experts is not None:
output = output + self.mlp(hidden_states)[0]
return output, mlp_bias
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
def switch_load_balancing_loss_func(gates, mask, moe_aux_loss_coeff):
"""Calculate the auxiliary loss for better load balacing.
Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
Args:
gates (torch.Tensor): The gates tensor representing the routing probabilities for each expert. [SeqLen * MBS, num_moe_experts]
mask (torch.Tensor): The 2D mask tensor indicating which experts are selected. [SeqLen * MBS, num_moe_experts]
Returns:
torch.Tensor: The auxiliary loss for load balancing.
"""
num_moe_experts = mask.size(-1)
# gate_mean: [num_moe_experts]
gates_mean = gates.mean(dim=0)
top_k = mask[0].count_nonzero()
# selection_mean: [num_moe_experts]
selection_mean = mask.float().mean(dim=0) / top_k
# torch.tensor scale
aux_loss = torch.sum(gates_mean * selection_mean) * num_moe_experts
aux_loss *= moe_aux_loss_coeff
return aux_loss
def z_loss_func(logits, z_loss_coeff):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
# z_loss: scale
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss
def sinkhorn(cost: torch.Tensor, tol: float = 0.0001):
"""Sinkhorn based MoE routing function"""
cost = torch.exp(cost)
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
eps = 0.00000001
error = 1e9
d1_old = d1
while error > tol:
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
error = torch.mean(torch.abs(d1_old - d1))
d1_old = d1
return d1 * cost * d0.unsqueeze(1)
class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that compute and scales the grad for auxiliary loss.
"""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
@staticmethod
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
"""Preserve the aux_loss by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
aux_loss (torch.Tensor): The auxiliary loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(aux_loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for auxiliary loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
"""
(aux_loss,) = ctx.saved_tensors
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
return grad_output, scaled_aux_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the aux loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
"""
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import math
from abc import ABC, abstractmethod
from typing import Callable, List
import torch
from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
sinkhorn,
switch_load_balancing_loss_func,
z_loss_func,
)
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron import get_args
from megatron.core import mpu, tensor_parallel
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
try:
from einops import rearrange
except ImportError:
rearrange = None
try:
from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
from flash_attn import flash_attn_func
except ImportError:
flash_attn_unpadded_func = None
class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device)
dropout_p = 0
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class CoreAttention(MegatronModule):
def __init__(self, layer_number, config,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__(config)
self.fp16 = config.fp16
self.bf16 = config.bf16
args = get_args()
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = config.sequence_parallel
projection_size = args.seq_length
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = projection_size
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_router_heads)
self.num_attention_heads_per_partition = config.num_attention_heads
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
config.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
if self.training == False:
self.hidden_size_per_partition = query_layer.size(2)
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ParallelAttention_router(MegatronModule):
def __init__(self, config, layer_number=0,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention_router, self).__init__(config)
args = get_args()
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = config.params_dtype
self.sequence_parallel = config.sequence_parallel
self.flash_attn_drop = args.flash_attn_drop
self.use_lf_gate = args.use_lf_gate
self.hidden_size = config.hidden_size
self.use_flash_attn = args.use_flash_attn
self.use_fp32_router = args.use_fp32_router
projection_size = args.num_experts
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide(
args.seq_length , args.num_attention_router_heads)
self.num_attention_router_heads = args.num_attention_router_heads
self.num_attention_heads_per_partition = config.num_attention_heads
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
3 * projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=True)
self.core_attention = CoreAttention(self.layer_number, config,
AttnMaskType.padding)
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
def forward(self, hidden_states, attention_mask=None, enc_position_ids=None,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
is_first_step = False
before_hidden_states = None
if self.attention_type == AttnType.self_attn:
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
seq_length = query_layer.size(0)
batch_size = query_layer.size(1)
expert_num = query_layer.size(2)
if self.training == False:
self.num_attention_router_heads = seq_length // self.hidden_size_per_attention_head
query_layer = query_layer.transpose(0, 2).contiguous().view(expert_num, batch_size, self.num_attention_router_heads, self.hidden_size_per_attention_head)
key_layer = key_layer.transpose(0, 2).contiguous().view(expert_num, batch_size, self.num_attention_router_heads, self.hidden_size_per_attention_head)
value_layer = value_layer.transpose(0, 2).contiguous().view(expert_num, batch_size, self.num_attention_router_heads, self.hidden_size_per_attention_head)
if self.use_fp32_router:
context_layer = self.core_attention(
query_layer.float(), key_layer.float(), value_layer.float(), None)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, None)
router_output = context_layer.transpose(0, 2).contiguous()
return router_output
class Router(ABC, MegatronModule):
"""Base Router class"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the Router module.
Args:
config (TransformerConfig): Configuration object for the Transformer model.
"""
super().__init__(config)
self.config = config
args = get_args()
self.use_attention_router = args.use_attention_router
self.num_moe_experts = self.config.num_moe_experts
self.moe_aux_loss_func = None
if self.use_attention_router:
self.attention_router = ParallelAttention_router(config)
else:
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size))
)
if args.process_checkpoint:
config.init_method(self.weight)
else:
with get_cuda_rng_tracker().fork():
config.init_method(self.weight)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor. [SeqLen/TP, MBS, HiddenSize]
Returns:
torch.Tensor: Logits tensor.
"""
# logits: [SeqLen/TP, MBS, num_moe_experts]
if self.use_attention_router:
logits = self.attention_router(input)
else:
logits = torch.nn.functional.linear(input, self.weight)
return logits
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices.
"""
raise NotImplementedError("Routing function not implemented.")
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor. [SeqLen/TP, MBS, HiddenSize]
Returns:
Tuple[torch.Tensor, torch.Tensor]: scores and indices.
"""
# self.hidden [SeqLen/TP, MBS, HiddenSize]
self.hidden = input.shape[-1]
# logits [SeqLen/TP, MBS, num_moe_experts]
logits = self.gating(input)
# logits [SeqLen/TP * MBS, num_moe_experts]
logits = logits.view(-1, self.config.num_moe_experts)
scores, indices = self.routing(logits)
return scores, indices
class TopKRouter(Router):
"""Route each token to the top-k experts."""
def __init__(self, config: TransformerConfig,) -> None:
"""Initialize the zero token dropping router.
Args:
config (TransformerConfig): The configuration for the transformer model.
"""
super().__init__(config=config)
assert config.moe_token_dropping is False
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.moe_aux_loss_func = switch_load_balancing_loss_func
self.input_jitter = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
torch.Tensor: The logits tensor after applying sinkhorn routing.
"""
def _sinkhorn_activation(logits):
if self.topk == 1:
logits = torch.sigmoid(logits)
else: # k > 1
logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
return logits
assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
if self.training:
with torch.no_grad():
norm_logits = sinkhorn(
logits.to(dtype=torch.float32)
) # explicit fp32 conversion for stability
_, indices = torch.topk(norm_logits, k=self.topk, dim=1)
logits = _sinkhorn_activation(logits)
scores = torch.gather(logits, 1, indices)
else:
logits = _sinkhorn_activation(logits)
scores, indices = torch.topk(logits, k=self.topk, dim=1)
return scores, indices
def aux_loss_load_balancing(self, logits: torch.Tensor):
"""Apply loss-based load balancing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The scores and the indices tensor after applying load balancing.
"""
# 取topk,top_logits, indices [SeqLen/TP * MBS, TopK]
top_logits, indices = torch.topk(logits, k=self.topk, dim=1)
# scores [SeqLen/TP * MBS, TopK]
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
# Apply load balancing loss
# probs [SeqLen/TP * MBS, num_moe_experts]
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# scores: [SeqLen/TP * MBS, num_moe_experts]
scores = self.apply_aux_loss(self.moe_aux_loss_func, probs, indices, activation=scores)
return scores, indices
def apply_aux_loss(
self,
loss_func: Callable,
probs: torch.Tensor,
indices: torch.Tensor,
activation: torch.Tensor,
):
"""Applies auxiliary loss to the MoE layer.
Args:
loss_func (callable): The loss function to be used. switch_load_balancing_loss_func
probs (torch.Tensor): The probabilities output by the MoE layer. [SeqLen/TP * MBS, num_moe_experts]
indices (torch.Tensor): The indices of the selected experts. [SeqLen/TP * MBS, TopK]
activation (torch.Tensor): The activation tensor to attach the gradient function to. [SeqLen/TP * MBS, TopK]
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
mask = torch.nn.functional.one_hot(indices, num_classes=self.num_moe_experts).sum(dim=1)
aux_loss = loss_func(probs, mask, self.config.moe_aux_loss_coeff)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def apply_z_loss(self, logits):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None:
z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
return logits
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, device=input.device),
torch.tensor(1.0 + eps, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor. [SeqLen/TP * MBS, num_moe_experts]
Returns:
Tuple[torch.Tensor, torch.Tensor]: Probs and the indices tensor.
"""
logits = logits.view(-1, self.config.num_moe_experts)
# Apply Z-Loss, ST-MOE
# L_z(x) = \frac{1}{B} \sum_{i=1}^B \left( log \sum_{j=1}^N e^{x_j^{(i)}} \right)^2
logits = self.apply_z_loss(logits)
# Apply input jitter ST-MOE
logits = self.apply_input_jitter(logits)
if self.routing_type == "sinkhorn":
scores, indices = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, indices = self.aux_loss_load_balancing(logits)
elif self.routing_type == "none":
# A naive top-k routing without load balancing
# top_logits, indices [SeqLen/TP * MBS, TopK]
top_logits, indices = torch.topk(logits, k=self.topk, dim=1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
# scores, indices: [SeqLen/TP * MBS, num_moe_experts]
return scores, indices
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import abstractmethod
from typing import List
import torch
# from megatron.core import parallel_state, tensor_parallel
# from megatron.core.parallel_state import get_tensor_and_expert_parallel_group
from megatron.core.transformer.transformer_config import TransformerConfig
class MoETokenDispatcher:
"""
MoE Token Dispatcher
"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.config = config
@abstractmethod
def token_permutation(
self, tokens: torch.Tensor, indices: torch.Tensor,
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
indices (torch.Tensor): indices tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_unpermutation(
self, expert_output: torch.Tensor, scores: torch.Tensor, indices: torch.Tensor,
):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
scores (torch.Tensor): Each token's score with each expert.
indices (torch.Tensor): The indices used to reorder the expert output.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise NotImplementedError("Restore function not implemented.")
class MoEDroplessTokenDispatcher(MoETokenDispatcher):
"""
Token dispatcher without token dropping.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
) -> None:
"""
Initialize the zero token dropping router.
"""
super().__init__(config=config)
self.num_local_experts = num_local_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
assert len(self.local_expert_indices) > 0, "Expected at least one local expert index"
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
def token_permutation(
self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind: torch.Tensor
):
"""Dispatch tokens to local experts. It's composed of two stages:
(1) Permute the tokens across the expert parallel devices. After this stage,
each device receives all of the tokens assigned to its local set of experts
in its local HBM.
(2) Permute the tokens locally so that they are grouped by their expert
assignment. After the stage (1), the tokens are grouped by which device
they came from. We re-order them locally for subsequent efficient computation.
Args:
hidden_states: input tokens of shape [SeqLen/TP, MBS, HiddenSize]
max_prob: probs of token assignment to local experts. scores, [SeqLen/TP * MBS, num_moe_experts]
max_ind: token assignment to local experts. indices, [SeqLen/TP * MBS, num_moe_experts]
Returns:
permuted_local_hidden_states: Permutation of tokens to local experts group.
tokens_per_expert: the number of tokens each local expert to process.
indices: The indices of `local_indices` (which holds the un-sorted expert
indices of tokens that local expert can process) that give its sorted order along dim 0.
global_local_map (optional): 2D tensor. A mask of mapping between global and local tokens where each
element is True if it's between the local_expert_indices. Only useful
when cross device token permutation is enabled and **AllGahter** is performed.
"""
self.hidden_shape = hidden_states.shape
# [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.router_topk > 1:
# [S*B/TP, TopK] ones
global_local_map = torch.ones_like(max_ind).bool()
# local_indices: [S*B*TopK/TP]
local_indices = max_ind.masked_select(global_local_map)
# local_probs: [S*B*TopK/TP]
local_probs = max_prob.masked_select(global_local_map)
# global_local_map: [S*B*TopK/TP]
global_local_map = global_local_map.nonzero()[:, 0]
# global_local_map: [S*B*TopK/TP, H]
global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
# local_hidden_states: [S*B*TopK/TP, H]
local_hidden_states = torch.gather(hidden_states, 0, global_local_map)
else:
local_indices = max_ind
local_probs = max_prob
local_hidden_states = hidden_states
global_local_map = None
with torch.no_grad():
# The indices of local_indices that give its sorted order along dim 0.
indices = torch.argsort(local_indices, dim=0)
tokens_per_expert = torch.histc(
local_indices,
bins=self.num_local_experts,
min=self.local_expert_indices[0],
max=self.local_expert_indices[-1],
)
tokens_per_expert = tokens_per_expert.cpu().to(torch.long)
# Stage2: permute the tokens locally so that they are grouped by their expert assignment
# Reshape indices to be compatible with Tensor.gather
# indices: [S*B*TopK/TP, H]
indices = indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
# permuted_local_hidden_states: [S*B*TopK/TP, H]
permuted_local_hidden_states = torch.gather(local_hidden_states, 0, indices)
return (
permuted_local_hidden_states,
tokens_per_expert,
local_probs,
indices,
global_local_map,
)
def token_unpermutation(
self,
hidden_states: torch.Tensor,
scores: torch.Tensor,
indices: torch.Tensor,
global_local_map: torch.Tensor = None,
bias: torch.Tensor = None,
):
"""
Reverse process of `dispatch()` which permutes the ouput of local
experts locallay and across expert parallel rank into the original order to
produce the final output.
Args:
hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize],
ouput of local experts.
scores: 2D tensor of the probs of token assignment to local experts.
indices: 2D tensor of the indices of `local_indices` (which holds the un-sorted expert
indices of tokens that local expert can process) that give its sorted order along dim 0.
global_local_map (optional): 2D tensor, a mask of mapping between global and local tokens where each
element is True if it's between the local_expert_indices. Only useful
when cross device token permutation is enabled and **AllGather** is performed.
bias (optional): The bias tensor.
Returns:
output_total: un-permuted updated hidden states output from all local experts
with shape of [SeqLen/TP, MBS, HiddenSize]
"""
# Stage1: unpermute the tokens and bias locally respectively.
scores = scores.to(dtype=hidden_states.dtype)
unpermuted_local_hidden = torch.zeros_like(hidden_states)
assert indices.shape == hidden_states.shape
unpermuted_local_hidden = unpermuted_local_hidden.scatter(0, indices, hidden_states)
# Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
if self.router_topk > 1:
unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
unpermuted_local_bias = None
if self.add_bias:
assert bias is not None
unpermuted_local_bias = torch.zeros_like(hidden_states)
assert indices.shape == bias.shape
unpermuted_local_bias = unpermuted_local_bias.scatter(0, indices, bias)
if self.router_topk > 1:
unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
output_total = unpermuted_local_hidden
output_bias_total = unpermuted_local_bias
if self.router_topk > 1:
global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
unpermuted_global_hidden = torch.zeros(
global_hidden_shape,
dtype=hidden_states.dtype,
device=torch.cuda.current_device(),
)
output_total = unpermuted_global_hidden.scatter_add(
0, global_local_map, unpermuted_local_hidden
)
if self.add_bias:
unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
output_bias_total = unpermuted_global_bias.scatter_add(
0, global_local_map, unpermuted_local_bias
)
if self.router_topk == 1:
output_total = output_total * scores
output_total = output_total.view(self.hidden_shape)
if self.add_bias:
assert output_bias_total is not None
if self.router_topk == 1:
output_bias_total = output_bias_total * scores
output_bias_total = output_bias_total.view(self.hidden_shape)
else:
output_bias_total = None
return output_total, output_bias_total
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import types
from dataclasses import dataclass, field
from typing import Tuple, Union
import torch
@dataclass
class ModuleSpec:
"""This is a Module Specification dataclass.
Specification defines the location of the module (to import dynamically)
or the imported module itself. It also defines the params that need to be
passed to initialize the module.
Args:
module (Union[Tuple, type]): A tuple describing the location of the
module class e.g. `(module.location, ModuleClass)` or the imported
module class itself e.g. `ModuleClass` (which is already imported
using `from module.location import ModuleClass`).
params (dict): A dictionary of params that need to be passed while init.
"""
module: Union[Tuple, type]
params: dict = field(default_factory=lambda: {})
submodules: type = None
def import_module(module_path: Tuple[str]):
"""Import a named object from a module in the context of this function.
TODO: make this importer module more robust, at least make sure there
are no side effects of using this as is
"""
base_path, name = module_path
try:
module = __import__(base_path, globals(), locals(), [name])
except ImportError as e:
print(f"couldn't import module due to {e}")
return None
return vars(module)[name]
def get_module(spec_or_module: Union[ModuleSpec, type], **additional_kwargs):
# If a module clas is already provided return it as is
if isinstance(spec_or_module, (type, types.FunctionType)):
return spec_or_module
# If the module is provided instead of module path, then return it as is
if isinstance(spec_or_module.module, (type, types.FunctionType)):
return spec_or_module.module
# Otherwise, return the dynamically imported module from the module path
return import_module(spec_or_module.module)
def build_module(spec_or_module: Union[ModuleSpec, type], *args, **kwargs):
# If the passed `spec_or_module` is
# a `Function`, then return it as it is
# NOTE: to support an already initialized module add the following condition
# `or isinstance(spec_or_module, torch.nn.Module)` to the following if check
if isinstance(spec_or_module, types.FunctionType):
return spec_or_module
# If the passed `spec_or_module` is actually a spec (instance of
# `ModuleSpec`) and it specifies a `Function` using its `module`
# field, return the `Function` as it is
if isinstance(spec_or_module, ModuleSpec) and isinstance(
spec_or_module.module, types.FunctionType
):
return spec_or_module.module
# Check if a module class is provided as a spec or if the module path
# itself is a class
if isinstance(spec_or_module, type):
module = spec_or_module
elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
module = spec_or_module.module
else:
# Otherwise, dynamically import the module from the module path
module = import_module(spec_or_module.module)
# If the imported module is actually a `Function` return it as it is
if isinstance(module, types.FunctionType):
return module
# Finally return the initialized module with params from the spec as well
# as those passed as **kwargs from the code
# Add the `submodules` argument to the module init call if it exists in the
# spec.
if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
kwargs["submodules"] = spec_or_module.submodules
try:
return module(
*args, **spec_or_module.params if hasattr(spec_or_module, "params") else {}, **kwargs
)
except Exception as e:
# improve the error message since we hide the module name in the line above
import sys
tb = sys.exc_info()[2]
raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
sys.exc_info()[2]
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from contextlib import nullcontext
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import make_viewless_tensor
class TransformerBlock(MegatronModule):
"""Transformer class."""
def __init__(
self,
config: TransformerConfig,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True,
post_process=True,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.self_attn_mask_type = self_attn_mask_type
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
# required for pipeline parallel schedules
self.input_tensor = None
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
# TODO: Maybe we can create a build_transformer_block method here instead
self.num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
self._build_layers()
def _build_layers(self):
# Transformer layers.
# @jcasper can we improve how we deal with layer_number?
# currently it's only used in CoreAttention?
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
def build_layer(layer_number):
return TransformerLayer(
config=self.config, layer_number=layer_number, self_attn_mask_type=self.self_attn_mask_type,
)
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
total_num_layers = self.config.num_layers
num_layers_per_virtual_rank = self.num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers / vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(num_layers_per_virtual_rank)]
)
else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * self.num_layers_per_pipeline_rank
else:
offset = 0
# @jcasper why is layer_number using 1 index?
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers_per_pipeline_rank)]
)
# # TODO: add back standalone_embedding_stage
# if self.num_layers == 0:
# # When a standalone embedding stage is used (e.g.,
# # args.standalone_embedding_stage == True), virtual pipeline ranks
# # on pipeline rank 0 will have zero transformer layers assigned to
# # them. This results in the model's input and output tensors to be
# # the same, which will cause failure for certain output tensor
# # optimizations (e.g., pipeline output deallocation). To remedy
# # this, we assign a 'no-op' layer on these ranks, which will
# # disconnect the input tensor from the output tensor.
# self.num_layers = 1
# self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
# else:
# self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = FusedLayerNorm(
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
persist_layer_norm=self.config.persist_layer_norm,
sequence_parallel=self.config.sequence_parallel,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask, rotary_pos_emb):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*args, **kwargs):
x_, *args = args
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, *args, **kwargs)
return x_
return custom_forward
if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.config.recompute_num_layers),
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
rotary_pos_emb,
)
l += self.recompute_num_layers
elif self.config.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers_per_pipeline_rank):
if l < self.config.recompute_num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
rotary_pos_emb,
)
else:
hidden_states = custom(l, l + 1)(hidden_states, attention_mask, rotary_pos_emb)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, inference_params=None, rotary_pos_emb=None):
# hidden_states (float): [s, b, h]
# attention_mask (bool): [1, 1, s, s]
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True,)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=self.config.fp8_margin,
interval=self.config.fp8_interval,
fp8_format=transformer_engine.common.recipe.Format.E4M3
if self.config.fp8_e4m3 else
transformer_engine.common.recipe.Format.HYBRID,
fp8_amax_compute_algo=self.config.fp8_amax_compute_algo,
fp8_amax_history_len=self.config.fp8_amax_history_len
)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe
)
else:
fp8_context = nullcontext()
with rng_context and fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb)
else:
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb)
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable
import torch
import torch.nn.functional as F
from megatron.core import ModelParallelConfig
from megatron.core.utils import init_method_normal, scaled_init_method_normal
@dataclass
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
Attributes:
# model architecture
num_layers (int): Number of transformer layers in a transformer block.
hidden_size (int): Transformer hidden size.
ffn_hidden_size (int): Transformer Feed-Forward Network hidden size.
This is set to 4*hidden_size if not provided. Defaults to None.')
num_attention_heads (int): Number of transformer attention heads.
kv_channels (int): Projection weights dimension in multi-head attention.
This is set to hidden_size // num_attention_heads if not provided.
Defaults to None.
hidden_dropout (float): Dropout probability for transformer hidden state. Defaults to 0.1.
attention_dropout (float): Post attention dropout probability. Defaults to 0.1.
fp32_residual_connection (bool): If true, move residual connections to fp32.
apply_residual_connection_post_layernorm (bool): If true, uses the original BERT residule connection ordering.
Defaults to False.
layernorm_epsilon (float): Layernorm epsilon. Defaults to 1e-5.
layernorm_zero_centered_gamma (bool): if set to 'True', the LayerNorm is adjusted to center the gamma values
around 0. This improves numerical stability. Defaults to False.
add_bias_linear (bool): Include a bias term in all linear layers (QKV projections, after core attention, and two
in MLP layer). Default is True.
gated_linear_unit (bool): Use a gated linear unit for the first linear layer in the MLP. Defaults to False.
activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu.
# initialization
init_method (Callable): Method to initialize weights. Note that bias is always set to
zero. Should be a function that takes a single Tensor and
initializes it. Defaults to
megatron.core.utils.init_method_normal(init_method_std) which is
torch.nn.init.normal_ with mean=0.0 and std=init_method_Std.
output_layer_init_method (Callable): Method to initialize weights of the output layer of
both attention and MLP blocks. Defaults to
megatron.core.utils.scaled_init_method_normal(init_method_std)
which is torch.nn.init.normal_ with mean=0.0 and
std=init_method_std / math.sqrt(2.0 * num_layers).
init_method_std (float): Standard deviation of the zero mean normal for the default
initialization method, not used if init_method and
output_layer_init_method are provided. Defaults to 0.02.
# mixed-precision
apply_query_key_layer_scaling (bool): If true, scale Q * K^T by 1 / layer-number. Defaults to True.
attention_softmax_in_fp32 (bool): If true, run attention masking and softmax in fp32.
This should be true if apply_query_key_layer_scaling is true.
# fusion
bias_gelu_fustion (bool): If true, fuses bias and gelu. Defaults to False.
masked_softmax_fusion (bool): If true, uses softmax fusion.
persist_layer_norm (bool): If true, uses the persistent fused layer norm kernel.
This kernel only supports a fixed set of hidden sizes.
Defaults to False.
bias_dropout_fusion (bool): If true, uses bias dropout fusion.
# activation recomputation
recompute_granularity (str): megatron-core supports 'selective' activation checkpointing where only the memory
intensive part of attention is checkpointed. These memory intensive activations
are also less compute intensive which makes activation checkpointing more efficient
for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer
Models: https://arxiv.org/abs/2205.05198 for more details. 'full' will checkpoint
the entire transformer layer. Must be 'selective' or 'full'. Defaults to None.
recompute_method (str): uniform will uniformly divide the total number of transformer layers in a transformer
block and recompute the input activation of each divided chunk at the specified
granularity. block will recompute the input activations for only a set number of
transformer layers per pipeline stage. The rest of the layers in the pipeline stage
will not have any activations recomputed. Must be 'uniform' or 'block'. Defaults to
None.
recompute_num_layers (int): When recompute_method is uniform, recompute_num_layers is the number of transformer
layers in each uniformly divided recompute unit. When recompute_method is block,
recompute_num_layers is the number of transformer layers to recompute within each
pipeline stage. Defaults to None.
distribute_saved_activations (bool): If true, distribute recomputed activations across the model parallel
group. Defaults to None.
# fp8 related (via Transformer Engine). For detailed info, refer the the Transformer Engine docs at
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html
fp8 (bool): Enables the use of FP8 precision through Transformer Engine.
fp8_e4m3 (bool): Enables the use of FP8 tensors in e4m3 format for both forward and backward passes.
fp8_margin (int): Enables the use of FP8 tensors in e4m3 format in the forward pass and e5m2 format in the
backward pass.
fp8_interval (int): Controls how often the scaling factor is recomputed.
fp8_amax_history_len (int): The length of the amax history window used for scaling factor computation.
fp8_amax_compute_algo (str): Algorithm used for choosing the `amax` value for the scaling factor computation.
There are 2 predefined choices: `max` chooses the largest `amax` in the history
window, while `most_recent` always chooses the most recently seen value.
"""
# model architecture
num_layers: int = 0
hidden_size: int = 0
num_attention_heads: int = 0
num_query_groups: int = None
ffn_hidden_size: int = None
kv_channels: int = None
hidden_dropout: float = 0.1
attention_dropout: float = 0.1
fp32_residual_connection: bool = False
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
layernorm_epsilon: float = 1e-5
layernorm_zero_centered_gamma: bool = False
add_bias_linear: bool = True
gated_linear_unit: bool = False
activation_func: Callable = F.gelu
num_moe_experts: int = None
rotary_interleaved: bool = False
# initialization
init_method: Callable = None
output_layer_init_method: Callable = None
init_method_std: float = 0.02
# mixed-precision
apply_query_key_layer_scaling: bool = True
attention_softmax_in_fp32: bool = True
# communication
# fusion
bias_activation_fusion: bool = False
masked_softmax_fusion: bool = False
persist_layer_norm: bool = False
bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
# activation recomputation
recompute_granularity: str = None
recompute_method: str = None
recompute_num_layers: int = None
distribute_saved_activations: bool = None
# fp8 related
fp8: bool = True
fp8_e4m3: bool = False
fp8_margin: int = 0
fp8_interval: int = 1
fp8_amax_history_len: int = 1
fp8_amax_compute_algo: str = "most_recent"
# experimental section (TODO: move to apt. section above once stable)
normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm"
# MoE related
moe_router_load_balancing_type: str = "aux_loss"
moe_router_topk: int = 2
moe_grouped_gemm: bool = False
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss
moe_input_jitter_eps: float = None
moe_token_dropping: bool = False # TODO: Support token dropping.
def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
super().__post_init__()
if self.fp16 and self.bf16:
raise ValueError(f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.')
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
if self.kv_channels is None:
self.kv_channels = self.hidden_size // self.num_attention_heads
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.recompute_granularity is not None:
if not self.recompute_granularity in ['full', 'selective']:
raise ValueError(
f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".'
)
if self.recompute_method is not None:
if not self.recompute_method in ['block', 'uniform']:
raise ValueError(f'recompute_method: {self.recompute_method} must be "block" or "uniform".')
elif self.recompute_granularity != 'selective':
raise ValueError(
f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"'
)
if self.recompute_num_layers is None:
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} so recompute_num_layers must be between '
f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}'
)
if self.distribute_saved_activations and self.sequence_parallel_enabled:
raise ValueError(
f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel_enabled}'
)
if self.virtual_pipeline_model_parallel_size is not None:
if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
raise ValueError(
f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}'
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.bias_activation_fusion:
if self.activation_func not in [F.gelu, F.silu]:
raise ValueError(
"When bias_activation_fusion is True, activation function should be either gelu or swiglu"
)
if self.activation_func == F.gelu and not self.add_bias_linear:
raise ValueError(
"When bias_activation_fusion is True and activation function is gelu, add_bias_linear must also be True."
)
if self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(self.init_method_std, self.num_layers)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.enums import AttnType, AttnMaskType
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.mlp import MLP
from megatron.core.utils import make_viewless_tensor
from megatron.core.transformer.custom_layers.transformer_engine import TELayerNorm
class TransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, self_attn_mask_type=AttnMaskType.padding,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.layer_number = layer_number
self.self_attn_mask_type = self_attn_mask_type
# Layernorm on the input data.
# TODO: add pytorch only layernorm
self.input_layernorm = TELayerNorm(
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
persist_layer_norm=self.config.persist_layer_norm,
sequence_parallel=self.config.sequence_parallel,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
)
# Self attention.
self.self_attention = SelfAttention(
config=self.config,
layer_number=layer_number,
attn_mask_type=self_attn_mask_type,
)
# Layernorm on the attention output
self.post_self_attn_layernorm = TELayerNorm(
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
persist_layer_norm=self.config.persist_layer_norm,
sequence_parallel=self.config.sequence_parallel,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
)
# MLP
self.mlp = MLP(config=self.config)
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
self.bias_dropout_add_func = get_bias_dropout_add(
self.training,
self.config.bias_dropout_fusion
)
# TODO: decide how to do inference_params
def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None, rotary_pos_emb=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
layernorm_output, attention_mask, inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb
)
# Residual connection.
if self.config.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# bias_dropout_add fusion returning fp32 instead of bf16
with self.bias_dropout_add_exec_handler():
layernorm_input = self.bias_dropout_add_func(
attention_output_with_bias, residual, self.config.hidden_dropout
)
# Layer norm post the self attention.
layernorm_output = self.post_self_attn_layernorm(layernorm_input)
# MLP.
mlp_output_with_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.config.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
with self.bias_dropout_add_exec_handler():
output = self.bias_dropout_add_func(
mlp_output_with_bias, residual, self.config.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(inp=output, requires_grad=output.requires_grad, keep_graph=True)
return output
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Utilities for transformer layers."""
from functools import lru_cache
from operator import itemgetter
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union
import torch
from megatron import get_args
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
if get_args().perform_initialization:
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions used throughout Megatron core"""
from functools import reduce
import math
import operator
import torch
from megatron.core import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def get_attr_wrapped_model(model, attr, allow_none=True):
"""Get an attribute from a wrapped model"""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
if allow_none:
def condition(model, attr):
return not hasattr(model, attr)
else:
def condition(model, attr):
return getattr(model, attr, None) is None
while condition(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
def get_model_config(model):
return get_attr_wrapped_model(model, 'config', allow_none=False)
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = helpers
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
from . import indexed_dataset
"""AutoAugment data augmentation policy for ImageNet.
-- Begin license text.
MIT License
Copyright (c) 2018 Philip Popien
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-- End license text.
Code adapted from https://github.com/DeepVoltaire/AutoAugment.
This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
policies.
Reference:
[1] https://arxiv.org/abs/1805.09501
"""
import random
import numpy as np
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageOps
_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
class ImageNetPolicy:
"""Definition of an ImageNetPolicy.
Implements a fixed AutoAugment data augmentation policy targeted at
ImageNet training by randomly applying at runtime one of the 25 pre-defined
data augmentation sub-policies provided in Reference [1].
Usage example as a Pytorch Transform:
>>> transform=transforms.Compose([transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
"""Initialize an ImageNetPolicy.
Args:
fillcolor (tuple): RGB color components of the color to be used for
filling when needed (default: (128, 128, 128), which
corresponds to gray).
"""
# Instantiate a list of sub-policies.
# Each entry of the list is a SubPolicy which consists of
# two augmentation operations,
# each of those parametrized as operation, probability, magnitude.
# Those two operations are applied sequentially on the image upon call.
self.policies = [
SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
]
def __call__(self, img):
"""Define call method for ImageNetPolicy class."""
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
"""Define repr method for ImageNetPolicy class."""
return "ImageNetPolicy"
class SubPolicy:
"""Definition of a SubPolicy.
A SubPolicy consists of two augmentation operations,
each of those parametrized as operation, probability, magnitude.
The two operations are applied sequentially on the image upon call.
"""
def __init__(
self,
operation1,
probability1,
magnitude_idx1,
operation2,
probability2,
magnitude_idx2,
fillcolor,
):
"""Initialize a SubPolicy.
Args:
operation1 (str): Key specifying the first augmentation operation.
There are fourteen key values altogether (see supported_ops below
listing supported operations). probability1 (float): Probability
within [0., 1.] of applying the first augmentation operation.
magnitude_idx1 (int): Integer specifiying the strength of the first
operation as an index further used to derive the magnitude from a
range of possible values.
operation2 (str): Key specifying the second augmentation operation.
probability2 (float): Probability within [0., 1.] of applying the
second augmentation operation.
magnitude_idx2 (int): Integer specifiying the strength of the
second operation as an index further used to derive the magnitude
from a range of possible values.
fillcolor (tuple): RGB color components of the color to be used for
filling.
Returns:
"""
# List of supported operations for operation1 and operation2.
supported_ops = [
"shearX",
"shearY",
"translateX",
"translateY",
"rotate",
"color",
"posterize",
"solarize",
"contrast",
"sharpness",
"brightness",
"autocontrast",
"equalize",
"invert",
]
assert (operation1 in supported_ops) and (
operation2 in supported_ops
), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
assert (
0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
assert (
isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
assert (
isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
# Define a dictionary where each key refers to a specific type of
# augmentation and the corresponding value is a range of ten possible
# magnitude values for that augmentation.
num_levels = _MAX_LEVEL + 1
ranges = {
"shearX": np.linspace(0, 0.3, num_levels),
"shearY": np.linspace(0, 0.3, num_levels),
"translateX": np.linspace(0, 150 / 331, num_levels),
"translateY": np.linspace(0, 150 / 331, num_levels),
"rotate": np.linspace(0, 30, num_levels),
"color": np.linspace(0.0, 0.9, num_levels),
"posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
np.int
),
"solarize": np.linspace(256, 0, num_levels), # range [0, 256]
"contrast": np.linspace(0.0, 0.9, num_levels),
"sharpness": np.linspace(0.0, 0.9, num_levels),
"brightness": np.linspace(0.0, 0.9, num_levels),
"autocontrast": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"equalize": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"invert": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
}
def rotate_with_fill(img, magnitude):
"""Define rotation transformation with fill.
The input image is first rotated, then it is blended together with
a gray mask of the same size. Note that fillcolor as defined
elsewhere in this module doesn't apply here.
Args:
magnitude (float): rotation angle in degrees.
Returns:
rotated_filled (PIL Image): rotated image with gray filling for
disoccluded areas unveiled by the rotation.
"""
rotated = img.convert("RGBA").rotate(magnitude)
rotated_filled = Image.composite(
rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
)
return rotated_filled.convert(img.mode)
# Define a dictionary of augmentation functions where each key refers
# to a specific type of augmentation and the corresponding value defines
# the augmentation itself using a lambda function.
# pylint: disable=unnecessary-lambda
func_dict = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
magnitude * img.size[0] * random.choice([-1, 1]),
0,
1,
0,
),
fillcolor=fillcolor,
),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
0,
0,
1,
magnitude * img.size[1] * random.choice([-1, 1]),
),
fillcolor=fillcolor,
),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * random.choice([-1, 1])
),
"posterize": lambda img, magnitude: ImageOps.posterize(
img, magnitude
),
"solarize": lambda img, magnitude: ImageOps.solarize(
img, magnitude
),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img),
}
# Store probability, function and magnitude of the first augmentation
# for the sub-policy.
self.probability1 = probability1
self.operation1 = func_dict[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
# Store probability, function and magnitude of the second augmentation
# for the sub-policy.
self.probability2 = probability2
self.operation2 = func_dict[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
"""Define call method for SubPolicy class."""
# Randomly apply operation 1.
if random.random() < self.probability1:
img = self.operation1(img, self.magnitude1)
# Randomly apply operation 2.
if random.random() < self.probability2:
img = self.operation2(img, self.magnitude2)
return img
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""BERT Style dataset."""
import numpy as np
import torch
from megatron import (
get_args,
get_tokenizer,
mpu,
print_rank_0
)
from megatron.data.dataset_utils import (
get_samples_mapping,
get_a_and_b_segments,
truncate_segments,
create_tokens_and_tokentypes,
create_masked_lm_predictions
)
class BertDataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed, binary_head):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens
short_seq_prob,
self.seed,
self.name,
self.binary_head)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng,
self.binary_head)
def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng, binary_head):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return train_sample
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0, \
f"num_tokens ({num_tokens}) is greater than " \
"max_seq_length ({max_seq_length})."
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment