"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "e80fbd7e679e726b71af8b86159f5c4c6c474df5"
Unverified Commit 4473d81f authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Make sure RoPE frequencies are in FP32 (#875)



Make sure RoPE frequencies are in FP32
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 9ff2c076
...@@ -1432,6 +1432,8 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -1432,6 +1432,8 @@ class FusedRoPEFunc(torch.autograd.Function):
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd": if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False) output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd": elif tensor_format == "bshd":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment