Unverified Commit 97f439ae authored by mksit's avatar mksit Committed by GitHub
Browse files

Create the return value on device to avoid unnecessary copying from CPU (#26151)

parent 42791a57
......@@ -779,7 +779,7 @@ class SwitchTransformersBlock(nn.Module):
if isinstance(hidden_states, tuple):
hidden_states, router_tuple = hidden_states
else:
router_tuple = (torch.tensor([0], device=hidden_states.device),)
router_tuple = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment