Commit 219df6e6 authored by 1SAA's avatar 1SAA Committed by Frank Lee
Browse files

Optimized MoE layer and fixed some bugs;

Decreased moe tests;

Added FFNExperts and ViTMoE model
parent 3dba0705
......@@ -9,6 +9,6 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.0
rev: v13.0.1
hooks:
- id: clang-format
......@@ -56,6 +56,7 @@ class MoeEnv:
self.data_parallel_size = None
self.model_parallel_size = None
self.aux_loss = None
self.enable_cuda = True
def setup(self, moe_model_size):
from .core import global_context as gpc
......@@ -71,6 +72,9 @@ class MoeEnv:
def is_initialized(self):
return self.model_parallel_size is not None
def set_cuda_false(self):
self.enable_cuda = False
def reset_loss(self):
self.aux_loss = 0
......
......@@ -5,7 +5,7 @@
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
......
#include <torch/extension.h>
torch::Tensor moe_dispatch_cuda_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_dispatch_cuda_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_combine_cuda_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_forward(
s, ec, h,
batch_tokens, mask, dest_idx);
}
torch::Tensor moe_dispatch_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_backward(
s, ec, h,
expert_grad, mask, dest_idx);
}
torch::Tensor moe_combine_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_forward(
s, e, c, h,
expert_tokens, logits, mask, dest_idx);
}
std::vector<torch::Tensor> moe_combine_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_backward(
s, e, c, h,
tokens_grad, expert_tokens, logits, mask, dest_idx);
}
torch::Tensor moe_cumsum(torch::Tensor mask) {
CHECK_INPUT(mask);
return cumsum_sub_one_in_dim0(mask);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cumsum_sub_one", &moe_cumsum,
"Fast cumsum operation in dim0");
m.def("dispatch_forward", &moe_dispatch_forward,
"Forward operation in MoE dispatch function");
m.def("dispatch_backward", &moe_dispatch_backward,
"Backward operation in MoE dispatch function");
m.def("combine_forward", &moe_combine_forward,
"Combine operation in MoE combine function");
m.def("combine_backward", &moe_combine_backward,
"Combine operation in MoE combine function");
}
This diff is collapsed.
from ._operation import AllToAll
from .layers import Experts, MoeLayer, \
NormalNoiseGenerator, Top1Router, Top2Router
from .experts import Experts, FFNExperts
from .layers import MoeLayer, Top1Router, Top2Router
from .utils import NormalNoiseGenerator
__all__ = [
'AllToAll', 'Experts', 'Top1Router', 'Top2Router',
'MoeLayer', 'NormalNoiseGenerator'
]
\ No newline at end of file
__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator']
......@@ -6,16 +6,26 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from typing import Any, Tuple
U_CUDA_MODE = False
try:
import colossal_moe_cuda
U_CUDA_MODE = True
except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(ctx: Any,
inputs: Tensor,
parallel_mode: ParallelMode) -> Tensor:
ctx.parallel_mode = parallel_mode
if ctx is not None:
ctx.parallel_mode = parallel_mode
if not inputs.is_contiguous():
inputs = inputs.contiguous()
......@@ -26,4 +36,79 @@ class AllToAll(torch.autograd.Function):
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
class MoeDispatch(torch.autograd.Function):
@staticmethod
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
return expert_input
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = colossal_moe_cuda.dispatch_backward(
ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h,
cb_input, logits,
mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
return output
@staticmethod
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = colossal_moe_cuda.combine_backward(
ctx.s, ctx.e, ctx.c, ctx.h,
cb_grad, cb_input, logits, mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and U_CUDA_MODE:
return colossal_moe_cuda.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
import math
import torch
import torch.nn as nn
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
:param expert: The class of all experts
:param num_experts: The number of experts
:param expert_args: Args used to initialize experts
:type num_experts: int
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
output = torch.cat(expert_output, dim=1).contiguous()
return output
class FFNExperts(nn.Module):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs): # x [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
......@@ -3,70 +3,13 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device
from ._operation import AllToAll
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
:param num_experts: The number of experts
:type num_experts: int
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
:param expert: The class of all experts
:param num_experts: The number of experts
:param expert_args: Args used to initialize experts
:type num_experts: int
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([
expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', 1)
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=0)
expert_output = []
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
output = torch.cat(expert_output, dim=0)
return output
from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum
from .utils import autocast_softmax
class Top1Router(nn.Module):
......@@ -83,63 +26,79 @@ class Top1Router(nn.Module):
:type noisy_func: Callable, optional
"""
def __init__(self,
capacity_factor: float,
min_capacity: int,
noisy_func=None):
def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None):
super().__init__()
self.capacity_factor = capacity_factor
self.min_capacity = min_capacity
self.select_policy = select_policy
self.noisy_func = noisy_func
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())).rsample
def get_capacity(self, logits_shape):
capacity = math.ceil(self.capacity_factor *
logits_shape[0] / logits_shape[1])
if capacity < self.min_capacity:
capacity = self.min_capacity
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample
def get_capacity(
self,
logits_shape,
):
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
def forward(self, inputs):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None:
inputs_noisy = self.noisy_func(inputs)
else:
inputs_noisy = inputs
logits = F.softmax(inputs, dim=1)
num_experts = logits.shape[1]
logits = autocast_softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
expert_idx = torch.argmax(inputs_noisy, dim=1)
expert_mask = F.one_hot(expert_idx, num_classes=num_experts)
expert_mask_f = expert_mask.float()
exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu')
me = torch.mean(logits, dim=0)
ce = torch.mean(expert_mask_f, dim=0)
l_aux = torch.sum(me * ce) * num_experts
moe_env.add_loss(l_aux)
rand_mask = expert_mask * self.uniform(logits.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
dispatch_mask = \
expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1)
top1_idx = torch.argmax(inputs_noisy, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
locations = torch.cumsum(dispatch_mask, dim=0) - 1
locations = torch.sum(dispatch_mask * locations, dim=1)
locations = F.one_hot(locations, num_classes=capacity)
if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
moe_env.add_loss(l_aux)
else:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
if not self.training:
ranks = moe_cumsum(mask)
elif self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask)
elif self.select_policy == "first":
ranks = moe_cumsum(mask)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
logits = logits * dispatch_mask
combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1)
ranks = torch.sum(mask * ranks, dim=-1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask, exp_counts
if cuda_mode:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * logits.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
class Top2Router(nn.Module):
......@@ -159,53 +118,67 @@ class Top2Router(nn.Module):
self.noisy_func = noisy_func
def get_capacity(self, logits_shape):
capacity = math.ceil(2 * self.capacity_factor *
logits_shape[0] / logits_shape[1])
capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
assert capacity > 0
return capacity
def forward(self, inputs):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
# inputs: [s, h]
if self.noisy_func is not None:
inputs = self.noisy_func(inputs)
logits = F.softmax(inputs, dim=-1)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
_, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True)
top1_idx = expert_idx[:, 0]
top2_idx = expert_idx[:, 1]
top1_idx = torch.argmax(logits, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
moe_env.add_loss(l_aux)
else:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
mask1 = F.one_hot(top1_idx, num_classes=num_experts)
mask2 = F.one_hot(top2_idx, num_classes=num_experts)
rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
loss_mask = (mask1 + mask2)
exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu')
me = torch.mean(logits, dim=0)
ce = torch.mean(loss_mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
moe_env.add_loss(l_aux)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
locations2 += torch.sum(mask1, dim=0, keepdim=True)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
if cuda_mode:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
weight1 = mask1 * logits
weight2 = mask2 * logits
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
locations1 = torch.sum(mask1 * locations1, dim=1)
locations2 = torch.sum(mask2 * locations2, dim=1)
locations1_sc = F.one_hot(locations1, num_classes=capacity)
locations2_sc = F.one_hot(locations2, num_classes=capacity)
return logits, mask, dest_idx, num_experts * capacity
else:
weight1 = mask1 * logits.type_as(inputs)
weight2 = mask2 * logits.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1)
combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1)
combine_weights = combine_weights1 + combine_weights2
sec_mask = combine_weights.bool()
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
return combine_weights, sec_mask, exp_counts
return cb_weight, sec_mask
class MoeLayer(nn.Module):
......@@ -225,52 +198,47 @@ class MoeLayer(nn.Module):
:type experts: nn.Module
"""
def __init__(self,
dim_model: int,
num_experts: int,
router: nn.Module,
experts: nn.Module):
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate = nn.Linear(dim_model, num_experts, device=get_current_device())
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
self.router = router
self.experts = experts
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
def _router_part(self, tokens: torch.Tensor):
gate_output = self.gate(tokens)
return self.router(gate_output)
def router_part(self, tokens: torch.Tensor):
autocast_context = torch.is_autocast_enabled()
if not autocast_context:
return self._router_part(tokens)
else:
with autocast(enabled=False):
if tokens.dtype == torch.float16:
input_tokens = tokens.float()
else:
input_tokens = tokens
return self._router_part(input_tokens)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model)
combine_weights, sec_mask, exp_counts = self.router_part(tokens)
def expert_part(self, expert_input: torch.Tensor):
expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL)
sec_mask_f = sec_mask.type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
input_shape = expert_input.shape
dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
expert_input = expert_input.reshape(moe_env.model_parallel_size,
self.num_experts // moe_env.model_parallel_size, -1, self.d_model)
expert_output = self.experts(dispatch_data)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
return expert_output
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens)
router_res = self.router(gate_output, self.cuda_mode)
ret = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape)
if self.cuda_mode:
logits, mask, dest_idx, ec = router_res
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec)
expert_output = self.expert_part(expert_input)
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
else:
combine_weights, sec_mask = router_res
sec_mask_f = sec_mask.type_as(inputs)
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
expert_output = self.expert_part(expert_input)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape)
return ret
import torch
import torch.nn.functional as F
from colossalai.utils import get_current_device
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
:param num_experts: The number of experts
:type num_experts: int
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16)
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim)
return sm_output
......@@ -4,7 +4,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.global_variables import moe_env
......@@ -81,6 +81,7 @@ class VanillaFFN(nn.Module):
class Widenet(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
......@@ -98,43 +99,33 @@ class Widenet(nn.Module):
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embedding = VanillaPatchEmbedding(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
shared_sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv,
attention_drop=attention_drop, drop_rate=drop_rate))
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = Experts(expert=VanillaFFN,
num_experts=num_experts,
**moe_mlp_args(
d_model=d_model,
d_ff=d_ff,
drop_rate=drop_rate
))
shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [
TransformerLayer(
att=shared_sa,
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
router=shared_router, experts=shared_experts),
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)
)
for i in range(depth)
TransformerLayer(att=shared_sa,
ffn=MoeLayer(dim_model=d_model,
num_experts=num_experts,
router=shared_router,
experts=shared_experts),
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)) for i in range(depth)
]
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model,
num_classes=num_classes)
self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
......@@ -145,3 +136,64 @@ class Widenet(nn.Module):
x = torch.mean(x, dim=1)
x = self.linear(x)
return x
class ViTMoE(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
d_model: int = 768,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 3072,
attention_drop: float = 0.,
drop_rate: float = 0.1,
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor, noisy_func=noisy_func)
assert depth % 2 == 0
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = []
for i in range(depth):
sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa,
ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR))
blocks.append(layer)
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x):
moe_env.reset_loss()
x = self.vitmoe(x)
x = torch.mean(x, dim=1)
x = self.linear(x)
return x
......@@ -162,6 +162,10 @@ if build_cuda_ext:
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'],
extra_cuda_flags + cc_flag))
ext_modules.append(cuda_ext_helper('colossal_moe_cuda',
['moe_cuda.cpp', 'moe_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
extra_cuda_flags = ['-maxrregcount=50']
ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda',
......
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size,
dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.dist
@pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing, world_size=world_size, port=free_port(),
rs=rs, hidden_size=hidden_size, data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16)
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top1Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(60, 512, torch.float16)
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment