"tests/test_tensor/model/test_module_spec.py" did not exist on "ee50497db2a4754f4bc064710a27dc759fc98d79"
Unverified Commit efef43b5 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5372 from hpcaitech/exp/mixtral

parents 4c03347f 06db94fb
from typing import Any, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function): ...@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if ctx.ep_size != 1: if ctx.ep_size != 1:
grad = grad / ctx.ep_size grad = grad / ctx.ep_size
return grad, None return grad, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape = list(inputs.shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
class AllToAllUneven(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
input_split_sizes=None,
output_split_sizes=None,
group=None,
overlap: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
@staticmethod
def backward(ctx: Any, *grad_outputs):
return (
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
None,
None,
None,
None,
)
def all_to_all_uneven(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
):
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
...@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
""" """
torch.cuda.empty_cache()
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
...@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}." f"index located at {save_index_file}."
) )
dist.barrier() dist.barrier()
torch.cuda.empty_cache()
# ======================================================== # ========================================================
# Abstract methods for optimizer loading/saving implementation # Abstract methods for optimizer loading/saving implementation
...@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param( def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
): ):
if master_to_working_map is not None and id(param) in master_to_working_map: if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)] working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
return optimizer.param_info["param2id"][id(working_param)] return optimizer.param_info["param2id"][id(working_param)]
...@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map = optimizer.get_master_to_working_map() master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups: for pg in optimizer.optim.param_groups:
for param in pg["params"]: for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param id_map[param_id] = param
# Read checkpoint index file. # Read checkpoint index file.
...@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg) updated_groups.append(new_pg)
# ep extra group # ep param group
if MOE_MANAGER.parallel == "EP": if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][ new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg) updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups}) optimizer.optim.__dict__.update({"param_groups": updated_groups})
...@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for param in pg["params"]: for param in pg["params"]:
if param is None: if param is None:
continue continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map: if param_id not in weight_map:
continue continue
filename = weight_map[param_id] filename = weight_map[param_id]
...@@ -400,26 +400,33 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -400,26 +400,33 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename) file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero. # Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items(): for pid, state in list(state_dict.items()):
device = param.device if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map: if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)] working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)] original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim( sharded_state = self.pre_load_optim(
state, state,
param, working_param,
current_shape=working_param.shape, current_shape=working_param.shape,
original_shape=original_shape, original_shape=original_shape,
device=device, device="cpu",
inplace=True, inplace=True,
) )
optimizer.optim.state[param] = sharded_state state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
sharded_optimizer_loading_epilogue(optimizer.optim) sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master(): if self.verbose and self.coordinator.is_master():
...@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if master_to_working_map is not None and id(param) in master_to_working_map: if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)] working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
...@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors size_per_shard (int): Max file size of each file shard that store state tensors
""" """
torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
...@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}." f"index located at {final_index_file_path}."
) )
torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
""" """
......
...@@ -67,7 +67,11 @@ class MLPExperts(nn.Module): ...@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1 self.ep_size = 1
if gated: if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else: else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
......
...@@ -51,6 +51,8 @@ class SparseMLP(nn.Module): ...@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
router_top_k: int = 1, router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25, router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0, router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4, router_min_capacity: int = 4,
...@@ -65,15 +67,19 @@ class SparseMLP(nn.Module): ...@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False, enable_kernel: bool = False,
enable_comm_overlap: bool = False, enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False, enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_experts = num_experts self.num_experts = num_experts
self.gated = mlp_gated self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel() self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router # moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts) noisy_func = get_noise_generator(router_noisy_policy, num_experts)
...@@ -150,9 +156,8 @@ class SparseMLP(nn.Module): ...@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size) tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32 # the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float) gate_logits = F.linear(tokens, self.gate_weight)
fp32_weight = self.gate_weight.to(torch.float) gate_output = gate_logits.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# update expert load # update expert load
if self.enable_load_balance == True: if self.enable_load_balance == True:
...@@ -165,7 +170,12 @@ class SparseMLP(nn.Module): ...@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router # the result from the router
used_capacity, *route_result_list = self.router( used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size) # dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel: if self.enable_kernel:
...@@ -177,22 +187,15 @@ class SparseMLP(nn.Module): ...@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size) # expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP": if self.expert_parallel == "EP":
expert_output = self._ep_process( expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel == "TP": elif self.expert_parallel == "TP":
expert_output = self._tp_process( expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel is None: elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data) expert_output = self._local_process(dispatch_data)
else: else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n" raise NotImplementedError(
"Please use Experts build function.") "This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel: if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size) expert_output = expert_output.reshape(-1, self.hidden_size)
...@@ -204,6 +207,10 @@ class SparseMLP(nn.Module): ...@@ -204,6 +207,10 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape) ans = ans.reshape(inputs.shape)
if self.return_gate_logits:
return ans, gate_logits
else:
return ans return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
...@@ -212,10 +219,7 @@ class SparseMLP(nn.Module): ...@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out return expert_out
def _ep_process( def _ep_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Expert Parallel Expert Parallel
...@@ -228,10 +232,14 @@ class SparseMLP(nn.Module): ...@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
""" """
if not overlap or dist.get_world_size(self.ep_group) == 1: if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None: if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output return expert_output
else: else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
...@@ -249,7 +257,7 @@ class SparseMLP(nn.Module): ...@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape) dispatch_data = dispatch_data.reshape(*input_shape)
...@@ -262,13 +270,15 @@ class SparseMLP(nn.Module): ...@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1): for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None: if expert_out is not None:
expert_out.handle.wait() expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size offset += chunk_size
expert_out = None expert_out = None
# all2all last output # all2all last output
if _expert_out is not None: if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None _expert_out = None
# all2all next input # all2all next input
...@@ -288,10 +298,7 @@ class SparseMLP(nn.Module): ...@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output return output
def _tp_process( def _tp_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
without overlap: without overlap:
...@@ -326,8 +333,9 @@ class SparseMLP(nn.Module): ...@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ assert (
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0) chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data) output = torch.empty_like(dispatch_data)
......
...@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC): ...@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
self._z_loss = None self._z_loss = None
self.use_kernel = use_kernel self.use_kernel = use_kernel
def get_capacity(self, logits_shape): def get_capacity(self, num_tokens, num_experts, ep_group=None):
if ep_group is not None:
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
dist.all_reduce(num_tokens_tensor, group=ep_group)
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
capacity += capacity % 2 capacity += capacity % 2
capacity = max(capacity, self.min_capacity) capacity = max(capacity, self.min_capacity)
assert capacity > 0 assert capacity > 0
...@@ -150,7 +154,14 @@ class Top1Router(MoeRouter): ...@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
high=torch.tensor(1.0, device=get_accelerator().get_current_device()), high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample ).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...@@ -168,7 +179,8 @@ class Top1Router(MoeRouter): ...@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
assert inputs.dtype == torch.float assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1) probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1) num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape) num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(inputs, dim=-1) top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
...@@ -207,7 +219,7 @@ class Top1Router(MoeRouter): ...@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs) weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool() sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter): class Top2Router(MoeRouter):
...@@ -240,7 +252,14 @@ class Top2Router(MoeRouter): ...@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
drop_tks=drop_tks, drop_tks=drop_tks,
) )
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
...@@ -257,8 +276,13 @@ class Top2Router(MoeRouter): ...@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1) probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1) num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape) num_tokens = inputs.size(0)
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(probs, dim=-1) top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
...@@ -270,6 +294,7 @@ class Top2Router(MoeRouter): ...@@ -270,6 +294,7 @@ class Top2Router(MoeRouter):
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss # calculate loss
if use_loss:
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts) self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs) self.set_z_loss(inputs)
......
...@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable: ...@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU() return torch.nn.GELU()
elif act == "swiglu": elif act == "swiglu":
return SwiGLU return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else: else:
raise NotImplementedError("Unsupported activation function") raise NotImplementedError("Unsupported activation function")
......
...@@ -26,3 +26,5 @@ class MoeParallelInfo: ...@@ -26,3 +26,5 @@ class MoeParallelInfo:
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis) self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
self.ep_rank = self.pg.coordinate(self.ep_axis)
self.dp_rank = self.pg.coordinate(self.dp_axis)
...@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy # because they have different parallel strategy
# so we need to store them separately in param_groups # so we need to store them separately in param_groups
# instead of working_groups # instead of working_groups
moe_params = list() self.working_moe_params = list()
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
# partition these param groups for data parallel training # partition these param groups for data parallel training
...@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.moe_extra_dp_pg is None: if self.moe_extra_dp_pg is None:
# skip moe param # skip moe param
if is_moe_tensor(param): if is_moe_tensor(param):
moe_params.append(param) self.working_moe_params.append(param)
continue continue
group_params.append(param) group_params.append(param)
...@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank # managed by this data parallel rank
param_group["params"] = master_param_current_rank param_group["params"] = master_param_current_rank
# if there are moe params, store in additional group in optim # if there are moe params, store in addtional group in optim
if len(moe_params) > 0: if len(self.working_moe_params) > 0:
self._sync_master_param = False
param_group = dict() param_group = dict()
# create fp32 master param
for key, value in self.optim.param_groups[0].items(): for key, value in self.optim.param_groups[0].items():
if key != "params": if key != "params":
param_group[key] = value param_group[key] = value
param_group["params"] = moe_params self.master_moe_params = []
for param in self.working_moe_params:
self.master_moe_params.append(param.clone().to(torch.float32).detach())
# create mapping from master to working for optimizer io
self.moe_master_to_working_map = {}
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
# add to optim
param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group) self.optim.param_groups.append(param_group)
# initialize communication stream for # initialize communication stream for
...@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer # update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id] self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
# update param for moe ep
# move grad to master param and compute norm
if len(self.working_moe_params) > 0:
moe_grads = []
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
if master_moe_param.grad is not None:
raise RuntimeError("Moe param should not have grad here")
grad = working_moe_param.grad
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
master_moe_param.grad = grad
working_moe_param.grad = None
moe_grads.append(grad)
grad_partition_groups.append(grad)
norm_group = self._compute_grad_norm(gradients=moe_grads)
norm_groups.append(norm_group)
self.optim.param_groups[-1]["params"] = self.master_moe_params
del moe_grads
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm) self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# TODO: we should store master param for ep
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.data = param.data.to(torch.float32)
param.grad = param.grad.to(torch.float32)
# update the parameters # update the parameters
self.optim.step() self.optim.step()
# release the moe gradm # release moe grad
if len(self.param_groups) > len(self._working_param_groups): if len(self.working_moe_params) > 0:
for param in self.param_groups[-1]["params"]: for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
param.grad = None master_moe_param.grad = None
param.data = param.data.to(self._dtype) working_moe_param.data = (
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
)
# release the grad # release the grad
grad_partition_groups = [] grad_partition_groups = []
...@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else: else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
if hasattr(self, "master_moe_params"):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
if hasattr(self, "moe_master_to_working_map"):
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
return self._param_store.master_to_working_param return self._param_store.master_to_working_param
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict from colossalai.moe.utils import get_moe_epsize_param_dict
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "moe_info"):
delattr(param, "moe_info")
class MoeModel(nn.Module): class MoeModel(nn.Module):
...@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None): ...@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for i in range(world_size - 1): for i in range(world_size - 1):
a = tensor_list[i] a = tensor_list[i]
b = tensor_list[i + 1] b = tensor_list[i + 1]
assert not torch.allclose(a, b), \ assert not torch.allclose(a, b), (
(f"expected tensors on rank {i} and {i + 1} not to be equal " f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
f"but they are, {a} vs {b}") )
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol)
...@@ -12,7 +12,6 @@ import colossalai ...@@ -12,7 +12,6 @@ import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
sys.path.append( sys.path.append(
...@@ -95,6 +94,7 @@ def get_model(parallel): ...@@ -95,6 +94,7 @@ def get_model(parallel):
precision="bf16", precision="bf16",
tp_size=1, tp_size=1,
pp_size=1, pp_size=1,
ep_size=1,
zero_stage=2, zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(), custom_policy=OpenMoeForCausalLMPolicy(),
) )
...@@ -103,6 +103,7 @@ def get_model(parallel): ...@@ -103,6 +103,7 @@ def get_model(parallel):
precision="bf16", precision="bf16",
tp_size=1, tp_size=1,
pp_size=1, pp_size=1,
ep_size=dist.get_world_size(),
zero_stage=2, zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(), custom_policy=OpenMoeForCausalLMPolicy(),
) )
...@@ -111,6 +112,7 @@ def get_model(parallel): ...@@ -111,6 +112,7 @@ def get_model(parallel):
precision="bf16", precision="bf16",
tp_size=1, tp_size=1,
pp_size=1, pp_size=1,
ep_size=2,
zero_stage=2, zero_stage=2,
extra_dp_size=2, extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(), custom_policy=OpenMoeForCausalLMPolicy(),
...@@ -120,6 +122,7 @@ def get_model(parallel): ...@@ -120,6 +122,7 @@ def get_model(parallel):
precision="bf16", precision="bf16",
tp_size=1, tp_size=1,
pp_size=2, pp_size=2,
ep_size=2,
zero_stage=1, zero_stage=1,
microbatch_size=1, microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(), custom_policy=OpenMoeForCausalLMPolicy(),
...@@ -130,27 +133,6 @@ def get_model(parallel): ...@@ -130,27 +133,6 @@ def get_model(parallel):
def _test_moe_checkpoint(rank, parallel): def _test_moe_checkpoint(rank, parallel):
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
)
elif parallel == "ep":
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "ep_zero":
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=2,
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel) model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel) model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel) model3, booster3, optim3 = get_model(parallel)
...@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel): ...@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
_test_moe_checkpoint(rank, parallel) _test_moe_checkpoint(rank, parallel)
@pytest.mark.skip(reason="This is tested in ColossalMOE")
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"]) @pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
......
...@@ -4,15 +4,21 @@ import torch ...@@ -4,15 +4,21 @@ import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
@pytest.mark.parametrize(["router", "num_groups"], [ @pytest.mark.parametrize(
["router", "num_groups"],
[
(Top1Router(), 1), (Top1Router(), 1),
(Top2Router(), 1), (Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4), # (TopKRouter(num_selected_experts=3), 4),
]) ],
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ )
@pytest.mark.parametrize(
["batch_size", "seq_len", "num_experts"],
[
(4, 5, 8), (4, 5, 8),
(3, 4, 4), (3, 4, 4),
]) ],
)
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda() x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1: if num_groups > 1:
...@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex ...@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router.train() router.train()
if isinstance(router, TopKRouter): if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2) combine_array, dispatch_mask = router(x, expert_capacity=2)
else: else:
_, combine_array, dispatch_mask = router(x) combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval() router.eval()
if isinstance(router, TopKRouter): if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2) combine_array, dispatch_mask = router(x, expert_capacity=2)
else: else:
_, combine_array, dispatch_mask = router(x) combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
......
...@@ -4,102 +4,75 @@ import torch ...@@ -4,102 +4,75 @@ import torch
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size): def run_zero_test(local_rank, stage=1):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
zero_model = MoeModel() MOE_MANAGER.__init__()
optimizer = torch.optim.Adam(zero_model.parameters()) MOE_MANAGER.setup(parallel="EP")
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") moe_model = MoeModel().bfloat16()
booster = Booster(plugin=plugin) moe_optimizer = torch.optim.Adam(moe_model.parameters())
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
torch_model = MoeModel() moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data) MOE_MANAGER.__init__()
torch_model = torch_model.cuda() MOE_MANAGER.setup(parallel=None)
grad_handler = MoeGradientHandler(torch_model) zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
# assert zero model zero_optimizer = torch.optim.Adam(zero_model.parameters())
for (torch_name, torch_param), (zero_name, zero_param) in zip( zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
torch_model.named_parameters(), zero_model.module.named_parameters() zero_booster = Booster(plugin=zero_plugin)
): zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
assert zero_name == torch_name sync_local_from_ep(zero_model, moe_model)
assert torch.allclose(zero_param.data, torch_param.data)
data = torch.randn(16, 4).bfloat16().cuda()
data = torch.randn(16, 4).cuda()
label = torch.randint(0, 4, (16,)).cuda() label = torch.randint(0, 4, (16,)).cuda()
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
assert torch.allclose(torch_out, zero_out) assert torch.allclose(zero_out, moe_out)
grad_handler.handle_gradient()
for (zero_name, zero_param), (torch_name, torch_param) in zip( for (moe_name, moe_param), (zero_name, zero_param) in zip(
zero_model.module.named_parameters(), torch_model.named_parameters() moe_model.module.named_parameters(), zero_model.module.named_parameters()
): ):
assert zero_name == torch_name assert moe_name == zero_name
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
if hasattr(zero_param, "moe_info"): zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
assert len(zero_grad_list) == 0 if hasattr(moe_param, "moe_info"):
assert torch.allclose(zero_param.grad, torch_param.grad) assert len(moe_grad_list) == 0
if stage == 1:
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
else:
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
assert torch.allclose(
moe_param.grad, zero_grad, atol=1e-5
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
else: else:
assert len(zero_grad_list) > 0 assert len(moe_grad_list) > 0
torch_grad_list = split_ddp_grad(torch_param.grad, world_size) assert len(moe_grad_list) == len(zero_grad_list)
if stage == 2: for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
torch_grad_list = torch_grad_list[local_rank : local_rank + 1] assert torch.allclose(moe_grad, zero_grad)
assert len(zero_grad_list) == len(torch_grad_list)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
assert torch.allclose(zero_grad, torch_grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP")
seed_all(42 + rank) seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1) run_zero_test(rank, stage=stage)
run_zero_test(rank, world_size, stage=2)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_model(world_size): def test_moe_zero_model(world_size, stage):
spawn(run_dist, world_size) spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_zero_model(world_size=2) test_moe_zero_model(world_size=2, stage=1)
...@@ -4,89 +4,80 @@ import torch ...@@ -4,89 +4,80 @@ import torch
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size): def run_zero_test(local_rank, stage=1):
with torch.no_grad(): criterion = torch.nn.CrossEntropyLoss()
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
moe_model = MoeModel().bfloat16()
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
zero_model = MoeModel().bfloat16()
delete_moe_info(zero_model)
sync_local_from_ep(zero_model, moe_model)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
for (moe_name, moe_param), (zero_name, zero_param) in zip(
moe_model.named_parameters(), zero_model.named_parameters()
):
if ".experts." in moe_name:
continue
assert moe_name == zero_name
assert torch.allclose(
moe_param.data, zero_param.data
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
def run_zero_optim_test(local_rank, world_size, stage=1): for _ in range(1):
criterion = torch.nn.CrossEntropyLoss() data = torch.randn(2, 4).bfloat16().cuda()
label = torch.randint(0, 4, (2,)).cuda()
zero_model = MoeModel() moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
zero_optimizer = torch.optim.Adam(zero_model.parameters()) zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") assert torch.allclose(zero_out, moe_out)
booster = Booster(plugin=plugin) moe_optimizer.step()
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
torch_model = MoeModel()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)
for _ in range(2):
data = torch.randn(16, 4).cuda() / (local_rank + 1)
label = torch.randint(0, 4, (16,)).cuda()
run_fwd_bwd(torch_model, data, label, criterion, None)
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
grad_handler.handle_gradient()
torch_optimizer.step()
zero_optimizer.step() zero_optimizer.step()
for (torch_name, torch_param), (zero_name, zero_param) in zip( for (moe_name, moe_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.named_parameters() moe_model.named_parameters(), zero_model.named_parameters()
): ):
assert torch.allclose( assert moe_name == zero_name
torch_param.data, zero_param.data if is_moe_tensor(moe_param):
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" param_size = moe_param.shape[0]
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
torch_optimizer.zero_grad() moe_optimizer.zero_grad()
zero_optimizer.zero_grad() zero_optimizer.zero_grad()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank)
run_zero_optim_test(rank, world_size, stage=1) run_zero_test(rank, stage=stage)
run_zero_optim_test(rank, world_size, stage=2)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size): def test_moe_zero_optim(world_size, stage):
spawn(run_dist, world_size) spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_zero_optim(world_size=2) test_moe_zero_optim(world_size=2, stage=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment