Unverified Commit fb6cc7b0 authored by zhangdonghao-zdh's avatar zhangdonghao-zdh Committed by GitHub
Browse files

Fix RotaryEmbedding for fp32 input (#11843)

parent 8374a96e
...@@ -112,7 +112,7 @@ class RotaryEmbedding(CustomOp): ...@@ -112,7 +112,7 @@ class RotaryEmbedding(CustomOp):
if not _is_cuda: if not _is_cuda:
cache = cache.to(dtype) cache = cache.to(dtype)
if ( if dtype == torch.float32 or (
(not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
and not (_is_cpu and _is_cpu_amx_available) and not (_is_cpu and _is_cpu_amx_available)
and not _is_xpu and not _is_xpu
...@@ -254,7 +254,11 @@ class RotaryEmbedding(CustomOp): ...@@ -254,7 +254,11 @@ class RotaryEmbedding(CustomOp):
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]): if (
_is_cuda
and (self.head_size in [64, 128, 256, 512])
and self.dtype != torch.float32
):
apply_rope_with_cos_sin_cache_inplace( apply_rope_with_cos_sin_cache_inplace(
positions=positions, positions=positions,
query=query, query=query,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment