Commit 2d81858b authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

allow netsd data struct as moe input

parent b652e8d8
r""" r"""
FMoE core layer FMoE core layer
""" """
import tree
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,7 +11,6 @@ from .functions import AllGather, Slice ...@@ -10,7 +11,6 @@ from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r""" r"""
Mark all parameters in `module` as doing data parallel in `comm`, where Mark all parameters in `module` as doing data parallel in `comm`, where
...@@ -42,22 +42,39 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -42,22 +42,39 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
topk = gate.shape[1] topk = gate.shape[1]
x = MOEScatter.apply(
inp, pos // topk, def scatter_func(inp_tensor):
local_expert_count, global_expert_count, fwd_batch_size, world_size tensor = MOEScatter.apply(
inp_tensor,
pos // topk,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
) )
return tensor
x = tree.map_structure(scatter_func, inp)
x = expert_fn(x, fwd_expert_count) x = expert_fn(x, fwd_expert_count)
out_batch_size = inp.shape[0] out_batch_size = inp.shape[0]
if len(gate.shape) == 2: if len(gate.shape) == 2:
out_batch_size *= gate.shape[1] out_batch_size *= gate.shape[1]
x = MOEGather.apply( def gatter_func(outp_tensor):
x, pos, tensor = MOEGather.apply(
local_expert_count, global_expert_count, outp_tensor,
out_batch_size, world_size pos,
local_expert_count,
global_expert_count,
out_batch_size,
world_size,
) )
return x return tensor
outp = tree.map_structure(gatter_func, x)
return outp
class FMoE(nn.Module): class FMoE(nn.Module):
...@@ -101,7 +118,7 @@ class FMoE(nn.Module): ...@@ -101,7 +118,7 @@ class FMoE(nn.Module):
self.slice_group = slice_group self.slice_group = slice_group
if mp_group is not None: if mp_group is not None:
print('[Warning] mp_group is being deprecated') print("[Warning] mp_group is being deprecated")
self.slice_group = mp_group self.slice_group = mp_group
if self.slice_group is None: if self.slice_group is None:
self.slice_size = 1 self.slice_size = 1
...@@ -116,8 +133,7 @@ class FMoE(nn.Module): ...@@ -116,8 +133,7 @@ class FMoE(nn.Module):
self.experts_fused = False self.experts_fused = False
self.num_expert = num_expert = len(expert) self.num_expert = num_expert = len(expert)
elif expert is not None: elif expert is not None:
self.experts = nn.ModuleList([expert(d_model) self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
for _ in range(num_expert)])
self.experts_fused = False self.experts_fused = False
else: else:
self.experts_fused = True self.experts_fused = True
...@@ -159,52 +175,94 @@ class FMoE(nn.Module): ...@@ -159,52 +175,94 @@ class FMoE(nn.Module):
mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "gate") mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp): def forward(self, moe_inp, non_moe_inp=None):
r""" r"""
The FMoE module first computes gate output, and then conduct MoE forward The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight. expert is multiplied to the experts' output tensors as a weight.
""" """
if self.world_size > 1: if self.world_size > 1:
ensure_comm(inp, self.moe_group)
def ensure_comm_func(tensor):
ensure_comm(tensor, self.moe_group)
tree.map_structure(ensure_comm_func, moe_inp)
if self.slice_size > 1: if self.slice_size > 1:
inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp) def slice_func(tensor):
return Slice.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_inp = tree.map_structure(slice_func, moe_inp)
gate_top_k_idx, gate_score = self.gate(moe_inp)
if self.gate_hook is not None: if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None) self.gate_hook(gate_top_k_idx, gate_score, None)
# TODO: to fix
def delete_mask_func(tensor):
# to: (BxL') x d_model
tensor = tensor[mask == 0, :]
return tensor
# delete masked tensors # delete masked tensors
if self.mask is not None and self.mask_dict is not None: if self.mask is not None and self.mask_dict is not None:
mask = self.mask.view(-1) mask = self.mask.view(-1)
# to: (BxL') x d_model moe_inp = tree.map_structure(delete_mask_func, moe_inp)
inp = inp[mask == 0, :]
gate_top_k_idx = gate_top_k_idx[mask == 0, :] gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward( fwd = _fmoe_general_global_forward(
inp, gate_top_k_idx, moe_inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
self.expert_fn, self.num_expert, self.world_size
) )
# recover deleted tensors # recover deleted tensors
if self.mask is not None and self.mask_dict is not None: if self.mask is not None and self.mask_dict is not None:
# to: (BxL') x top_k x d_model
fwd = fwd.view(-1, self.top_k, self.d_model) def recover_func(tensor):
# to: (BxL') x top_k x dim
dim = tensor.shape[-1]
tensor = tensor.view(-1, self.top_k, dim)
# to: (BxL) x top_k x d_model # to: (BxL) x top_k x d_model
x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype) x = torch.zeros(
mask.shape[0],
self.top_k,
dim,
device=tensor.device,
dtype=tensor.dtype,
)
# recover # recover
x[mask == 0] = fwd x[mask == 0] = tensor
for k, v in self.mask_dict.items(): for k, v in self.mask_dict.items():
x[mask == k] = v x[mask == k] = v
return x
moe_outp = tree.map_structure(recover_func, fwd)
else: else:
x = fwd.view(-1, self.top_k, self.d_model)
gate_score = gate_score.view(x.shape[0], 1, self.top_k) def view_func(tensor):
x = torch.bmm(gate_score, x).reshape(-1, self.d_model) dim = tensor.shape[-1]
tensor = tensor.view(-1, self.top_k, dim)
return tensor
moe_outp = tree.map_structure(view_func, fwd)
gate_score = gate_score.view(-1, 1, self.top_k)
def bmm_func(tensor):
dim = tensor.shape[-1]
tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
return tensor
moe_outp = tree.map_structure(bmm_func, moe_outp)
if self.slice_size > 1: if self.slice_size > 1:
x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group) def all_gather_func(tensor):
return x return AllGather.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_outp = tree.map_structure(all_gather_func, moe_outp)
return moe_outp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment