Unverified Commit 74b6908f authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #90 from xptree/multi_input

MoE with multiple inputs and multiple outputs
parents b652e8d8 0abea7b2
r""" r"""
FMoE core layer FMoE core layer
""" """
import tree
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,7 +11,6 @@ from .functions import AllGather, Slice ...@@ -10,7 +11,6 @@ from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r""" r"""
Mark all parameters in `module` as doing data parallel in `comm`, where Mark all parameters in `module` as doing data parallel in `comm`, where
...@@ -42,22 +42,37 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -42,22 +42,37 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
topk = gate.shape[1] topk = gate.shape[1]
x = MOEScatter.apply(
inp, pos // topk, def scatter_func(tensor):
local_expert_count, global_expert_count, fwd_batch_size, world_size return MOEScatter.apply(
) tensor,
pos // topk,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
x = tree.map_structure(scatter_func, inp)
x = expert_fn(x, fwd_expert_count) x = expert_fn(x, fwd_expert_count)
out_batch_size = inp.shape[0] out_batch_size = tree.flatten(inp)[0].shape[0]
if len(gate.shape) == 2: if len(gate.shape) == 2:
out_batch_size *= gate.shape[1] out_batch_size *= gate.shape[1]
x = MOEGather.apply( def gather_func(tensor):
x, pos, return MOEGather.apply(
local_expert_count, global_expert_count, tensor,
out_batch_size, world_size pos,
) local_expert_count,
return x global_expert_count,
out_batch_size,
world_size,
)
outp = tree.map_structure(gather_func, x)
return outp
class FMoE(nn.Module): class FMoE(nn.Module):
...@@ -84,7 +99,7 @@ class FMoE(nn.Module): ...@@ -84,7 +99,7 @@ class FMoE(nn.Module):
num_expert=32, num_expert=32,
d_model=1024, d_model=1024,
world_size=1, world_size=1,
mp_group=None, # being deprecated mp_group=None, # being deprecated
slice_group=None, slice_group=None,
moe_group=None, moe_group=None,
top_k=2, top_k=2,
...@@ -101,7 +116,7 @@ class FMoE(nn.Module): ...@@ -101,7 +116,7 @@ class FMoE(nn.Module):
self.slice_group = slice_group self.slice_group = slice_group
if mp_group is not None: if mp_group is not None:
print('[Warning] mp_group is being deprecated') print("[Warning] mp_group is being deprecated")
self.slice_group = mp_group self.slice_group = mp_group
if self.slice_group is None: if self.slice_group is None:
self.slice_size = 1 self.slice_size = 1
...@@ -116,8 +131,7 @@ class FMoE(nn.Module): ...@@ -116,8 +131,7 @@ class FMoE(nn.Module):
self.experts_fused = False self.experts_fused = False
self.num_expert = num_expert = len(expert) self.num_expert = num_expert = len(expert)
elif expert is not None: elif expert is not None:
self.experts = nn.ModuleList([expert(d_model) self.experts = nn.ModuleList([expert(d_model) for _ in range(num_expert)])
for _ in range(num_expert)])
self.experts_fused = False self.experts_fused = False
else: else:
self.experts_fused = True self.experts_fused = True
...@@ -159,52 +173,109 @@ class FMoE(nn.Module): ...@@ -159,52 +173,109 @@ class FMoE(nn.Module):
mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "gate") mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp): def forward(self, moe_inp):
r""" r"""
The FMoE module first computes gate output, and then conduct MoE forward The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight. expert is multiplied to the experts' output tensors as a weight.
""" """
moe_inp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
)
assert all(
[batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
), "MoE inputs must have the same batch size"
if self.world_size > 1: if self.world_size > 1:
ensure_comm(inp, self.moe_group)
def ensure_comm_func(tensor):
ensure_comm(tensor, self.moe_group)
tree.map_structure(ensure_comm_func, moe_inp)
if self.slice_size > 1: if self.slice_size > 1:
inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp) def slice_func(tensor):
return Slice.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_inp = tree.map_structure(slice_func, moe_inp)
gate_top_k_idx, gate_score = self.gate(moe_inp)
if self.gate_hook is not None: if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None) self.gate_hook(gate_top_k_idx, gate_score, None)
# delete masked tensors # delete masked tensors
if self.mask is not None and self.mask_dict is not None: if self.mask is not None and self.mask_dict is not None:
# TODO: to fix
def delete_mask_func(tensor):
# to: (BxL') x d_model
tensor = tensor[mask == 0, :]
return tensor
mask = self.mask.view(-1) mask = self.mask.view(-1)
# to: (BxL') x d_model moe_inp = tree.map_structure(delete_mask_func, moe_inp)
inp = inp[mask == 0, :]
gate_top_k_idx = gate_top_k_idx[mask == 0, :] gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward( fwd = _fmoe_general_global_forward(
inp, gate_top_k_idx, moe_inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
self.expert_fn, self.num_expert, self.world_size
) )
# recover deleted tensors # recover deleted tensors
if self.mask is not None and self.mask_dict is not None: if self.mask is not None and self.mask_dict is not None:
# to: (BxL') x top_k x d_model
fwd = fwd.view(-1, self.top_k, self.d_model) def recover_func(tensor):
# to: (BxL) x top_k x d_model # to: (BxL') x top_k x dim
x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype) dim = tensor.shape[-1]
# recover tensor = tensor.view(-1, self.top_k, dim)
x[mask == 0] = fwd # to: (BxL) x top_k x d_model
for k, v in self.mask_dict.items(): x = torch.zeros(
x[mask == k] = v mask.shape[0],
self.top_k,
dim,
device=tensor.device,
dtype=tensor.dtype,
)
# recover
x[mask == 0] = tensor
for k, v in self.mask_dict.items():
x[mask == k] = v
return x
moe_outp = tree.map_structure(recover_func, fwd)
else: else:
x = fwd.view(-1, self.top_k, self.d_model)
gate_score = gate_score.view(x.shape[0], 1, self.top_k) def view_func(tensor):
x = torch.bmm(gate_score, x).reshape(-1, self.d_model) dim = tensor.shape[-1]
tensor = tensor.view(-1, self.top_k, dim)
return tensor
moe_outp = tree.map_structure(view_func, fwd)
gate_score = gate_score.view(-1, 1, self.top_k)
def bmm_func(tensor):
dim = tensor.shape[-1]
tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
return tensor
moe_outp = tree.map_structure(bmm_func, moe_outp)
if self.slice_size > 1: if self.slice_size > 1:
x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group) def all_gather_func(tensor):
return x return AllGather.apply(
tensor, self.slice_rank, self.slice_size, self.slice_group
)
moe_outp = tree.map_structure(all_gather_func, moe_outp)
moe_outp_batch_size = tree.flatten(
tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
)
assert all(
[batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
), "MoE outputs must have the same batch size"
return moe_outp
import sys
import pytest
import torch
import torch.nn as nn
import numpy as np
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.linear import FMoELinear
from fmoe.megatron.layers import _megatron_init_method
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().max()
print("Rank {} {} abs err {}".format(rank, name, err))
if err > precision:
sys.stderr.write(f"=========== {name} moe out ==============\n")
sys.stderr.write("{}\n".format(mo))
sys.stderr.write(f"=========== {name} raw out ==============\n")
sys.stderr.write("{}\n".format(ro))
sys.stderr.write(f"=========== {name} diff ==============\n")
sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
assert False
class MyExpert(nn.Module):
r"""
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
"""
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
r"""
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
"""
if type(inp) == dict:
x = inp["x"]
y = inp["y"]
elif type(inp) == list:
x = inp[0]
y = inp[1]
else:
raise NotImplementedError
x = self.htoh4(x, fwd_expert_count)
x = self.activation(x)
x = self.h4toh(x, fwd_expert_count)
y = self.htoh4(y, fwd_expert_count)
y = self.activation(y)
y = self.h4toh(y, fwd_expert_count)
if type(inp) == dict:
ret = {"x": x, "y": y}
elif type(inp) == list:
ret = [x, y]
return ret
class MyGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__(d_model, num_expert, world_size, top_k)
def forward(self, inp, return_all_scores=False):
if type(inp) == dict:
x = inp["x"]
elif type(inp) == list:
x = inp[0]
else:
raise NotImplementedError
return super().forward(x, return_all_scores)
class MyMoE(FMoE):
def __init__(
self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=MyGate,
world_size=world_size,
mp_group=mp_group,
top_k=top_k,
)
self.experts = MyExpert(num_expert, d_model, d_hidden, activation)
rng = np.random.default_rng(1234)
_megatron_init_method(self.experts.htoh4, rng, 1.0)
_megatron_init_method(self.experts.h4toh, rng, 1.0)
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize(
"data_type", ["torch.FloatTensor", "torch.DoubleTensor", "torch.HalfTensor"]
)
@pytest.mark.parametrize("list_input", [False, True])
def test_fmoe_mimo_linear(
num_expert,
top_k,
batch_size,
d_model,
d_hidden,
rank,
world_size,
mp_group,
dp_group,
world_group,
data_type,
list_input,
activation=torch.nn.functional.gelu,
):
torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank)
moe = MyMoE(
num_expert=num_expert,
d_model=d_model,
d_hidden=4 * d_model,
world_size=world_size,
mp_group=mp_group,
top_k=top_k,
activation=activation,
).cuda()
x = torch.rand(batch_size, d_model).cuda()
inp = [x, x.clone()] if list_input else {"x": x, "y": x.clone()}
moe_out = moe(inp)
if list_input:
_assert_numerical(["x"], [moe_out[0]], [moe_out[1]], rank)
else:
_assert_numerical(["x"], [moe_out["x"]], [moe_out["y"]], rank)
if __name__ == "__main__":
test_fmoe_mimo_linear(
batch_size=2,
num_expert=2,
d_model=2,
top_k=2,
d_hidden=16,
rank=0,
world_size=1,
mp_group=None,
dp_group=None,
world_group=None,
data_type=torch.float32,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment