Unverified Commit c455b771 authored by RickyChen / 陳昭儒's avatar RickyChen / 陳昭儒 Committed by GitHub
Browse files

[Bugfix][CPU] Fix RotaryEmbedding fallback causing gibberish with --enforce-eager (#31643)


Signed-off-by: default avatarrickychen-infinirc <ricky.chen@infinirc.com>
parent eefa713a
...@@ -67,8 +67,9 @@ class CustomOp(nn.Module): ...@@ -67,8 +67,9 @@ class CustomOp(nn.Module):
return self.forward_native(*args, **kwargs) return self.forward_native(*args, **kwargs)
def forward_cpu(self, *args, **kwargs): def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops. # By default, we assume that CPU ops are compatible with the
return self.forward_cuda(*args, **kwargs) # PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_tpu(self, *args, **kwargs): def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the # By default, we assume that TPU ops are compatible with the
......
...@@ -250,6 +250,28 @@ class RotaryEmbedding(RotaryEmbeddingBase): ...@@ -250,6 +250,28 @@ class RotaryEmbedding(RotaryEmbeddingBase):
) )
return query, key return query, key
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from vllm import _custom_ops as ops
self._match_cos_sin_cache_dtype(query)
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}" s += f", max_position_embeddings={self.max_position_embeddings}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment