Unverified Commit efef43b5 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5372 from hpcaitech/exp/mixtral

parents 4c03347f 06db94fb
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
import torch
import torch.distributed as dist
......@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape = list(inputs.shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
class AllToAllUneven(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
input_split_sizes=None,
output_split_sizes=None,
group=None,
overlap: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
@staticmethod
def backward(ctx: Any, *grad_outputs):
return (
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
None,
None,
None,
None,
)
def all_to_all_uneven(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
):
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
......@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
torch.cuda.empty_cache()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
......@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}."
)
dist.barrier()
torch.cuda.empty_cache()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
......@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
......@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param
# Read checkpoint index file.
......@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
# ep extra group
if MOE_MANAGER.parallel == "EP":
# ep param group
if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
......@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
......@@ -400,26 +400,33 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
for pid, state in list(state_dict.items()):
if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
device="cpu",
inplace=True,
)
optimizer.optim.state[param] = sharded_state
state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
......@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
......@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
......@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
......
......@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1
if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
......
......@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
......@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
......@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
gate_logits = F.linear(tokens, self.gate_weight)
gate_output = gate_logits.to(torch.float)
# update expert load
if self.enable_load_balance == True:
......@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router
used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
......@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
"Please use Experts build function.")
raise NotImplementedError(
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
......@@ -204,6 +207,10 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
if self.return_gate_logits:
return ans, gate_logits
else:
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
......@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out
def _ep_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
Expert Parallel
......@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
......@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
......@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data
output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None
# all2all last output
if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None
# all2all next input
......@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output
def _tp_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
without overlap:
......@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
assert (
dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)
......
......@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
self._z_loss = None
self.use_kernel = use_kernel
def get_capacity(self, logits_shape):
def get_capacity(self, num_tokens, num_experts, ep_group=None):
if ep_group is not None:
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
dist.all_reduce(num_tokens_tensor, group=ep_group)
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
......@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
......@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
......@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask
return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter):
......@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
......@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
......@@ -270,6 +294,7 @@ class Top2Router(MoeRouter):
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
if use_loss:
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
......
......@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU()
elif act == "swiglu":
return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else:
raise NotImplementedError("Unsupported activation function")
......
......@@ -26,3 +26,5 @@ class MoeParallelInfo:
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
self.ep_rank = self.pg.coordinate(self.ep_axis)
self.dp_rank = self.pg.coordinate(self.dp_axis)
......@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params = list()
self.working_moe_params = list()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
......@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.moe_extra_dp_pg is None:
# skip moe param
if is_moe_tensor(param):
moe_params.append(param)
self.working_moe_params.append(param)
continue
group_params.append(param)
......@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
# if there are moe params, store in additional group in optim
if len(moe_params) > 0:
# if there are moe params, store in addtional group in optim
if len(self.working_moe_params) > 0:
self._sync_master_param = False
param_group = dict()
# create fp32 master param
for key, value in self.optim.param_groups[0].items():
if key != "params":
param_group[key] = value
param_group["params"] = moe_params
self.master_moe_params = []
for param in self.working_moe_params:
self.master_moe_params.append(param.clone().to(torch.float32).detach())
# create mapping from master to working for optimizer io
self.moe_master_to_working_map = {}
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
# add to optim
param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group)
# initialize communication stream for
......@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
# update param for moe ep
# move grad to master param and compute norm
if len(self.working_moe_params) > 0:
moe_grads = []
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
if master_moe_param.grad is not None:
raise RuntimeError("Moe param should not have grad here")
grad = working_moe_param.grad
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
master_moe_param.grad = grad
working_moe_param.grad = None
moe_grads.append(grad)
grad_partition_groups.append(grad)
norm_group = self._compute_grad_norm(gradients=moe_grads)
norm_groups.append(norm_group)
self.optim.param_groups[-1]["params"] = self.master_moe_params
del moe_grads
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# TODO: we should store master param for ep
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.data = param.data.to(torch.float32)
param.grad = param.grad.to(torch.float32)
# update the parameters
self.optim.step()
# release the moe gradm
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.grad = None
param.data = param.data.to(self._dtype)
# release moe grad
if len(self.working_moe_params) > 0:
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.grad = None
working_moe_param.data = (
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
)
# release the grad
grad_partition_groups = []
......@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
if hasattr(self, "master_moe_params"):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
if hasattr(self, "moe_master_to_working_map"):
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
return self._param_store.master_to_working_param
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "moe_info"):
delattr(param, "moe_info")
class MoeModel(nn.Module):
......@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
assert not torch.allclose(a, b), \
(f"expected tensors on rank {i} and {i + 1} not to be equal "
f"but they are, {a} vs {b}")
assert not torch.allclose(a, b), (
f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
)
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol)
......@@ -12,7 +12,6 @@ import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
sys.path.append(
......@@ -95,6 +94,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
......@@ -103,6 +103,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size(),
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
......@@ -111,6 +112,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=2,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
......@@ -120,6 +122,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=2,
ep_size=2,
zero_stage=1,
microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(),
......@@ -130,27 +133,6 @@ def get_model(parallel):
def _test_moe_checkpoint(rank, parallel):
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
)
elif parallel == "ep":
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "ep_zero":
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=2,
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
......@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
_test_moe_checkpoint(rank, parallel)
@pytest.mark.skip(reason="This is tested in ColossalMOE")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
......
......@@ -4,15 +4,21 @@ import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
@pytest.mark.parametrize(["router", "num_groups"], [
@pytest.mark.parametrize(
["router", "num_groups"],
[
(Top1Router(), 1),
(Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4),
])
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
],
)
@pytest.mark.parametrize(
["batch_size", "seq_len", "num_experts"],
[
(4, 5, 8),
(3, 4, 4),
])
],
)
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1:
......@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router.train()
if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
_, combine_array, dispatch_mask = router(x)
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval()
if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
_, combine_array, dispatch_mask = router(x)
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
......
......@@ -4,102 +4,75 @@ import torch
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_test(local_rank, world_size, stage=1):
def run_zero_test(local_rank, stage=1):
criterion = torch.nn.CrossEntropyLoss()
zero_model = MoeModel()
optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
booster = Booster(plugin=plugin)
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)
# assert zero model
for (torch_name, torch_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.module.named_parameters()
):
assert zero_name == torch_name
assert torch.allclose(zero_param.data, torch_param.data)
data = torch.randn(16, 4).cuda()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters())
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters())
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
sync_local_from_ep(zero_model, moe_model)
data = torch.randn(16, 4).bfloat16().cuda()
label = torch.randint(0, 4, (16,)).cuda()
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
assert torch.allclose(torch_out, zero_out)
grad_handler.handle_gradient()
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
assert torch.allclose(zero_out, moe_out)
for (zero_name, zero_param), (torch_name, torch_param) in zip(
zero_model.module.named_parameters(), torch_model.named_parameters()
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.module.named_parameters(), zero_model.module.named_parameters()
):
assert zero_name == torch_name
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
if hasattr(zero_param, "moe_info"):
assert len(zero_grad_list) == 0
assert torch.allclose(zero_param.grad, torch_param.grad)
assert moe_name == zero_name
moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
if hasattr(moe_param, "moe_info"):
assert len(moe_grad_list) == 0
if stage == 1:
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
else:
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
assert torch.allclose(
moe_param.grad, zero_grad, atol=1e-5
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
else:
assert len(zero_grad_list) > 0
torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
if stage == 2:
torch_grad_list = torch_grad_list[local_rank : local_rank + 1]
assert len(zero_grad_list) == len(torch_grad_list)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
assert torch.allclose(zero_grad, torch_grad)
assert len(moe_grad_list) > 0
assert len(moe_grad_list) == len(zero_grad_list)
for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
assert torch.allclose(moe_grad, zero_grad)
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP")
seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1)
run_zero_test(rank, world_size, stage=2)
run_zero_test(rank, stage=stage)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)
def test_moe_zero_model(world_size, stage):
spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__":
test_moe_zero_model(world_size=2)
test_moe_zero_model(world_size=2, stage=1)
......@@ -4,89 +4,80 @@ import torch
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_test(local_rank, stage=1):
criterion = torch.nn.CrossEntropyLoss()
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
sync_local_from_ep(zero_model, moe_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
):
if ".experts." in moe_name:
continue
assert moe_name == zero_name
assert torch.allclose(
moe_param.data, zero_param.data
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
def run_zero_optim_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss()
for _ in range(1):
data = torch.randn(2, 4).bfloat16().cuda()
label = torch.randint(0, 4, (2,)).cuda()
zero_model = MoeModel()
zero_optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
booster = Booster(plugin=plugin)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)
for _ in range(2):
data = torch.randn(16, 4).cuda() / (local_rank + 1)
label = torch.randint(0, 4, (16,)).cuda()
run_fwd_bwd(torch_model, data, label, criterion, None)
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
grad_handler.handle_gradient()
torch_optimizer.step()
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
assert torch.allclose(zero_out, moe_out)
moe_optimizer.step()
zero_optimizer.step()
for (torch_name, torch_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.named_parameters()
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
):
assert torch.allclose(
torch_param.data, zero_param.data
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
assert moe_name == zero_name
if is_moe_tensor(moe_param):
param_size = moe_param.shape[0]
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
torch_optimizer.zero_grad()
moe_optimizer.zero_grad()
zero_optimizer.zero_grad()
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP")
run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2)
seed_all(42 + rank)
run_zero_test(rank, stage=stage)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size):
spawn(run_dist, world_size)
def test_moe_zero_optim(world_size, stage):
spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__":
test_moe_zero_optim(world_size=2)
test_moe_zero_optim(world_size=2, stage=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment