Commit 2d81858b authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

allow netsd data struct as moe input

parent b652e8d8
r"""
FMoE core layer
"""
import tree
import torch
import torch.nn as nn
......@@ -10,7 +11,6 @@ from .functions import AllGather, Slice
from .gates import NaiveGate
def mark_module_parallel_comm(module, comm):
r"""
Mark all parameters in `module` as doing data parallel in `comm`, where
......@@ -42,22 +42,39 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
topk = 1
if len(gate.shape) == 2:
topk = gate.shape[1]
x = MOEScatter.apply(
inp, pos // topk,
local_expert_count, global_expert_count, fwd_batch_size, world_size
)
def scatter_func(inp_tensor):
tensor = MOEScatter.apply(
inp_tensor,
pos // topk,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
return tensor
x = tree.map_structure(scatter_func, inp)
x = expert_fn(x, fwd_expert_count)
out_batch_size = inp.shape[0]
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
x = MOEGather.apply(
x, pos,
local_expert_count, global_expert_count,
out_batch_size, world_size
)
return x
def gatter_func(outp_tensor):
tensor = MOEGather.apply(
outp_tensor,
pos,
local_expert_count,
global_expert_count,
out_batch_size,
world_size,
)
return tensor
outp = tree.map_structure(gatter_func, x)
return outp
class FMoE(nn.Module):
......@@ -84,7 +101,7 @@ class FMoE(nn.Module):
num_expert=32,
d_model=1024,
world_size=1,
mp_group=None, # being deprecated
mp_group=None, # being deprecated
slice_group=None,
moe_group=None,
top_k=2,
......@@ -101,7 +118,7 @@ class FMoE(nn.Module):
self.slice_group = slice_group
if mp_group is not None:
print('[Warning] mp_group is being deprecated')
print("[Warning] mp_group is being deprecated")
self.slice_group = mp_group
if self.slice_group is None:
self.slice_size = 1
......@@ -116,8 +133,7 @@ class FMoE(nn.Module):
self.experts_fused = False
self.num_expert = num_expert = len(expert)
elif expert is not None:
self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)])
self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
self.experts_fused = False
else:
self.experts_fused = True
......@@ -159,52 +175,94 @@ class FMoE(nn.Module):
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp):
def forward(self, moe_inp, non_moe_inp=None):
r"""
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
if self.world_size > 1:
ensure_comm(inp, self.moe_group)
def ensure_comm_func(tensor):
ensure_comm(tensor, self.moe_group)
tree.map_structure(ensure_comm_func, moe_inp)
if self.slice_size > 1:
inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp)
def slice_func(tensor):
return Slice.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_inp = tree.map_structure(slice_func, moe_inp)
gate_top_k_idx, gate_score = self.gate(moe_inp)
if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None)
# TODO: to fix
def delete_mask_func(tensor):
# to: (BxL') x d_model
tensor = tensor[mask == 0, :]
return tensor
# delete masked tensors
if self.mask is not None and self.mask_dict is not None:
mask = self.mask.view(-1)
# to: (BxL') x d_model
inp = inp[mask == 0, :]
moe_inp = tree.map_structure(delete_mask_func, moe_inp)
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward(
inp, gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size
moe_inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
)
# recover deleted tensors
if self.mask is not None and self.mask_dict is not None:
# to: (BxL') x top_k x d_model
fwd = fwd.view(-1, self.top_k, self.d_model)
# to: (BxL) x top_k x d_model
x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype)
# recover
x[mask == 0] = fwd
for k, v in self.mask_dict.items():
x[mask == k] = v
def recover_func(tensor):
# to: (BxL') x top_k x dim
dim = tensor.shape[-1]
tensor = tensor.view(-1, self.top_k, dim)
# to: (BxL) x top_k x d_model
x = torch.zeros(
mask.shape[0],
self.top_k,
dim,
device=tensor.device,
dtype=tensor.dtype,
)
# recover
x[mask == 0] = tensor
for k, v in self.mask_dict.items():
x[mask == k] = v
return x
moe_outp = tree.map_structure(recover_func, fwd)
else:
x = fwd.view(-1, self.top_k, self.d_model)
gate_score = gate_score.view(x.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
def view_func(tensor):
dim = tensor.shape[-1]
tensor = tensor.view(-1, self.top_k, dim)
return tensor
moe_outp = tree.map_structure(view_func, fwd)
gate_score = gate_score.view(-1, 1, self.top_k)
def bmm_func(tensor):
dim = tensor.shape[-1]
tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
return tensor
moe_outp = tree.map_structure(bmm_func, moe_outp)
if self.slice_size > 1:
x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group)
return x
def all_gather_func(tensor):
return AllGather.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_outp = tree.map_structure(all_gather_func, moe_outp)
return moe_outp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment