Commit 49810c37 authored by zhuwenwen's avatar zhuwenwen
Browse files

[kernels] add fused_rms_norm_contiguous and rotary_embedding_deepseek_fuse

[kernels] update moe_align_block_size and moe_sum interface
parent ddf9d10c
...@@ -1773,7 +1773,7 @@ def fused_experts_impl( ...@@ -1773,7 +1773,7 @@ def fused_experts_impl(
if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick: if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick:
if shared_output is not None: if shared_output is not None:
op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], routed_scaling_factor) out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], None, routed_scaling_factor)
# else: # else:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), # ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) # out_hidden_states[begin_chunk_idx:end_chunk_idx])
......
...@@ -235,7 +235,7 @@ def moe_align_block_size( ...@@ -235,7 +235,7 @@ def moe_align_block_size(
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None) expert_ids, num_tokens_post_pad, None, None, None)
else: else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad) expert_ids, num_tokens_post_pad)
......
...@@ -10,6 +10,7 @@ import vllm.envs as envs ...@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool: def is_rocm_aiter_rmsnorm_enabled() -> bool:
...@@ -39,6 +40,33 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, ...@@ -39,6 +40,33 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
return out return out
def rms_norm_opt(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
from lightop import fused_rms_norm_contiguous
out = torch.empty_like(x)
fused_rms_norm_contiguous(
out,
x,
weight,
variance_epsilon,
)
return out
def rms_norm_opt_fake(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="rms_norm_opt",
op_func=rms_norm_opt,
mutates_args=[],
fake_impl=rms_norm_opt_fake,
)
def fused_add_rms_norm( def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -187,6 +215,23 @@ class RMSNorm(CustomOp): ...@@ -187,6 +215,23 @@ class RMSNorm(CustomOp):
else: else:
return norm_func(x, self.weight.data, self.variance_epsilon) return norm_func(x, self.weight.data, self.variance_epsilon)
def forward_cuda_opt(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return norm_func(x, residual, self.weight.data,
self.variance_epsilon)
else:
return torch.ops.vllm.rms_norm_opt(x, self.weight.data, self.variance_epsilon)
def forward_apex( def forward_apex(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -37,6 +37,8 @@ from transformers import PretrainedConfig ...@@ -37,6 +37,8 @@ from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
...@@ -900,6 +902,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -900,6 +902,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=[],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -938,8 +958,11 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -938,8 +958,11 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=1) num_warps=1)
call(query) if envs.VLLM_USE_LIGHTOP:
call(key) torch.ops.vllm.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query)
call(key)
return query, key return query, key
else: else:
query_rot = query[..., :self.rotary_dim] query_rot = query[..., :self.rotary_dim]
......
...@@ -566,7 +566,10 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -566,7 +566,10 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe # Add head dim of 1 to k_pe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment