Unverified Commit dbdc9a77 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

added Multiply Jitter and capacity factor eval for MOE (#434)

parent b03b3ae9
...@@ -11,6 +11,7 @@ from colossalai.utils import get_current_device ...@@ -11,6 +11,7 @@ from colossalai.utils import get_current_device
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts from .experts import MoeExperts
from .utils import autocast_softmax from .utils import autocast_softmax
from typing import Callable
class Top1Router(nn.Module): class Top1Router(nn.Module):
...@@ -18,21 +19,35 @@ class Top1Router(nn.Module): ...@@ -18,21 +19,35 @@ class Top1Router(nn.Module):
for routing usage. More deailted function can be found in the paper about Switch Transformer for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google. of Google.
:param capacity_factor: Capacity factor in routing :param capacity_factor_train: Capacity factor in routing of training
:param capacity_factor_eval: Capacity factor in routing of evaluation
:param min_capacity: The minimum number of the capacity of each expert :param min_capacity: The minimum number of the capacity of each expert
:param select_policy: The policy about tokens selection
:param noisy_func: Noisy function used in logits :param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation
:type capacity_factor: float :type capacity_factor_train: float, optional
:type min_capacity: int :type capacity_factor_eval: float, optional
:type min_capacity: int, optional
:type select_policy: str, optional
:type noisy_func: Callable, optional :type noisy_func: Callable, optional
:type drop_tks: bool, optional
""" """
def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None): def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__() super().__init__()
self.capacity_factor = capacity_factor self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity self.min_capacity = min_capacity
self.select_policy = select_policy self.select_policy = select_policy
self.noisy_func = noisy_func self.noisy_func = noisy_func
self.drop_tks = drop_tks
assert select_policy in {"first", "random"} assert select_policy in {"first", "random"}
if select_policy == "random": if select_policy == "random":
...@@ -44,7 +59,8 @@ class Top1Router(nn.Module): ...@@ -44,7 +59,8 @@ class Top1Router(nn.Module):
self, self,
logits_shape, logits_shape,
): ):
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1]) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2 capacity += capacity % 2
capacity = max(capacity, self.min_capacity) capacity = max(capacity, self.min_capacity)
assert capacity > 0 assert capacity > 0
...@@ -53,15 +69,13 @@ class Top1Router(nn.Module): ...@@ -53,15 +69,13 @@ class Top1Router(nn.Module):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None and self.training: if self.noisy_func is not None and self.training:
inputs_noisy = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
else:
inputs_noisy = inputs
logits = autocast_softmax(inputs, dim=-1) logits = autocast_softmax(inputs, dim=-1)
num_experts = logits.size(-1) num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape) capacity = self.get_capacity(logits.shape)
top1_idx = torch.argmax(inputs_noisy, dim=-1) top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
if self.training: if self.training:
...@@ -69,14 +83,14 @@ class Top1Router(nn.Module): ...@@ -69,14 +83,14 @@ class Top1Router(nn.Module):
ce = torch.mean(mask.float(), dim=0) ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) l_aux = num_experts * torch.sum(me * ce)
moe_env.add_loss(l_aux) moe_env.add_loss(l_aux)
else: elif not self.drop_tks:
max_num = torch.max(torch.sum(mask, dim=0)) max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item() capacity = max_num.item()
else:
pass
if not self.training: if self.select_policy == "random":
ranks = moe_cumsum(mask)
elif self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape) rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
...@@ -106,21 +120,40 @@ class Top2Router(nn.Module): ...@@ -106,21 +120,40 @@ class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE. for routing usage. More deailted function can be found in the paper about ViT-MoE.
:param capacity_factor: Capacity factor in routing :param capacity_factor_train: Capacity factor in routing of training
:param capacity_factor_eval: Capacity factor in routing of evaluation
:param min_capacity: The minimum number of the capacity of each expert
:param noisy_func: Noisy function used in logits :param noisy_func: Noisy function used in logits
:param drop_tks: Whether drops tokens in evaluation
:type capacity_factor: float :type capacity_factor_train: float, optional
:type capacity_factor_eval: float, optional
:type min_capacity: int, optional
:type noisy_func: Callable, optional :type noisy_func: Callable, optional
:type drop_tks: bool, optional
""" """
def __init__(self, capacity_factor: float, noisy_func=None): def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Callable = None,
drop_tks: bool = True):
super().__init__() super().__init__()
self.capacity_factor = capacity_factor self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func self.noisy_func = noisy_func
self.drop_tks = drop_tks
def get_capacity(self, logits_shape): def get_capacity(
capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1]) self,
logits_shape,
):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2 capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0 assert capacity > 0
return capacity return capacity
...@@ -143,12 +176,14 @@ class Top2Router(nn.Module): ...@@ -143,12 +176,14 @@ class Top2Router(nn.Module):
if self.training: if self.training:
me = torch.mean(logits, dim=0) me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0) ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0 l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
moe_env.add_loss(l_aux) moe_env.add_loss(l_aux)
else: elif not self.drop_tks:
max_num = torch.max(torch.sum(cmask, dim=0)) max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item() capacity = max_num.item()
else:
pass
rank1 = moe_cumsum(mask1) # rank1: [s, e] rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2) rank2 = moe_cumsum(mask2)
......
...@@ -25,6 +25,27 @@ class NormalNoiseGenerator: ...@@ -25,6 +25,27 @@ class NormalNoiseGenerator:
return inputs + noisy return inputs + noisy
class UniformNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
copied from mesh tensorflow:
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
:param eps: Epsilon in generator
:type eps: float
"""
def __init__(self, eps: float):
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps,
device=get_current_device())).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.uniform(inputs.shape)
return inputs * noisy
def autocast_softmax(inputs: torch.Tensor, dim: int): def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32} assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16) fp16_flag = (inputs.dtype == torch.float16)
......
...@@ -84,7 +84,9 @@ class Widenet(nn.Module): ...@@ -84,7 +84,9 @@ class Widenet(nn.Module):
def __init__(self, def __init__(self,
num_experts: int, num_experts: int,
capacity_factor: float, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224, img_size: int = 224,
patch_size: int = 16, patch_size: int = 16,
in_chans: int = 3, in_chans: int = 3,
...@@ -109,7 +111,10 @@ class Widenet(nn.Module): ...@@ -109,7 +111,10 @@ class Widenet(nn.Module):
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts) noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) shared_router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate) shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule # stochastic depth decay rule
...@@ -142,7 +147,9 @@ class ViTMoE(nn.Module): ...@@ -142,7 +147,9 @@ class ViTMoE(nn.Module):
def __init__(self, def __init__(self,
num_experts: int, num_experts: int,
capacity_factor: float, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224, img_size: int = 224,
patch_size: int = 16, patch_size: int = 16,
in_chans: int = 3, in_chans: int = 3,
...@@ -164,8 +171,10 @@ class ViTMoE(nn.Module): ...@@ -164,8 +171,10 @@ class ViTMoE(nn.Module):
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts) noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor, noisy_func=noisy_func) router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
assert depth % 2 == 0 assert depth % 2 == 0
# stochastic depth decay rule # stochastic depth decay rule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment