"docs/vscode:/vscode.git/clone" did not exist on "ab986769f1a6401bd1d0a1faf17e85dc67c2e8c4"
Unverified Commit 63c13a2c authored by Kyungmin Lee's avatar Kyungmin Lee Committed by GitHub
Browse files

fix: import vllm_rotary_embedding error when head_size not in 64, 128, 256, 512 (#5733)

parent 4d1e52ab
...@@ -14,8 +14,6 @@ _is_cuda = is_cuda() ...@@ -14,8 +14,6 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else:
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -84,6 +82,12 @@ class RotaryEmbedding(CustomOp): ...@@ -84,6 +82,12 @@ class RotaryEmbedding(CustomOp):
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda: if not _is_cuda:
cache = cache.to(dtype) cache = cache.to(dtype)
if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
from vllm._custom_ops import rotary_embedding
self.vllm_rotary_embedding = rotary_embedding
self.cos_sin_cache: torch.Tensor self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
...@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp): ...@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp):
) )
else: else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
vllm_rotary_embedding( self.vllm_rotary_embedding(
positions, positions,
query, query,
key, key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment