Commit 62920e37 authored by zhuwenwen's avatar zhuwenwen
Browse files

add triton_key

parent abf008ef
...@@ -5,6 +5,8 @@ from dataclasses import dataclass ...@@ -5,6 +5,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch import torch
import triton
from triton.compiler.compiler import triton_key
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -30,7 +32,7 @@ _ON_NAVI = "gfx1" in _GPU_ARCH ...@@ -30,7 +32,7 @@ _ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH _ON_MI250_MI300 = any(arch in _GPU_ARCH
for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]) for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
...@@ -778,6 +780,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -778,6 +780,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# prefix-enabled attention # prefix-enabled attention
# not applicable for encoder-only models # not applicable for encoder-only models
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY: if self.attn_type != AttentionType.ENCODER_ONLY:
output[: output[:
num_prefill_tokens] = PagedAttention.forward_prefix( num_prefill_tokens] = PagedAttention.forward_prefix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment