Unverified Commit 53ad423f authored by jiahanc's avatar jiahanc Committed by GitHub
Browse files

[Perf] enable flashinfer rotary_embedding custom ops in DeepSeek rotary (#30729)


Signed-off-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
parent 889f8bb2
...@@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp): ...@@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp):
# and current_platform.is_cuda() # and current_platform.is_cuda()
# and has_flashinfer() # and has_flashinfer()
# and self.head_size in [64, 128, 256, 512]) # and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
# Check if use_flashinfer is already set
if not hasattr(self, "use_flashinfer"):
self.use_flashinfer = False
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
if not self.use_flashinfer: if not self.use_flashinfer:
......
...@@ -6,6 +6,7 @@ import math ...@@ -6,6 +6,7 @@ import math
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from .base import RotaryEmbeddingBase from .base import RotaryEmbeddingBase
from .common import ( from .common import (
...@@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): ...@@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor * attn_factor
) )
self.use_flashinfer = (
self.enabled()
and dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and has_flashinfer()
and head_size in [64, 128, 256, 512]
)
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
) )
...@@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): ...@@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None, key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None, offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets) if self.use_flashinfer:
torch.ops.vllm.flashinfer_rotary_embedding(
torch.add(positions, offsets) if offsets is not None else positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
else:
return self.forward_native(positions, query, key, offsets)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment