Unverified Commit dc003c30 authored by Xuanlei Zhao's avatar Xuanlei Zhao Committed by GitHub
Browse files

[moe] merge moe into main (#4978)

* update moe module
* support openmoe
parent 8993c8a8
import random
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelAMPOptimizer,
HybridParallelModule,
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
get_param_info,
init_pipeline_optimizer,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoeCheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer
PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
clip_grad_norm=clip_grad_norm,
verbose=verbose,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
moe_extra_dp_process_group=moe_extra_dp_process_group,
)
class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
Plugin for Moe Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
"""
def __init__(
self,
tp_size: int,
pp_size: int,
extra_dp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
# sync moe in outer dp group, and sync other param in global dp group
if extra_dp_size > 1:
ep_size = self.dp_size // extra_dp_size
if use_ep_inside:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
else:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
else:
self.moe_extra_dp_group = None
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
)
self.ddp_config = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.zero_config = dict(
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
)
self.max_norm = max_norm
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def get_checkpoint_io(self) -> MoeCheckpintIO:
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_extra_dp_process_group=self.moe_extra_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config,
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
from .config import Config, ConfigException
# from .moe_context import MOE_CONTEXT
__all__ = [
"Config",
"ConfigException",
......
from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.legacy.tensor import ProcessGroup
def _check_sanity():
from colossalai.legacy.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
self.ep_size = ep_size
self.dp_size = dp_size
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
self.ep_group = self.pg.tp_process_group()
self.dp_group = self.pg.dp_process_group()
class MoeContext(metaclass=SingletonMeta):
"""MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.world_size = 1
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = 1
self.min_dp_size = 1
self.aux_loss = None
self.use_kernel_optim = True
self.has_setup = False
self._parallel_info_dict = dict()
@property
def parallel_info_dict(self):
return self._parallel_info_dict
@property
def is_initialized(self):
return self.has_setup
def setup(self, seed: int, use_kernel_optim: bool = True):
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
_check_sanity()
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.world_size = dist.get_world_size()
from colossalai.legacy.core import global_context as gpc
self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
assert (
self.world_size % self.max_ep_size == 0
), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases
# Users can close kernel optimization manually
self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed
moe_set_seed(seed)
self.has_setup = True
def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
"""
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, (
"Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa."
)
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU
# The data parallel size should be calculated
dp_size = 1 if gt_flag else self.max_ep_size // num_experts
ep_size = self.max_ep_size // dp_size
# Calculate the number of experts for each GPU
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
# Don't forget to multiply minimum data parallel size
dp_size *= self.min_dp_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
return num_local_experts, self.parallel_info_dict[ep_size]
def set_kernel_not_use(self):
self.use_kernel_optim = False
def reset_loss(self):
self.aux_loss = 0
def add_loss(self, loss):
self.aux_loss += loss
def get_loss(self):
return self.aux_loss
MOE_CONTEXT = MoeContext()
from functools import reduce
from typing import Any, Tuple
import torch
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
PRECISION_MAP = {
"fp32": (0, torch.float32),
"fp16": (1, torch.float16),
"bf16": (2, torch.bfloat16),
}
@triton.jit
def _llama_act_combine_forward(
X_GATE1,
X_GATE2,
X_UP,
Y,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
Y += row * stride
# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output
tl.store(Y + cols, y, mask=mask)
@triton.jit
def _llama_act_combine_backward(
X_GATE1,
X_GATE2,
X_UP,
X_GATE1_GRAD,
X_GATE2_GRAD,
X_UP_GRAD,
Y_GRAD,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
X_GATE1_GRAD += row * stride
X_GATE2_GRAD += row * stride
X_UP_GRAD += row * stride
Y_GRAD += row * stride
# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
x_up_grad = x_gate2_act * x_gate1
x_gate1_grad = x_gate2_act * x_up
# grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)]
# = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]}
x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))
# Write output
tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)
class LlamaActCombine(torch.autograd.Function):
"""
act(x_gate) * x_up
Args:
x_gate (torch.Tensor): (b, l, 2d) x_gate
x_up (torch.Tensor): (b, l, d) x_up
activation (str): only support swiglu
precision (str): fp32, fp16, bf16
"""
@staticmethod
@custom_fwd
def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
"""
act(x_gate) * x_up
Args:
x_gate (torch.Tensor): (b, l, 2d) x gate
x_up (torch.Tensor): (b, l, d) x up
activation (str): only support swiglu
"""
assert activation == "swiglu", "Only swiglu is supported"
# split x gate
assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
x_gate1 = x_gate1.contiguous()
x_gate2 = x_gate2.contiguous()
if not x_up.is_contiguous():
x_up = x_up.contiguous()
# assert shape
assert x_gate1.shape == x_gate2.shape == x_up.shape
# add ctx for backward
if x_gate.requires_grad:
ctx.save_for_backward(x_gate1, x_gate2, x_up)
# allocate output
y = torch.empty_like(x_up)
M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x_gate.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel
_llama_act_combine_forward[(M,)](x_gate1,
x_gate2,
x_up,
y,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y
@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
# restore from ctx
(x_gate1, x_gate2, x_up) = ctx.saved_tensors
M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps
# init grad
y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
x_gate2), torch.empty_like(x_up)
# enqueue kernel
_llama_act_combine_backward[(M,)](x_gate1,
x_gate2,
x_up,
x_gate1_grad,
x_gate2_grad,
x_up_grad,
y_grad,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
from ._moe_gradient_handler import MoeGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
......@@ -10,6 +9,5 @@ __all__ = [
"DataParallelGradientHandler",
"ZeROGradientHandler",
"PipelineSharedModuleGradientHandler",
"MoeGradientHandler",
"SequenceParallelGradientHandler",
]
......@@ -16,7 +16,6 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from colossalai.context import Config, ConfigException
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.interface import OptimizerWrapper
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
......@@ -36,7 +35,6 @@ from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param
def get_default_parser():
......@@ -323,8 +321,6 @@ def initialize(
if not use_zero:
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif MOE_CONTEXT.is_initialized:
sync_moe_model_param(model)
elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA)
else:
......@@ -377,14 +373,6 @@ def initialize(
"added even though not specified in the configuration",
ranks=[0],
)
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
gradient_handler_cfg = [dict(type="MoeGradientHandler")]
if verbose:
logger.info(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0],
)
elif is_using_sequence():
model = DDP(
model,
......
from .checkpoint import MoeCheckpintIO
from .experts import MLPExperts
from .layers import SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
__all__ = [
"MLPExperts",
"MoeRouter",
"Top1Router",
"Top2Router",
"TopKRouter",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoeCheckpintIO",
]
......@@ -3,62 +3,83 @@ from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
from colossalai.moe.manager import MOE_MANAGER
try:
from colossalai._C import moe
except:
moe = None
MOE_KERNEL = None
def build_moe_if_not_prebuilt():
# load moe kernel during runtime if not pre-built
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
def load_moe():
global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
MOE_KERNEL = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0)
return inputs.unsqueeze(0), None
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=group)
return outputs
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0)
return inputs.squeeze(0), None
if not inputs.is_contiguous():
inputs = inputs.contiguous()
......@@ -66,12 +87,21 @@ class ReduceScatter(torch.autograd.Function):
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class AllToAll(torch.autograd.Function):
......@@ -80,49 +110,78 @@ class AllToAll(torch.autograd.Function):
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs
return inputs, None
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
if not overlap:
dist.all_to_all_single(output, inputs, group=group)
return output, None
else:
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
return output, handle
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0],
None,
None,
)
class MoeDispatch(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
dtype = tokens.dtype
if MOE_KERNEL is None:
load_moe()
if tokens.dtype != torch.float32:
tokens = tokens.to(torch.float32)
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
if expert_input.dtype != dtype:
expert_input = expert_input.to(dtype)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
ctx.dtype = dtype
return expert_input
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
if output_grad.dtype != torch.float32:
output_grad = output_grad.to(torch.float32)
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
if d_tokens.dtype != ctx.dtype:
d_tokens = d_tokens.to(ctx.dtype)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
......@@ -130,42 +189,87 @@ class MoeCombine(torch.autograd.Function):
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
dtype = expert_tokens.dtype
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
fp16_flag = expert_tokens.dtype == torch.float16
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
if expert_tokens.dtype != torch.float32:
expert_tokens = expert_tokens.to(torch.float32)
if MOE_KERNEL is None:
load_moe()
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
if output.dtype != dtype:
output = output.to(dtype)
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
ctx.dtype = dtype
return output
@staticmethod
@custom_bwd
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32)
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
mask, dest_idx)
if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype)
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
return moe.cumsum_sub_one(inputs)
if flag and use_kernel:
if MOE_KERNEL is None:
load_moe()
return MOE_KERNEL.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad * ctx.ep_size
return grad, None
class MoeOutGradScaler(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict_into_model,
save_config_file,
save_state_dict_shards,
)
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
class MoeCheckpintIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
) -> None:
assert zero_stage in [
0,
1,
2,
], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
super().__init__(dp_group, pp_group, tp_group, zero_stage)
self.parallel = MOE_MANAGER.parallel
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
for name, param in state_dict.items():
if ".experts." in name:
if name in dict(model.named_parameters()):
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
state_dict[name] = param
dist.barrier()
return state_dict
def _model_sharder(
self,
state_dict: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
for name, param in state_dict.items():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
state_dict = torch.load(checkpoint)
state_dict = self.pre_load_model(model, state_dict)
model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [deepcopy(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, state_dict, group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
new_state_dict.update(o)
state_dict = new_state_dict
dist.barrier()
return state_dict
def save_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
dist.barrier()
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
The filenames are in the form of "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
dist.barrier()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
raise NotImplementedError()
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
raise NotImplementedError()
def save_sharded_optimizer(
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
):
raise NotImplementedError()
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
raise NotImplementedError()
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info
if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
class MLPExperts(nn.Module):
"""
SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
Args:
num_experts (int): The number of experts
hidden_size (int): The hidden size of MLP
intermediate_size (int): The intermediate size of MLP
expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
activation (optional): The activation function of MLP
drop_rate (float, optional): The drop rate of MLP
gated (bool, optional): Whether to use gated MLP
use_kernel (bool, optional): Whether to use kernel optimization
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
expert_parallel: Optional[str] = None,
activation: Optional[Callable] = None,
drop_rate: Optional[float] = 0,
gated: Optional[bool] = False,
use_kernel: Optional[bool] = False,
):
super().__init__()
assert expert_parallel in ["EP", "TP", None]
self.expert_parallel = expert_parallel
self.num_total_experts = num_experts
self.gated = gated
self.use_kernel = use_kernel
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False)
# get settings for different parallel
self.ep_size = get_ep_size(self)
if expert_parallel == "TP":
intermediate_size = intermediate_size // self.ep_size
num_experts = self.num_total_experts
else:
num_experts = self.num_local_experts
else:
self.num_local_experts = self.num_total_experts
self.ep_size = 1
if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
self.act_name = activation
self.act = get_activation(activation)
self.drop = nn.Dropout(p=drop_rate)
if expert_parallel is not None:
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)
# init param
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
# expert param should be different
if self.expert_parallel is not None:
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx:
if self.gated:
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
else:
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
def forward(
self,
x: torch.Tensor,
param_slice: Tuple[slice] = (slice(None),),
use_sparse: bool = True,
) -> torch.Tensor:
"""
forward: hidden_size --> intermediate_size --> hidden_size
Args:
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
"""
x = MoeInGradScaler.apply(x, self.ep_size)
e = x.size(1)
h = x.size(-1)
x = x.transpose(0, 1)
inshape = x.shape
x = x.reshape(e, -1, h)
if self.use_kernel and use_sparse:
seq_len = x.shape[1]
with torch.no_grad():
mask = x[:, :, 0] != 0.0
mask = torch.sum(mask, dim=-1)
x_list = []
for i in range(e):
x_list.append(x[i, :mask[i]])
x = x_list
if self.gated:
x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
else:
x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
else:
x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
x = [self.act(x[i]) for i in range(e)]
x = [self.drop(x[i]) for i in range(e)]
x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
if self.use_kernel and use_sparse:
for i in range(e):
x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size)
return x
import dataclasses
import math
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
class SparseMLP(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in `Switch Transformer paper`_.
'Gaussian' can be found in `ViT-MoE paper`_.
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in `Microsoft paper`_.
residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
.. _Switch Transformer paper:
https://arxiv.org/abs/2101.03961
.. _ViT-MoE paper:
https://arxiv.org/abs/2106.05974
.. _Microsoft paper:
https://arxiv.org/abs/2201.05596
"""
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_capacity_factor_train: Optional[float] = 1.25,
router_capacity_factor_eval: Optional[float] = 2.0,
router_min_capacity: Optional[int] = 4,
router_noisy_policy: Optional[str] = None,
router_drop_tks: Optional[bool] = True,
mlp_activation: Optional[str] = None,
mlp_gated: Optional[bool] = False,
enable_load_balance: Optional[bool] = False,
load_balance_tolerance: Optional[float] = 0.1,
load_balance_beam_width: Optional[int] = 8,
load_balance_group_swap_factor: Optional[float] = 0.4,
enable_kernel: Optional[bool] = False,
enable_comm_overlap: Optional[bool] = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.gated = mlp_gated
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
# moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
router_cls = get_router_cls(router_top_k)
self.topk = router_top_k
self.router: MoeRouter = router_cls(
capacity_factor_train=router_capacity_factor_train,
capacity_factor_eval=router_capacity_factor_eval,
min_capacity=router_min_capacity,
noisy_func=noisy_func,
drop_tks=router_drop_tks,
)
# gate
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
# moe experts
self.experts = MLPExperts(
num_experts=self.num_experts,
expert_parallel=self.expert_parallel,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
activation=mlp_activation,
gated=mlp_gated,
use_kernel=self.enable_kernel,
)
# get parallel settings
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
self.dp_group = None
self.num_local_experts = self.experts.num_local_experts
# load balance
self.enable_load_balance = enable_load_balance
if self.enable_load_balance == True:
self.load_balancer = LoadBalancer(
experts=self.experts,
gate=self.gate_weight,
local_expert_num=self.num_local_experts,
expert_num=self.num_experts,
ep_group=self.ep_group,
dp_group=self.dp_group,
tolerance=load_balance_tolerance,
beam_width=load_balance_beam_width,
group_swap_factor=load_balance_group_swap_factor,
)
# init param
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
Returns:
torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
"""
# reshape the input tokens
tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# update expert load
if self.enable_load_balance == True:
with torch.no_grad():
# TODO: optimize computation
expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
# TODO: bincount introduces synchronize, fix it
expert_load = torch.bincount(expert_load.view(-1))
self.load_balancer.update_load(expert_load)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
else:
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
"Please use Experts build function.")
if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
expert_out = self.experts(expert_in)
return expert_out
def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
"""
Expert Parallel
Args:
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output
else:
@dataclasses.dataclass
class Capsule:
data: torch.Tensor
handle: Any = None
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
output = torch.empty_like(dispatch_data)
offset = 0
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None
# all2all last output
if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
_expert_out = None
# all2all next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True))
# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(data=self.experts(expert_in.data), handle=None)
expert_in = None
if _expert_in is not None:
expert_in = _expert_in
_expert_in = None
return output
def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
"""
without overlap:
| C |
| A | | R |
with overlap:
| C1 || C2 || C3 || C4 |
| A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 |
where C is computation, A is all gather, R is reduce scatter.
Args:
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
return expert_out
else:
@dataclasses.dataclass
class Capsule:
data: torch.Tensor
handle: Any
indices: Tuple
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)
def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
return (slice(idx * chunk_size, (idx + 1) * chunk_size),)
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[expert_out.indices] = expert_out.data
expert_out = None
# reduce scatter last output
if _expert_out is not None:
expert_out = Capsule(
*ReduceScatter.apply(_expert_out.data, self.ep_group, True),
indices=_expert_out.indices,
)
_expert_out = None
# all gather next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(
*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
indices=get_chunk_slice(i, chunk_size),
)
# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(
self.experts(expert_in.data, expert_in.indices),
handle=None,
indices=expert_in.indices,
)
expert_in = None
if _expert_in is not None:
expert_in = _expert_in
_expert_in = None
return output
def apply_load_balance(model: nn.Module, optim: Any) -> None:
"""
apply load balance to every experts in the model
"""
def _apply_recursive(module: nn.Module):
for _, sub_module in module.named_children():
if isinstance(sub_module, SparseMLP):
if sub_module.enable_load_balance == True:
sub_module.load_balancer.balance_load(optim)
_apply_recursive(sub_module)
torch.cuda.empty_cache()
_apply_recursive(model)
torch.cuda.empty_cache()
from copy import deepcopy
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
from colossalai.moe.experts import MLPExperts
from colossalai.moe.manager import MOE_MANAGER
from colossalai.zero.low_level import LowLevelZeroOptimizer
class LoadBalancer:
def __init__(
self,
experts: MLPExperts,
gate: nn.Parameter,
local_expert_num: int,
expert_num: int,
ep_group: ProcessGroup,
dp_group: ProcessGroup,
tolerance: Optional[float] = 0.1,
beam_width: Optional[int] = 8,
group_swap_factor: Optional[float] = 0.4,
) -> None:
self.experts: MLPExperts = experts
self.gate: nn.Parameter = gate
self.moe_ep_group: ProcessGroup = ep_group
self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks
self.moe_dp_group: ProcessGroup = dp_group
self.tolerance = tolerance
self.beam_width = beam_width
self.group_swap_factor = group_swap_factor
self.local_expert_num = local_expert_num
self.expert_num = expert_num
self.local_load = None
# TODO: use a global process group mesh
pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size
global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size)
self.global_dp_group = global_dp_group.get_group_along_axis(1)
self.global_dp_rank = dist.get_rank(self.global_dp_group)
self.global_dp_size = dist.get_world_size(self.global_dp_group)
def _clear_load(self) -> None:
self.local_load = None
def _sync_load(self) -> Tensor:
new_load = self.local_load.clone().detach()
# all reduce load between ep group
dist.all_reduce(new_load, group=self.moe_ep_group)
# all reduce load between dp group
dist.all_reduce(new_load, group=self.moe_dp_group)
return new_load
@staticmethod
def _get_diff_from_avg(data: List, group: int, avg: float) -> float:
return abs(sum(data[group]) / len(data[group]) - avg)
@staticmethod
def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None:
data[group_i][index_i], data[group_j][index_j] = (
data[group_j][index_j],
data[group_i][index_i],
)
@staticmethod
def _normalize_data(data: List) -> List:
max_value = max(max(sublist) for sublist in data)
data = [[i / max_value for i in sublist] for sublist in data]
return data
@staticmethod
def _get_swap_loss(
group_swap_factor: float,
swap_list: List,
group_i: int,
index_i: int,
group_j: int,
index_j: int,
) -> float:
"""
Get swap loss. The swap loss is used to avoid the situation that
the same index is swapped twice and the same group is swapped for multiple times.
"""
swap_loss = 0
for swap in swap_list:
for group_id, index_id in zip([group_i, group_j], [index_i, index_j]):
# the group has been swapped
if group_id in [swap[0], swap[2]]:
# the index has been swapped
# we want to avoid the situation that the same index is swapped twice
if index_id in [swap[1], swap[3]]:
swap_loss += 1e5
# the index has not been swapped
# this is acceptable but as less as possible
else:
swap_loss += group_swap_factor
return swap_loss
@staticmethod
def _check_convergence(data: List, avg: float, tolerance: float):
"""
Check whether the data is converged after swap.
"""
for sublist in data:
if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg:
return False
return True
def _beam_search(
self,
inputs: Tuple[List, float, List],
beam_width: int,
avg: float,
group_swap_factor: float,
) -> List:
"""
Beam search for the best swap combination.
Specifically, we swap two elements from two groups and calculate the score.
The score is the difference between the origin group sum and the new group sum.
The larger the score, the better the swap combination.
Args:
inputs (Tuple): (data, origin_score, swap_list)
beam_width (int): beam width for beam search
avg (float): average value of the data
group_swap_factor (float): group loss for group swap loss
Returns:
List: results list
"""
data, origin_score, swap_list = inputs
results = []
group_num = len(data)
group_size = len(data[0])
origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)]
for group_num_i in range(group_num):
for group_size_i in range(group_size):
for group_num_j in range(group_num_i + 1, group_num):
for group_size_j in range(group_size):
new_data = deepcopy(data)
# calculate origin group sum
origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j]
# swap data
self._swap_data(
new_data,
group_num_i,
group_size_i,
group_num_j,
group_size_j,
)
# calculate new group sum
new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg(
new_data, group_num_j, avg
)
# caculate score
new_score = origin_diff - new_diff
if new_score > 0:
new_score = origin_score + new_score
# get swap loss
swap_loss = self._get_swap_loss(
group_swap_factor,
swap_list,
group_num_i,
group_size_i,
group_num_j,
group_size_j,
)
new_score = new_score - swap_loss
# update swap list
new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)]
results.append((new_data, new_score, new_swap_list))
# sort results
results.sort(key=lambda x: x[1], reverse=True)
# select top k results
results = results[:beam_width]
return results
def _load_to_list(self, load: Tensor) -> List:
load_len = len(load)
assert load_len % self.local_expert_num == 0
load_list = []
tmp_list = []
for i in range(len(load)):
tmp_list.append(float(load[i]))
if (i + 1) % self.local_expert_num == 0:
load_list.append(tmp_list)
tmp_list = []
return load_list
def _search_balance(
self,
data: List,
tolerance: Optional[float] = 0.1,
beam_width: Optional[int] = 8,
group_swap_factor: Optional[float] = 0.4,
return_swapped_data: Optional[bool] = False,
) -> Tuple[List, List]:
"""
Search for the best swap combination to balance the data within the specified tolerance.
And return the balanced data and the swap list. The swap list is used to record the swap.
The swap list is a list of tuples. Each tuple is a swap operation.
Args:
data (List): expert load list.
E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]]
This means there are 4 devices and each devices has 2 experts.
The value is the load of the expert.
tolerance (float): tolerance for balance.
beam_width (int): beam width for beam search.
group_swap_factor (float): group swap factor for group swap loss.
The bigger it is, the less times a group will be swapped.
return_swapped_data (bool): whether to return the swapped data.
Returns:
Tuple: (balanced data, swap list).
The swap list is a list of tuples. Each tuple is a swap operation.
E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means
the first expert of the first device is swapped with the first expert
of the second device.
"""
norm_data = self._normalize_data(data)
avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data)
results = [(norm_data, 0, [])]
stop_flag = False
while stop_flag == False:
new_results = []
best_score = results[0][1]
for i in range(len(results)):
new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor))
if len(new_results) == 0:
stop_flag = True
break
new_results.sort(key=lambda x: x[1], reverse=True)
new_best_score = new_results[0][1]
if new_best_score == best_score:
stop_flag = True
break
new_results = new_results[:beam_width]
results = new_results
for i in results:
if self._check_convergence(results[0][0], avg, tolerance):
stop_flag = True
break
swap_list = results[0][2]
if return_swapped_data:
out = deepcopy(data)
for swap in swap_list:
self._swap_data(out, *swap)
return out, swap_list
else:
return swap_list
@staticmethod
def _swap_expert_single_tensor(
weight: nn.Parameter,
expert_idx: int,
comm_group: ProcessGroup,
send_first: bool,
comm_rank: int,
):
# exchange weight
local_weight = weight.data[expert_idx]
new_weight = torch.empty_like(local_weight)
if send_first:
dist.send(local_weight, dst=comm_rank, group=comm_group)
dist.recv(new_weight, src=comm_rank, group=comm_group)
else:
dist.recv(new_weight, src=comm_rank, group=comm_group)
dist.send(local_weight, dst=comm_rank, group=comm_group)
weight.data[expert_idx] = new_weight
def _swap_expert_param_and_optim(
self,
weight: nn.Parameter,
expert_idx: int,
comm_group: ProcessGroup,
send_first: bool,
comm_rank: int,
optim: LowLevelZeroOptimizer,
):
# need to update master and working param if master param exists
# else just update working param
if weight in optim.optim.state:
master_weight_ptr = None
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
else:
master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
working_weight_ptr = weight
exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
# exchange weight
self._swap_expert_single_tensor(
working_weight_ptr,
expert_idx,
comm_group,
send_first,
comm_rank,
)
if master_weight_ptr is not None:
# TODO: exchange master weight, skip for now
# master weight is shared by dp group
tmp = working_weight_ptr.view(-1).split(
working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group)
)[dist.get_rank(self.moe_dp_group)]
master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype))
# exchange optim
self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank)
self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank)
def _gather_global_dp_group(self, data: Tensor) -> Tensor:
data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)]
dist.all_gather(data_list, data, group=self.global_dp_group)
data_list = torch.cat(data_list, dim=0)
return data_list
def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None:
"""
Swap moe param and optim.
We use different strategies to swap expert and gate.
For expert, we exchange the param and optim of the expert by p2p.
For gate, we all gather the gate choose the part we want.
Args:
swap_list (List)
optim (LowLevelZeroOptimizer)
"""
# get all experts weights
local_rank = dist.get_rank(self.moe_ep_group)
if self.experts.gated:
weight_list = [self.experts.wi_up, self.experts.wi_gate]
else:
weight_list = [self.experts.wi]
weight_list.append(self.experts.wo)
# gate optim should be obtained first
gate_shape = self.gate.shape
# get master weight and optim
master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
# gather
global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape)
global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape)
global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape)
assert (
self.gate.shape
== global_master_gate_weight.shape
== global_gate_exp_avg.shape
== global_gate_exp_avg_sq.shape
)
for swap in swap_list:
source_group, source_idx, target_group, target_idx = swap
source_rank = self.moe_ep_ranks[source_group]
target_rank = self.moe_ep_ranks[target_group]
# exchange expert
if local_rank in [source_group, target_group]:
for weight in weight_list:
if local_rank == source_group:
self._swap_expert_param_and_optim(
weight,
source_idx,
self.moe_ep_group,
True,
target_rank,
optim,
)
elif local_rank == target_group:
self._swap_expert_param_and_optim(
weight,
target_idx,
self.moe_ep_group,
False,
source_rank,
optim,
)
# exchange gate
source_expert_pos = source_group * self.local_expert_num + source_idx
target_expert_pos = target_group * self.local_expert_num + target_idx
for gate in [
self.gate,
global_master_gate_weight,
global_gate_exp_avg,
global_gate_exp_avg_sq,
]:
origin_source = gate.data[source_expert_pos].clone().detach()
origin_target = gate.data[target_expert_pos].clone().detach()
gate.data[source_expert_pos], gate.data[target_expert_pos] = (
origin_target,
origin_source,
)
# update gate
global_master_gate_weight = global_master_gate_weight.view(-1).split(
global_master_gate_weight.numel() // self.global_dp_size
)[self.global_dp_rank]
master_gate_weight.data.copy_(global_master_gate_weight)
global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[
self.global_dp_rank
]
gate_exp_avg.data.copy_(global_gate_exp_avg)
global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(
global_gate_exp_avg_sq.numel() // self.global_dp_size
)[self.global_dp_rank]
gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq)
@torch.no_grad()
def update_load(self, load: Tensor) -> None:
if len(load) != self.expert_num:
padding_size = self.expert_num - len(load)
padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device)
load = torch.cat((load, padding), dim=0)
if self.local_load is None:
self.local_load = load
else:
self.local_load += load
@torch.no_grad()
def balance_load(self, optim: LowLevelZeroOptimizer) -> None:
# prepare load
load = self._sync_load()
load = self._load_to_list(load)
# search balance
swap_list = self._search_balance(load)
if dist.get_rank() == 0:
if len(swap_list) > 0:
print(f"[Load Balance] Applying expert swap...")
else:
print(f"[Load Balance] Invalid swap, skip...")
# swap expert and gate
self._swap_moe_param(swap_list, optim)
# clear load
self._clear_load()
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.registry import LOSSES
from colossalai.moe.manager import MOE_MANAGER
@LOSSES.register_module
class MoeCrossEntropyLoss(_Loss):
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
......@@ -45,11 +43,10 @@ class MoeCrossEntropyLoss(_Loss):
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
main_loss = self.loss(*args)
aux_loss = MOE_CONTEXT.get_loss()
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss
@LOSSES.register_module
class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
......@@ -77,5 +74,5 @@ class MoeLoss(_Loss):
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
"""
main_loss = self.loss_fn(*args, **kwargs)
aux_loss = MOE_CONTEXT.get_loss()
aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss
from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
class MoeManager(metaclass=SingletonMeta):
"""MoE manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
def __init__(self):
self.parallel = None
self.seed = None
self.mode = None
self.use_ep_inside = None
self.world_size = None
self._parallel_info_dict = dict()
# router
self.router_aux_loss = []
self.router_z_loss = []
# fixed mode
self.pp_size = None
self.dp_size = None
self.ep_size = None
# dynamic mode
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = None
self.has_setup = False
@property
def parallel_info_dict(self):
return self._parallel_info_dict
@property
def is_initialized(self):
return self.has_setup
def setup(
self,
seed: int,
parallel: str = None,
mode: str = "dynamic",
max_ep_size: int = 8,
fixed_dp_size: int = 0,
fixed_ep_size: int = 0,
fixed_pp_size: int = 0,
use_ep_inside: bool = True,
) -> None:
"""
Setup MoE distributed context.
Args:
seed (int): Random seed. Defaults to 42.
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
In fixed mode, the ep size and dp size is fixed.
In dynamic mode, the ep size and dp size will be changed according to num experts.
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
"""
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
self.seed = seed + dist.get_rank()
self.parallel = parallel
self.use_ep_inside = use_ep_inside
self.world_size = dist.get_world_size()
# init by mode
self.mode = mode
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, self.world_size)
else:
assert (fixed_dp_size > 0 and fixed_ep_size > 0
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0"
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int)
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size
self.has_setup = True
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group.
Parameters
----------
num_experts : int
The number experts
Returns
-------
int, MoeParallelInfo
number of local experts, the MoeParallelInfo of the current ep_size
"""
if self.mode == "dynamic":
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number"
" is not a multiple of ep size or vice versa.")
dp_size = 1 if gt_flag else self.world_size // num_experts
ep_size = min(self.world_size // dp_size, self.max_ep_size)
dp_size = self.world_size // ep_size
pp_size = 1
else:
dp_size = self.dp_size
ep_size = self.ep_size
pp_size = self.pp_size
# Calculate the number of experts for each GPU
if use_tp:
num_local_experts = num_experts
else:
if self.mode == "dynamic":
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
else:
num_local_experts = num_experts // ep_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
if dist.get_rank() == 0:
if self.use_ep_inside:
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
else:
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
return num_local_experts, self.parallel_info_dict[ep_size]
def reset_loss(self):
self.router_aux_loss, self.router_z_loss = [], []
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
self.router_aux_loss.append(aux_loss)
self.router_z_loss.append(z_loss)
def get_loss(self):
cur_loss = self.router_aux_loss, self.router_z_loss
return cur_loss
def get_parallel(self):
return self.parallel
MOE_MANAGER = MoeManager()
import math
from abc import ABC
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER
from colossalai.utils import get_current_device
class MoeRouter(nn.Module, ABC):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._aux_loss = None
self._z_loss = None
self.use_kernel = use_kernel
def get_capacity(self, logits_shape):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return int(capacity)
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
"""Computes auxiliary load balancing loss as in Switch Transformer.
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
implements the loss function presented in equations (4) - (6). It aims to
penalize those cases where the routing between experts is unbalanced.
Args:
router_probs: Probability assigned to each expert per token. Shape:
<float32>[num_groups, tokens_per_group, num_experts].
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
indices identifying the top num_selected_experts for a given token.
"""
assert self._aux_loss is None
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
# Shape: [num_groups, tokens_per_group, num_experts]
expert_mask = expert_mask.max(dim=-2)[0]
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
self._aux_loss = aux_loss
def set_z_loss(self, router_logits: torch.Tensor):
"""Compute router z-loss.
The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.
Args:
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
"""
assert self._z_loss is None
if router_logits.dim() == 2:
router_logits = router_logits.unsqueeze(0)
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
num_groups, tokens_per_group, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, dim=-1)
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
self._z_loss = z_loss
def pop_router_loss(self) -> torch.Tensor:
assert self._aux_loss is not None
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
self._aux_loss = None
self._z_loss = None
class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about Switch Transformer of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
# caculate router loss
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
elif self.select_policy == "first":
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return probs, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# caculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return probs, mask, dest_idx, num_experts * capacity
else:
# >>> original code
# weight1 = mask1 * probs.type_as(inputs)
# weight2 = mask2 * probs.type_as(inputs)
# rank1_sc = F.one_hot(rank1, num_classes=capacity)
# rank2_sc = F.one_hot(rank2, num_classes=capacity)
# cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
# cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
# cb_weight = cb_weight1 + cb_weight2
# sec_mask = cb_weight.bool()
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
return cb_weight, sec_mask
class TopKRouter(MoeRouter):
"""Masked matmul router using tokens choose top-k experts assignment.
NOTE: this is modified from flaxformer.
This router uses the same mechanism as in Switch Transformer
(https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
sorted by router_probs and then routed to their choice of expert until the
expert's expert_capacity is reached. There is no guarantee that each token is
processed by an expert, or that each expert receives at least one token.
Attributes:
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
"""
def __init__(self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
drop_tks)
def forward(
self,
router_probs: torch.Tensor,
expert_capacity: int,
) -> Tuple:
"""Computes masks for the top-k experts per token.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
# TODO: add parallel group
num_groups, _, num_experts = router_probs.shape
# Top-k router probability and corresponding expert indices for each token.
# Shape: [num_groups, tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
self.set_aux_loss(router_probs, expert_index, num_experts)
self.pop_router_loss()
# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
expert_index = torch.transpose(expert_index, 1, 2)
# Shape: [num_groups, num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(num_groups, -1)
# Create mask out of indices.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 1, 2)
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# Shape: [num_groups, tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=2)[0]
# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
return combine_array, dispatch_mask
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
if not grouped:
if top_k == 1:
return Top1Router
elif top_k == 2:
return Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
else:
return TopKRouter
import contextlib
from typing import Any, Callable, Dict, List
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.utils import get_current_device
from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data.clone()
......@@ -56,16 +61,117 @@ class UniformNoiseGenerator:
def autocast_softmax(logit: torch.Tensor, dim: int):
if logit.dtype != torch.float32:
logit = logit.float()
return F.softmax(logit, dim=dim)
return F.softmax(logit, dim=dim, detype=torch.float32)
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
mep_size = MOE_CONTEXT.max_ep_size
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % mep_size == 0:
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
if noise_type is None:
return None
elif noise_type == "Jitter":
noisy_func = UniformNoiseGenerator()
elif noise_type == "Gaussian":
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
raise NotImplementedError("Unsupported input noisy policy")
return noisy_func
def get_activation(act: str) -> Callable:
if act is None or act == "relu":
return torch.nn.ReLU()
elif act == "gelu":
return torch.nn.GELU()
elif act == "swiglu":
return SwiGLU
else:
raise NotImplementedError("Unsupported activation function")
def SwiGLU(x):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[-1]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = torch.split(x, size // 2, -1)
return x1 * (x2 * torch.sigmoid(x2))
@contextlib.contextmanager
def skip_init():
"""
skip param random init
"""
def _skip_init(*args, **kwargs):
pass
init_func = {
"constant_": torch.nn.init.constant_,
"uniform_": torch.nn.init.uniform_,
"normal_": torch.nn.init.normal_,
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
"kaiming_normal_": torch.nn.init.kaiming_normal_,
"xavier_normal_": torch.nn.init.xavier_normal_,
"xavier_uniform_": torch.nn.init.xavier_uniform_,
"trunc_normal_": torch.nn.init.trunc_normal_,
}
for method_name, original_init in init_func.items():
setattr(torch.nn.init, method_name, _skip_init)
yield
for method_name, original_init in init_func.items():
setattr(torch.nn.init, method_name, original_init)
return
def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
"""Returns a parameter dictionary, the key of which is the expert parallel
size of every parameter. Since the parameters in data parallelism is replicated
in each GPU, we set their ep_size to 1.
Args:
model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
"""
epsize_param_dict = dict()
for param in model.parameters():
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = get_ep_size(param)
if ep_size not in epsize_param_dict:
epsize_param_dict[ep_size] = []
epsize_param_dict[ep_size].append(param)
return epsize_param_dict
def sync_moe_model_param(model: nn.Module):
"""Make sure model parameters are consistent in MoE parallel context.
Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
"""
param_dict = get_moe_epsize_param_dict(model)
# synchronize the parameters whose dp_group is the whole world
if 1 in param_dict:
for param in param_dict[1]:
dist.broadcast(param, src=0)
for ep_size in param_dict:
# When ep_size = world_size, communication is not needed
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
for param in param_dict[ep_size]:
src_rank = get_dp_group_ranks(param)[0]
dist.broadcast(param, src=src_rank, group=get_dp_group(param))
def set_moe_args(config: Any, args: dict):
for k, v in args.items():
setattr(config, k, v)
# from .moe import *
from .utils import *
from .checkpoint import load_moe_model, save_moe_model
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, MoeModule
from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
"Experts",
"FFNExperts",
"TPExperts",
"Top1Router",
"Top2Router",
"MoeLayer",
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"build_ffn_experts",
"MoeModule",
"MoeRouter",
"save_moe_model",
"load_moe_model",
]
import torch
import torch.distributed as dist
import torch.nn as nn
from .experts import MoeExperts
def save_moe_model(model: nn.Module, save_path: str):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, save_path)
dist.barrier()
def load_moe_model(model: nn.Module, load_path: str):
state_dict = torch.load(load_path)
for prefix, module in model.named_modules():
if prefix.endswith(".moe_layer.experts"):
# this module should be an Experts instance
assert isinstance(module, MoeExperts)
ep_rank = dist.get_rank(module.dist_info.ep_group)
num_local = module.num_local_experts
for i in range(num_local):
expert_id = ep_rank * num_local + i
for name, _ in module.experts[i].named_parameters():
cur_key = f"{prefix}.experts.{i}.{name}"
param_key = f"{prefix}.experts.{expert_id}.{name}"
load_param = state_dict[param_key]
state_dict[cur_key] = load_param
for name, _ in module.experts[0].named_parameters():
pop_pre = f"{prefix}.experts."
pop_suf = f".{name}"
for i in range(num_local, module.num_total_experts):
pop_key = f"{pop_pre}{i}{pop_suf}"
state_dict.pop(pop_key)
model.load_state_dict(state_dict)
import math
from copy import deepcopy
from typing import Type
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator
from colossalai.utils import get_current_device
class MoeExperts(nn.Module):
"""Basic class for experts in MoE. It stores what kind of communication experts use
to exchange tokens, how many experts in a single GPU and parallel information such as
expert parallel size, data parallel size and their distributed communication groups.
"""
def __init__(self, comm_name: str, num_experts: int):
super().__init__()
assert comm_name in {
"all_to_all",
"all_gather",
}, "This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm_name = comm_name
self.num_total_experts = num_experts
# Get the configuration of experts' deployment and parallel information from moe context
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instance of the class, 'expert' in initialization parameters.
Args:
expert_cls (:class:`torch.nn.Module`): The class of all experts
num_experts (int): The number of experts
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts)
# Use seed to make every expert different from others
with seed(ParallelMode.TENSOR):
self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
# Attach parallel information for all parameters in Experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__("moe_info", self.dist_info)
def forward(self, inputs: torch.Tensor):
# Split inputs for each expert
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
# Get outputs from each expert
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
# Concatenate all outputs together
output = torch.cat(expert_output, dim=1).contiguous()
return output
def state_dict(self, destination=None, prefix="", keep_vars=False):
assert keep_vars == False, "Only support keep_vars=False now"
dp_rank = dist.get_rank(self.dist_info.dp_group)
ep_rank = dist.get_rank(self.dist_info.ep_group)
submodule_dict = dict()
example_submodule = None
for name, subm in self.experts.named_modules():
if subm is self.experts:
continue
module_number = self.num_local_experts * ep_rank + int(name)
submodule_dict[module_number] = subm
example_submodule = subm
if dp_rank == 0:
local_prefix = prefix + "experts."
buffer_module = deepcopy(example_submodule)
for i in range(self.num_total_experts):
source_rank = i // self.num_local_experts
current_prefix = local_prefix + str(i) + "."
comm_module = submodule_dict.get(i, buffer_module)
for name, param in comm_module.named_parameters():
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
if ep_rank == 0:
destination[current_prefix + name] = param.data.cpu()
dist.barrier()
class FFNExperts(MoeExperts):
"""Use torch.bmm to speed up for multiple experts."""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_to_all", num_experts)
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
param.__setattr__("moe_info", self.dist_info)
def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
class TPExperts(MoeExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
case that the number of experts can't be divide by maximum expert parallel size or
maximum expert parallel size can't be divide by the number of experts.
"""
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size"
p_ff = d_ff // MOE_CONTEXT.max_ep_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__("moe_info", self.dist_info)
self.w2.__setattr__("moe_info", self.dist_info)
self.b1.__setattr__("moe_info", self.dist_info)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
out_inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment