Unverified Commit fb6cc7b0 authored by zhangdonghao-zdh's avatar zhangdonghao-zdh Committed by GitHub
Browse files

Fix RotaryEmbedding for fp32 input (#11843)

parent 8374a96e
......@@ -112,7 +112,7 @@ class RotaryEmbedding(CustomOp):
if not _is_cuda:
cache = cache.to(dtype)
if (
if dtype == torch.float32 or (
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
and not (_is_cpu and _is_cpu_amx_available)
and not _is_xpu
......@@ -254,7 +254,11 @@ class RotaryEmbedding(CustomOp):
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
if (
_is_cuda
and (self.head_size in [64, 128, 256, 512])
and self.dtype != torch.float32
):
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment