r""" FMoE core layer """ import torch import torch.nn as nn from .functions import prepare_forward, ensure_comm from .functions import MOEScatter, MOEGather from .functions import AllGather, Slice from .gates import NaiveGate def mark_module_parallel_comm(module, comm): r""" Mark all parameters in `module` as doing data parallel in `comm`, where `comm` may be one of `'world', 'dp', 'none'`. """ for p in module.parameters(): setattr(p, "dp_comm", comm) def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): r""" A private function that performs the following steps to complete the MoE computation. * Count the number of tokens from each worker to each expert. * Send the features to their target position so that input features to each expert are contiguous in memory. * Perform the forward computation of the experts using `expert_fn` * Gather the output features of experts back, and reorder them as sentences. Intermediate results like expert counts are hidden from users by this function. """ ( pos, local_expert_count, global_expert_count, fwd_expert_count, fwd_batch_size, ) = prepare_forward(gate, num_expert, world_size) topk = 1 if len(gate.shape) == 2: topk = gate.shape[1] x = MOEScatter.apply( inp, pos // topk, local_expert_count, global_expert_count, fwd_batch_size, world_size ) x = expert_fn(x, fwd_expert_count) out_batch_size = inp.shape[0] if len(gate.shape) == 2: out_batch_size *= gate.shape[1] x = MOEGather.apply( x, pos, local_expert_count, global_expert_count, out_batch_size, world_size ) return x class FMoE(nn.Module): r""" A general moe implementation that supports an arbitrary module as the expert. * `num_expert` stands for the number of experts on **each** worker. * `world_size` stands for the total number of workers that contains different experts. * `slice_group` can be a torch's communication group, indicating that specific model parallel is applied across the group, and workers in the group hold the same copy of input feature, and requires the same copy of the output. For each worker, FMoE only computes the output of a certain slice of the input batch, and will all-gather the outputs after computation. * `top_k` stands for the number of experts each token is going to. * `gate` is a gate class which can found in `fmoe.gates`. * `expert` can be specified as a module class, it is used to generate `num_expert` expert modules. """ def __init__( self, num_expert=32, d_model=1024, world_size=1, mp_group=None, # being deprecated slice_group=None, moe_group=None, top_k=2, gate=NaiveGate, expert=None, gate_hook=None, mask=None, mask_dict=None, ): super().__init__() self.num_expert = num_expert self.d_model = d_model self.world_size = world_size self.slice_group = slice_group if mp_group is not None: print('[Warning] mp_group is being deprecated') self.slice_group = mp_group if self.slice_group is None: self.slice_size = 1 self.slice_rank = 0 else: self.slice_size = self.slice_group.size() self.slice_rank = self.slice_group.rank() self.top_k = top_k if type(expert) is list: self.experts = nn.ModuleList([e(d_model) for e in expert]) self.experts_fused = False self.num_expert = num_expert = len(expert) elif expert is not None: self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)]) self.experts_fused = False else: self.experts_fused = True self.gate = gate(d_model, num_expert, world_size, top_k) self.gate_hook = gate_hook self.mask = mask self.mask_dict = mask_dict self.moe_group = moe_group def expert_fn(self, inp, fwd_expert_count): r""" The default expert function which either calls the experts as a whole or as separate experts. """ if self.experts_fused: return self.experts(inp, fwd_expert_count) outputs = [] base_idx = 0 for i in range(self.num_expert): batch_size = fwd_expert_count[i].item() inp_slice = inp[base_idx : base_idx + batch_size] outputs.append(self.experts[i](inp_slice)) base_idx += batch_size return torch.cat(outputs, dim=0) def mark_parallel_comm(self, expert_dp_comm="none"): r""" Automatically mark the data parallel comms of the parameters within the module. This can be typically called at the end of the __init__ function in child classes. """ if self.experts is not None: comm = expert_dp_comm if isinstance(self.experts, list): for e in self.experts: mark_module_parallel_comm(e, comm) else: mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.gate, "gate") def forward(self, inp): r""" The FMoE module first computes gate output, and then conduct MoE forward according to the gate. The score of the selected gate given by the expert is multiplied to the experts' output tensors as a weight. """ if self.world_size > 1: ensure_comm(inp, self.moe_group) if self.slice_size > 1: inp = Slice.apply(inp, self.slice_rank, self.slice_size, self.slice_group) gate_top_k_idx, gate_score = self.gate(inp) if self.gate_hook is not None: self.gate_hook(gate_top_k_idx, gate_score, None) # delete masked tensors if self.mask is not None and self.mask_dict is not None: mask = self.mask.view(-1) # to: (BxL') x d_model inp = inp[mask == 0, :] gate_top_k_idx = gate_top_k_idx[mask == 0, :] fwd = _fmoe_general_global_forward( inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size ) # recover deleted tensors if self.mask is not None and self.mask_dict is not None: # to: (BxL') x top_k x d_model fwd = fwd.view(-1, self.top_k, self.d_model) # to: (BxL) x top_k x d_model x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype) # recover x[mask == 0] = fwd for k, v in self.mask_dict.items(): x[mask == k] = v else: x = fwd.view(-1, self.top_k, self.d_model) gate_score = gate_score.view(x.shape[0], 1, self.top_k) x = torch.bmm(gate_score, x).reshape(-1, self.d_model) if self.slice_size > 1: x = AllGather.apply(x, self.slice_rank, self.slice_size, self.slice_group) return x