Commit 969ef607 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

merge topk and only forward moe once

parent b9084e90
...@@ -70,14 +70,23 @@ class CustomizedMoEPositionwiseFF(nn.Module): ...@@ -70,14 +70,23 @@ class CustomizedMoEPositionwiseFF(nn.Module):
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k] gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k) gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
# gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k)
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k) gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
core_out = [] #core_out = []
inp = inp.view(-1, self.d_model) inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k, dim=0) # (BxLxtop_k) x d_model
inp = F.pad(inp, pad=(0, 1), mode='constant', value=1.0) inp = F.pad(inp, pad=(0, 1), mode='constant', value=1.0)
x = self.moe1(inp, gate_top_k_idx)
x = self.dropout(F.relu(x))
x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x = self.moe2(x, gate_top_k_idx)
x = self.dropout(x) # (BxLxtop_k) x d_model
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
"""
for i in range(self.top_k): for i in range(self.top_k):
gate_idx = gate_top_k_idx[:, i].contiguous() gate_idx = gate_top_k_idx[:, i].contiguous()
x = self.moe1(inp, gate_idx) x = self.moe1(inp, gate_idx)
...@@ -88,6 +97,7 @@ class CustomizedMoEPositionwiseFF(nn.Module): ...@@ -88,6 +97,7 @@ class CustomizedMoEPositionwiseFF(nn.Module):
core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
"""
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model) core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment