Commit c0a3a425 authored by Rick Ho's avatar Rick Ho
Browse files

change expert_fn structure

parent a88d1124
......@@ -112,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class FMoE(nn.Module):
r'''
A general moe implementation that supports an arbitrary module as the expert
Either `expert` or `expert_fn` is required.
A general moe implementation that supports an arbitrary module as the
expert.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
......@@ -126,12 +126,9 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
'''
def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
top_k=2, gate=NaiveGate, expert=None, expert_fn=None):
top_k=2, gate=NaiveGate, expert=None):
super().__init__()
self.num_expert = num_expert
self.d_model = d_model
......@@ -145,10 +142,12 @@ class FMoE(nn.Module):
self.mp_rank = mp_group.rank()
self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k)
if expert_fn is None:
assert expert is not None, 'Either expert or expert_fn should be set'
if expert is not None:
self.experts = [expert(d_model) for _ in range(num_expert)]
def expert_fn(inp, fwd_expert_count):
def expert_fn(self, inp, fwd_expert_count):
if isinstance(self.experts, nn.Module):
return self.experts(inp, fwd_expert_count)
outputs = []
base_idx = 0
for i in range(self.num_expert):
......@@ -157,7 +156,6 @@ class FMoE(nn.Module):
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
self.expert_fn = expert_fn
def mark_parallel_comm(self):
r'''
......@@ -193,7 +191,8 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn,
expert_fn = lambda inp, fec: self.expert_fn(inp, fec)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, expert_fn,
self.num_expert, self.world_size)
# to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model)
......
......@@ -49,11 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k=2,
pre_lnorm=False
):
def expert_fn(inp, gate):
return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn)
top_k=top_k, world_size=world_size, mp_group=mp_group)
self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment