Unverified Commit f77df946 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Perf] Add decode full-graph support to FlashInfer-MLA backend (#26313)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent f231e5bc
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import ClassVar, Optional, Union
import torch import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
...@@ -12,13 +12,20 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -12,13 +12,20 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__) logger = init_logger(__name__)
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
class FlashInferMLABackend(MLACommonBackend): class FlashInferMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -28,6 +35,10 @@ class FlashInferMLABackend(MLACommonBackend): ...@@ -28,6 +35,10 @@ class FlashInferMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashInferMLAImpl"]: def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl return FlashInferMLAImpl
@staticmethod
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder
g_fi_workspace = torch.zeros( g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment