Unverified Commit 4473d81f authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Make sure RoPE frequencies are in FP32 (#875)



Make sure RoPE frequencies are in FP32
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 9ff2c076
......@@ -1432,6 +1432,8 @@ class FusedRoPEFunc(torch.autograd.Function):
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment