"templates/vscode:/vscode.git/clone" did not exist on "fd405e9a93f066cf1230ce4d53e2ade73c4a5497"
Unverified Commit f13f544a authored by Shiyu Li's avatar Shiyu Li Committed by GitHub
Browse files

Fix switch transformer mixed precision issue (#27220)

* Fix mixed precision error for switch transformer

* Fixup
parent db69bd88
......@@ -286,7 +286,7 @@ class GPTSanJapaneseSparseMLP(nn.Module):
next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices])
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
......
......@@ -318,7 +318,7 @@ class SwitchTransformersSparseMLP(nn.Module):
next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices])
next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment