"vscode:/vscode.git/clone" did not exist on "fbfe20c62ce5efe1f26a96cbd2cf28ed49b6282e"
Unverified Commit 9319044e authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE][Perf] Wrap DSV3 QKVAProj GEMM in custom op for torch.compile (#35751)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent c42dc402
......@@ -75,6 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
......@@ -717,6 +718,44 @@ class Indexer(nn.Module):
return self.indexer_op(hidden_states, q_fp8, k, weights)
def _min_latency_fused_qkv_a_proj_impl(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 16.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
num_tokens = input_.shape[0]
if 0 < num_tokens <= 16:
output = torch.empty(
num_tokens,
weight.shape[0],
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(output, input_, weight.T)
return output
else:
return torch.nn.functional.linear(input_, weight)
def _min_latency_fused_qkv_a_proj_fake(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return input_.new_empty(input_.shape[0], weight.shape[0])
direct_register_custom_op(
op_name="min_latency_fused_qkv_a_proj",
op_func=_min_latency_fused_qkv_a_proj_impl,
mutates_args=[],
fake_impl=_min_latency_fused_qkv_a_proj_fake,
)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
def __init__(
self,
......@@ -752,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
num_tokens = input_.shape[0]
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
output = torch.empty(
num_tokens,
2112,
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(
output,
input_,
self.weight.T,
)
if self._use_min_latency_gemm:
output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment