Unverified Commit 9319044e authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE][Perf] Wrap DSV3 QKVAProj GEMM in custom op for torch.compile (#35751)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent c42dc402
...@@ -75,6 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -75,6 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend, DeepseekV32IndexerBackend,
...@@ -717,6 +718,44 @@ class Indexer(nn.Module): ...@@ -717,6 +718,44 @@ class Indexer(nn.Module):
return self.indexer_op(hidden_states, q_fp8, k, weights) return self.indexer_op(hidden_states, q_fp8, k, weights)
def _min_latency_fused_qkv_a_proj_impl(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 16.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
num_tokens = input_.shape[0]
if 0 < num_tokens <= 16:
output = torch.empty(
num_tokens,
weight.shape[0],
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(output, input_, weight.T)
return output
else:
return torch.nn.functional.linear(input_, weight)
def _min_latency_fused_qkv_a_proj_fake(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return input_.new_empty(input_.shape[0], weight.shape[0])
direct_register_custom_op(
op_name="min_latency_fused_qkv_a_proj",
op_func=_min_latency_fused_qkv_a_proj_impl,
mutates_args=[],
fake_impl=_min_latency_fused_qkv_a_proj_fake,
)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear): class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
def __init__( def __init__(
self, self,
...@@ -752,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear): ...@@ -752,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
self, self,
input_, input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]: ) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
num_tokens = input_.shape[0] if self._use_min_latency_gemm:
if self._use_min_latency_gemm and (0 < num_tokens <= 16): output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
output = torch.empty(
num_tokens,
2112,
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(
output,
input_,
self.weight.T,
)
if not self.return_bias: if not self.return_bias:
return output return output
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment