Unverified Commit 4082338a authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Remove unneeded ROCm platform import when using CUDA (#22765)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent c6b92879
......@@ -22,7 +22,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
......@@ -886,6 +885,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window,
......
......@@ -11,7 +11,6 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd
......@@ -296,6 +295,7 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16)
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention(
query.dtype,
head_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment