Commit 8f67b530 authored by Rick Ho's avatar Rick Ho
Browse files

use nn.modulelist to wrap up customized experts

parent 704092b1
...@@ -143,10 +143,14 @@ class FMoE(nn.Module): ...@@ -143,10 +143,14 @@ class FMoE(nn.Module):
self.top_k = top_k self.top_k = top_k
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
if expert is not None: if expert is not None:
self.experts = [expert(d_model) for _ in range(num_expert)] self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)])
self.experts_fused = False
else:
self.experts_fused = True
def expert_fn(self, inp, fwd_expert_count): def expert_fn(self, inp, fwd_expert_count):
if isinstance(self.experts, nn.Module): if self.experts_fused:
return self.experts(inp, fwd_expert_count) return self.experts(inp, fwd_expert_count)
outputs = [] outputs = []
base_idx = 0 base_idx = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment