Unverified Commit 4082338a authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Remove unneeded ROCm platform import when using CUDA (#22765)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent c6b92879
...@@ -22,7 +22,6 @@ from vllm.logger import init_logger ...@@ -22,7 +22,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) GroupShape)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
...@@ -886,6 +885,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -886,6 +885,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention( use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window, decode_meta.max_decode_seq_len, self.sliding_window,
......
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd from .prefix_prefill import context_attention_fwd
...@@ -296,6 +295,7 @@ def chunked_prefill_paged_decode( ...@@ -296,6 +295,7 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16) 16)
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention( use_custom = use_rocm_custom_paged_attention(
query.dtype, query.dtype,
head_size, head_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment