Unverified Commit 9b85e405 authored by Ranggi Hwang's avatar Ranggi Hwang Committed by GitHub
Browse files

[`SwitchTransformer`] Significant performance improvement on MoE blocks (#31173)

* SwitchTransformer MoE layer performance improvement

* make fixup

* comments about shapes

* make fixup
parent 8177aa0e
......@@ -294,9 +294,17 @@ class SwitchTransformersSparseMLP(nn.Module):
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
router_mask = router_mask.bool()
batch_size, seq_len, num_experts = router_mask.shape
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
hidden_states[router_mask[:, :, idx]]
)
hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment