Unverified Commit 01c22335 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Kernel] [V1] Fix performance regression for triton unified attention (#18161)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent 451da4bc
......@@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
......@@ -12,10 +12,23 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
self.aot_schedule = False
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
......@@ -52,8 +65,8 @@ class TritonAttentionBackend(AttentionBackend):
return False
@staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
class TritonAttentionImpl(AttentionImpl):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment