Unverified Commit 01c22335 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Kernel] [V1] Fix performance regression for triton unified attention (#18161)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent 451da4bc
...@@ -56,11 +56,11 @@ def kernel_unified_attention_2d( ...@@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
stride_k_cache_0: tl.int64, # int stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1] query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32, num_seqs: tl.int32,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention and Triton prefix prefill.""" """Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
...@@ -12,10 +12,23 @@ from vllm.logger import init_logger ...@@ -12,10 +12,23 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import ( from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder) FlashAttentionMetadata, FlashAttentionMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
self.aot_schedule = False
class TritonAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
...@@ -52,8 +65,8 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -52,8 +65,8 @@ class TritonAttentionBackend(AttentionBackend):
return False return False
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return TritonAttentionMetadataBuilder
class TritonAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment