Unverified Commit f13f544a authored by Shiyu Li's avatar Shiyu Li Committed by GitHub
Browse files

Fix switch transformer mixed precision issue (#27220)

* Fix mixed precision error for switch transformer

* Fixup
parent db69bd88
...@@ -286,7 +286,7 @@ class GPTSanJapaneseSparseMLP(nn.Module): ...@@ -286,7 +286,7 @@ class GPTSanJapaneseSparseMLP(nn.Module):
next_states = hidden_states.clone() next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()): for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool() token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]) next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
hidden_states = router_probs * next_states hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index) return hidden_states, (router_logits, expert_index)
......
...@@ -318,7 +318,7 @@ class SwitchTransformersSparseMLP(nn.Module): ...@@ -318,7 +318,7 @@ class SwitchTransformersSparseMLP(nn.Module):
next_states = hidden_states.clone() next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()): for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool() token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices]) next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
hidden_states = router_probs * next_states hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index) return hidden_states, (router_logits, expert_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment