Unverified Commit 6874638b authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Model Bash] DeepSeek R1 BF16 Min Latency QKV A GEMM (0.5% E2E Speedup) (#34758)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent e24663c5
......@@ -771,6 +771,25 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS)
set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu")
set_gencode_flags_for_srcs(
SRCS "${DSV3_FUSED_A_GEMM_SRC}"
CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}")
list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC})
list(APPEND VLLM_GPU_FLAGS "-DENABLE_DSV3_FUSED_A_GEMM=1")
message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}")
else()
message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found "
"in CUDA target architectures.")
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
......
This diff is collapsed.
......@@ -410,3 +410,8 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif
#ifndef USE_ROCM
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
torch::Tensor const& mat_b);
#endif
\ No newline at end of file
......@@ -239,6 +239,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
ops.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
......
......@@ -2789,6 +2789,24 @@ def sm100_cutlass_mla_get_workspace_size(
)
def dsv3_fused_a_gemm(
output: torch.Tensor,
mat_a: torch.Tensor,
mat_b: torch.Tensor,
) -> None:
"""DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
Computes output = mat_a @ mat_b.T where:
mat_a: [num_tokens, 7168] row-major bf16 (hidden states)
mat_b: [7168, 2112] column-major bf16 (weight transposed)
output: [num_tokens, 2112] row-major bf16
Optimized for the DeepSeek V2/V3 QKV A-projection at small batch sizes.
Requires SM 9.0+ (Hopper).
"""
torch.ops._C.dsv3_fused_a_gemm(output, mat_a, mat_b)
if hasattr(torch.ops._C, "weight_packed_linear"):
@register_fake("_C::weight_packed_linear")
......
......@@ -129,6 +129,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None"
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
......
......@@ -32,6 +32,7 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
import vllm._custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
......@@ -711,6 +712,64 @@ class Indexer(nn.Module):
return self.indexer_op(hidden_states, q_fp8, k, weights)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__(
input_size,
output_size,
bias=False,
quant_config=quant_config,
disable_tp=True,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
# Check if the DeepSeek V3 fused A GEMM kernel can be used.
# This kernel supports PDL and is optimized for low batch size.
self._use_min_latency_gemm = (
hasattr(self, "weight")
and self.weight.dtype == torch.bfloat16
and self.weight.shape[0] == 2112
and self.weight.shape[1] == 7168
and current_platform.is_cuda()
and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)
)
def forward(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
num_tokens = input_.shape[0]
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
output = torch.empty(
num_tokens,
2112,
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(
output,
input_,
self.weight.T,
)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
else:
# Fallback to the standard forward method when
# the fused A GEMM kernel cannot be used.
return super().forward(input_)
class DeepseekV2MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
......@@ -756,13 +815,11 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True,
)
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment