r""" FMoE's parallel linear layer """ import torch import torch.nn as nn from torch.autograd import Function import math import fmoe_cuda class MOELinear(Function): r""" Computes linear operators within one GPU on different experts simutaneously. """ @staticmethod def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None): global_output_buf = fmoe_cuda.linear_forward( global_input_buf, fwd_expert_count, weight, bias ) variables = (global_input_buf, fwd_expert_count, weight, bias) ctx.save_for_backward(*variables) return global_output_buf @staticmethod def backward(ctx, grad_out): (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward( grad_out, input_buf, fwd_expert_count, weight, bias ) if not torch.is_tensor(bias): grad_bias = None return grad_inp_buf, None, grad_weight, grad_bias class FMoELinear(nn.Module): r""" A linear layer that contains multiple experts. As multiple experts can be placed on the same worker, the computation can be performed in parallel to increase the performance. The FMoELinear module provides such function. """ def __init__( self, num_expert: int, in_feat: int, out_feat: int, bias: bool = True, rank: int = 0, ): super().__init__() self.num_expert = num_expert self.in_feat = in_feat self.out_feat = out_feat self.rank = rank self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) if bias: self.bias = nn.Parameter(torch.zeros(num_expert, out_feat)) else: self.register_parameter("bias", None) self.reset_parameters() def forward(self, inp, fwd_expert_count): r""" Call MOE function """ x = MOELinear.apply(inp.type_as(self.weight), fwd_expert_count, self.weight, self.bias) return x def extra_repr(self) -> str: return "num_expert={}, in_features={}, \ out_features={}, bias={}, rank={}".format( self.num_expert, self.in_feat, self.out_feat, self.bias is not None, self.rank, ) def reset_parameters(self): # Approach is the same as in torch.nn.Linear # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88 # bias is left to zero, similar as megatron torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))