Commit 7d8e0338 authored by Xuanlei Zhao's avatar Xuanlei Zhao Committed by ver217
Browse files

[moe] init mixtral impl

parent c53ddda8
...@@ -67,7 +67,11 @@ class MLPExperts(nn.Module): ...@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1 self.ep_size = 1
if gated: if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else: else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
......
...@@ -51,6 +51,8 @@ class SparseMLP(nn.Module): ...@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
router_top_k: int = 1, router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25, router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0, router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4, router_min_capacity: int = 4,
...@@ -65,15 +67,19 @@ class SparseMLP(nn.Module): ...@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False, enable_kernel: bool = False,
enable_comm_overlap: bool = False, enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False, enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_experts = num_experts self.num_experts = num_experts
self.gated = mlp_gated self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel() self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router # moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts) noisy_func = get_noise_generator(router_noisy_policy, num_experts)
...@@ -150,9 +156,8 @@ class SparseMLP(nn.Module): ...@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size) tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32 # the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float) gate_logits = F.linear(tokens, self.gate_weight)
fp32_weight = self.gate_weight.to(torch.float) gate_output = gate_logits.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# update expert load # update expert load
if self.enable_load_balance == True: if self.enable_load_balance == True:
...@@ -165,7 +170,12 @@ class SparseMLP(nn.Module): ...@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router # the result from the router
used_capacity, *route_result_list = self.router( used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size) # dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel: if self.enable_kernel:
...@@ -177,22 +187,15 @@ class SparseMLP(nn.Module): ...@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size) # expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP": if self.expert_parallel == "EP":
expert_output = self._ep_process( expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel == "TP": elif self.expert_parallel == "TP":
expert_output = self._tp_process( expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel is None: elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data) expert_output = self._local_process(dispatch_data)
else: else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n" raise NotImplementedError(
"Please use Experts build function.") "This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel: if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size) expert_output = expert_output.reshape(-1, self.hidden_size)
...@@ -204,6 +207,10 @@ class SparseMLP(nn.Module): ...@@ -204,6 +207,10 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape) ans = ans.reshape(inputs.shape)
if self.return_gate_logits:
return ans, gate_logits
else:
return ans return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
...@@ -212,10 +219,7 @@ class SparseMLP(nn.Module): ...@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out return expert_out
def _ep_process( def _ep_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Expert Parallel Expert Parallel
...@@ -228,10 +232,14 @@ class SparseMLP(nn.Module): ...@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
""" """
if not overlap or dist.get_world_size(self.ep_group) == 1: if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None: if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output return expert_output
else: else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
...@@ -249,7 +257,7 @@ class SparseMLP(nn.Module): ...@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape) dispatch_data = dispatch_data.reshape(*input_shape)
...@@ -262,13 +270,15 @@ class SparseMLP(nn.Module): ...@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1): for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None: if expert_out is not None:
expert_out.handle.wait() expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size offset += chunk_size
expert_out = None expert_out = None
# all2all last output # all2all last output
if _expert_out is not None: if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None _expert_out = None
# all2all next input # all2all next input
...@@ -288,10 +298,7 @@ class SparseMLP(nn.Module): ...@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output return output
def _tp_process( def _tp_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
without overlap: without overlap:
...@@ -326,8 +333,9 @@ class SparseMLP(nn.Module): ...@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ assert (
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0) chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data) output = torch.empty_like(dispatch_data)
......
...@@ -150,7 +150,14 @@ class Top1Router(MoeRouter): ...@@ -150,7 +150,14 @@ class Top1Router(MoeRouter):
high=torch.tensor(1.0, device=get_accelerator().get_current_device()), high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample ).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...@@ -207,7 +214,7 @@ class Top1Router(MoeRouter): ...@@ -207,7 +214,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs) weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool() sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter): class Top2Router(MoeRouter):
...@@ -240,7 +247,14 @@ class Top2Router(MoeRouter): ...@@ -240,7 +247,14 @@ class Top2Router(MoeRouter):
drop_tks=drop_tks, drop_tks=drop_tks,
) )
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...@@ -257,6 +271,10 @@ class Top2Router(MoeRouter): ...@@ -257,6 +271,10 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1) probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1) num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape) capacity = self.get_capacity(inputs.shape)
...@@ -270,6 +288,7 @@ class Top2Router(MoeRouter): ...@@ -270,6 +288,7 @@ class Top2Router(MoeRouter):
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss # calculate loss
if use_loss:
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts) self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs) self.set_z_loss(inputs)
......
...@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable: ...@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU() return torch.nn.GELU()
elif act == "swiglu": elif act == "swiglu":
return SwiGLU return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else: else:
raise NotImplementedError("Unsupported activation function") raise NotImplementedError("Unsupported activation function")
......
...@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy # because they have different parallel strategy
# so we need to store them separately in param_groups # so we need to store them separately in param_groups
# instead of working_groups # instead of working_groups
moe_params = list() self.working_moe_params = list()
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
# partition these param groups for data parallel training # partition these param groups for data parallel training
...@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.moe_extra_dp_pg is None: if self.moe_extra_dp_pg is None:
# skip moe param # skip moe param
if is_moe_tensor(param): if is_moe_tensor(param):
moe_params.append(param) self.working_moe_params.append(param)
continue continue
group_params.append(param) group_params.append(param)
...@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank # managed by this data parallel rank
param_group["params"] = master_param_current_rank param_group["params"] = master_param_current_rank
# if there are moe params, store in additional group in optim # if there are moe params, store in addtional group in optim
if len(moe_params) > 0: if len(self.working_moe_params) > 0:
self._sync_master_param = False
param_group = dict() param_group = dict()
# create fp32 master param
for key, value in self.optim.param_groups[0].items(): for key, value in self.optim.param_groups[0].items():
if key != "params": if key != "params":
param_group[key] = value param_group[key] = value
param_group["params"] = moe_params self.master_moe_params = []
for param in self.working_moe_params:
self.master_moe_params.append(param.clone().to(torch.float32).detach())
# create mapping from master to working for optimizer io
self.moe_master_to_working_map = {}
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
# add to optim
param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group) self.optim.param_groups.append(param_group)
# initialize communication stream for # initialize communication stream for
...@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer # update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id] self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
# update param for moe ep
# move grad to master param and compute norm
if len(self.working_moe_params) > 0:
moe_grads = []
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
if master_moe_param.grad is not None:
raise RuntimeError("Moe param should not have grad here")
grad = working_moe_param.grad
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
master_moe_param.grad = grad
working_moe_param.grad = None
moe_grads.append(grad)
grad_partition_groups.append(grad)
norm_group = self._compute_grad_norm(gradients=moe_grads)
norm_groups.append(norm_group)
self.optim.param_groups[-1]["params"] = self.master_moe_params
del moe_grads
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm) self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# TODO: we should store master param for ep
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.data = param.data.to(torch.float32)
param.grad = param.grad.to(torch.float32)
# update the parameters # update the parameters
self.optim.step() self.optim.step()
# release the moe gradm # release moe grad
if len(self.param_groups) > len(self._working_param_groups): if len(self.working_moe_params) > 0:
for param in self.param_groups[-1]["params"]: for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
param.grad = None master_moe_param.grad = None
param.data = param.data.to(self._dtype) working_moe_param.data = (
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
)
# release the grad # release the grad
grad_partition_groups = [] grad_partition_groups = []
...@@ -640,6 +666,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -640,6 +666,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def sync_moe_master_param(self):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r""" r"""
Compute and return the gradient norm for gradient clipping. Compute and return the gradient norm for gradient clipping.
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict from colossalai.moe.utils import get_moe_epsize_param_dict
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "moe_info"):
delattr(param, "moe_info")
class MoeModel(nn.Module): class MoeModel(nn.Module):
...@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None): ...@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for i in range(world_size - 1): for i in range(world_size - 1):
a = tensor_list[i] a = tensor_list[i]
b = tensor_list[i + 1] b = tensor_list[i + 1]
assert not torch.allclose(a, b), \ assert not torch.allclose(a, b), (
(f"expected tensors on rank {i} and {i + 1} not to be equal " f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
f"but they are, {a} vs {b}") )
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol)
...@@ -4,102 +4,75 @@ import torch ...@@ -4,102 +4,75 @@ import torch
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size): def run_zero_test(local_rank, stage=1):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
zero_model = MoeModel() MOE_MANAGER.__init__()
optimizer = torch.optim.Adam(zero_model.parameters()) MOE_MANAGER.setup(parallel="EP")
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") moe_model = MoeModel().bfloat16()
booster = Booster(plugin=plugin) moe_optimizer = torch.optim.Adam(moe_model.parameters())
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
torch_model = MoeModel() moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data) MOE_MANAGER.__init__()
torch_model = torch_model.cuda() MOE_MANAGER.setup(parallel=None)
grad_handler = MoeGradientHandler(torch_model) zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
# assert zero model zero_optimizer = torch.optim.Adam(zero_model.parameters())
for (torch_name, torch_param), (zero_name, zero_param) in zip( zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
torch_model.named_parameters(), zero_model.module.named_parameters() zero_booster = Booster(plugin=zero_plugin)
): zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
assert zero_name == torch_name sync_local_from_ep(zero_model, moe_model)
assert torch.allclose(zero_param.data, torch_param.data)
data = torch.randn(16, 4).bfloat16().cuda()
data = torch.randn(16, 4).cuda()
label = torch.randint(0, 4, (16,)).cuda() label = torch.randint(0, 4, (16,)).cuda()
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
assert torch.allclose(torch_out, zero_out) assert torch.allclose(zero_out, moe_out)
grad_handler.handle_gradient()
for (zero_name, zero_param), (torch_name, torch_param) in zip( for (moe_name, moe_param), (zero_name, zero_param) in zip(
zero_model.module.named_parameters(), torch_model.named_parameters() moe_model.module.named_parameters(), zero_model.module.named_parameters()
): ):
assert zero_name == torch_name assert moe_name == zero_name
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
if hasattr(zero_param, "moe_info"): zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
assert len(zero_grad_list) == 0 if hasattr(moe_param, "moe_info"):
assert torch.allclose(zero_param.grad, torch_param.grad) assert len(moe_grad_list) == 0
if stage == 1:
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
else:
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
assert torch.allclose(
moe_param.grad, zero_grad, atol=1e-5
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
else: else:
assert len(zero_grad_list) > 0 assert len(moe_grad_list) > 0
torch_grad_list = split_ddp_grad(torch_param.grad, world_size) assert len(moe_grad_list) == len(zero_grad_list)
if stage == 2: for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
torch_grad_list = torch_grad_list[local_rank : local_rank + 1] assert torch.allclose(moe_grad, zero_grad)
assert len(zero_grad_list) == len(torch_grad_list)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
assert torch.allclose(zero_grad, torch_grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP")
seed_all(42 + rank) seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1) run_zero_test(rank, stage=stage)
run_zero_test(rank, world_size, stage=2)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_model(world_size): def test_moe_zero_model(world_size, stage):
spawn(run_dist, world_size) spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_zero_model(world_size=2) test_moe_zero_model(world_size=2, stage=1)
...@@ -4,89 +4,80 @@ import torch ...@@ -4,89 +4,80 @@ import torch
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size): def run_zero_test(local_rank, stage=1):
with torch.no_grad(): criterion = torch.nn.CrossEntropyLoss()
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
sync_local_from_ep(zero_model, moe_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
):
if ".experts." in moe_name:
continue
assert moe_name == zero_name
assert torch.allclose(
moe_param.data, zero_param.data
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
def run_zero_optim_test(local_rank, world_size, stage=1): for _ in range(1):
criterion = torch.nn.CrossEntropyLoss() data = torch.randn(2, 4).bfloat16().cuda()
label = torch.randint(0, 4, (2,)).cuda()
zero_model = MoeModel() moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
zero_optimizer = torch.optim.Adam(zero_model.parameters()) zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") assert torch.allclose(zero_out, moe_out)
booster = Booster(plugin=plugin) moe_optimizer.step()
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)
for _ in range(2):
data = torch.randn(16, 4).cuda() / (local_rank + 1)
label = torch.randint(0, 4, (16,)).cuda()
run_fwd_bwd(torch_model, data, label, criterion, None)
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
grad_handler.handle_gradient()
torch_optimizer.step()
zero_optimizer.step() zero_optimizer.step()
for (torch_name, torch_param), (zero_name, zero_param) in zip( for (moe_name, moe_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.named_parameters() moe_model.named_parameters(), zero_model.named_parameters()
): ):
assert torch.allclose( assert moe_name == zero_name
torch_param.data, zero_param.data if is_moe_tensor(moe_param):
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" param_size = moe_param.shape[0]
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
torch_optimizer.zero_grad() moe_optimizer.zero_grad()
zero_optimizer.zero_grad() zero_optimizer.zero_grad()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank)
run_zero_optim_test(rank, world_size, stage=1) run_zero_test(rank, stage=stage)
run_zero_optim_test(rank, world_size, stage=2)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size): def test_moe_zero_optim(world_size, stage):
spawn(run_dist, world_size) spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_zero_optim(world_size=2) test_moe_zero_optim(world_size=2, stage=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment