Unverified Commit 5fca839f authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix device issue in `SwitchTransformers` (#24352)



* fix

* Update src/transformers/models/switch_transformers/modeling_switch_transformers.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3b5a56e5
...@@ -798,7 +798,7 @@ class SwitchTransformersBlock(nn.Module): ...@@ -798,7 +798,7 @@ class SwitchTransformersBlock(nn.Module):
if isinstance(hidden_states, tuple): if isinstance(hidden_states, tuple):
hidden_states, router_tuple = hidden_states hidden_states, router_tuple = hidden_states
else: else:
router_tuple = (torch.tensor([0]),) router_tuple = (torch.tensor([0], device=hidden_states.device),)
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment