"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "f2eec77b14f5e937656982cec9c994904b6edf66"
Commit c0a3a425 authored by Rick Ho's avatar Rick Ho
Browse files

change expert_fn structure

parent a88d1124
...@@ -112,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -112,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class FMoE(nn.Module): class FMoE(nn.Module):
r''' r'''
A general moe implementation that supports an arbitrary module as the expert A general moe implementation that supports an arbitrary module as the
Either `expert` or `expert_fn` is required. expert.
* `num_expert` stands for the number of experts on **each** worker. * `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains * `world_size` stands for the total number of workers that contains
different experts. different experts.
...@@ -126,12 +126,9 @@ class FMoE(nn.Module): ...@@ -126,12 +126,9 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`. * `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate * `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules. `num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
''' '''
def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None, def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
top_k=2, gate=NaiveGate, expert=None, expert_fn=None): top_k=2, gate=NaiveGate, expert=None):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
...@@ -145,19 +142,20 @@ class FMoE(nn.Module): ...@@ -145,19 +142,20 @@ class FMoE(nn.Module):
self.mp_rank = mp_group.rank() self.mp_rank = mp_group.rank()
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
if expert_fn is None: if expert is not None:
assert expert is not None, 'Either expert or expert_fn should be set'
self.experts = [expert(d_model) for _ in range(num_expert)] self.experts = [expert(d_model) for _ in range(num_expert)]
def expert_fn(inp, fwd_expert_count):
outputs = [] def expert_fn(self, inp, fwd_expert_count):
base_idx = 0 if isinstance(self.experts, nn.Module):
for i in range(self.num_expert): return self.experts(inp, fwd_expert_count)
batch_size = fwd_expert_count[i].item() outputs = []
inp_slice = inp[base_idx:base_idx + batch_size] base_idx = 0
outputs.append(self.experts[i](inp_slice)) for i in range(self.num_expert):
base_idx += batch_size batch_size = fwd_expert_count[i].item()
return torch.cat(outputs, dim=0) inp_slice = inp[base_idx:base_idx + batch_size]
self.expert_fn = expert_fn outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
def mark_parallel_comm(self): def mark_parallel_comm(self):
r''' r'''
...@@ -193,7 +191,8 @@ class FMoE(nn.Module): ...@@ -193,7 +191,8 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model # to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0) inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn, expert_fn = lambda inp, fec: self.expert_fn(inp, fec)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, expert_fn,
self.num_expert, self.world_size) self.num_expert, self.world_size)
# to: (BxL) x top_k x d_model # to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model) x = x.view(-1, self.top_k, self.d_model)
......
...@@ -49,11 +49,8 @@ class FMoETransformerMLP(FMoE): ...@@ -49,11 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k=2, top_k=2,
pre_lnorm=False pre_lnorm=False
): ):
def expert_fn(inp, gate):
return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group, top_k=top_k, world_size=world_size, mp_group=mp_group)
expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank) rank=self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment