Commit 1e911dbd authored by zhuwenwen's avatar zhuwenwen
Browse files

[kernels] add rotary_embedding_deepseek_fuse

off rotary_embedding_deepseek_fuse
parent 63f1c793
......@@ -135,9 +135,10 @@ if current_platform.is_rocm():
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
torch.float16, torch.bfloat16
]
# use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
# torch.float16, torch.bfloat16
# ]
use_aiter = False
if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
......
......@@ -37,6 +37,9 @@ from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
......@@ -842,6 +845,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1)
return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=[],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -880,6 +901,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1)
# if envs.VLLM_USE_LIGHTOP:
if False:
torch.ops.vllm.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query)
call(key)
return query, key
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment