Unverified Commit ff7333c7 authored by Colin's avatar Colin Committed by GitHub
Browse files

mask and experts list (#2)

- mask some tensors of tokens for fmoe forward
- pass a list of expert classes to specify what experts in what order want to use
parent 28ba2d28
......@@ -205,7 +205,7 @@ class FMoE(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp)
# delete masked tensors
if self.mask != None and self.mask_dict != None:
if self.mask is not None and self.mask_dict is not None:
mask = self.mask.view(-1)
# to: (BxL') x d_model
inp = inp[mask == 0, :]
......@@ -218,11 +218,11 @@ class FMoE(nn.Module):
)
# recover deleted tensors
if self.mask != None and self.mask_dict != None:
if self.mask is not None and self.mask_dict is not None:
# to: (BxL') x top_k x d_model
fwd = fwd.view(-1, self.top_k, self.d_model)
# to: (BxL) x top_k x d_model
x = torch.zeros(mask.shape[0], self.top_k, self.d_model)
x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype)
# recover
x[mask == 0] = fwd
for k, v in self.mask_dict.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment