Unverified Commit 4e2d95e3 authored by wangshuai09's avatar wangshuai09 Committed by GitHub
Browse files

[Hardware][ROCM] using current_platform.is_rocm (#9642)


Signed-off-by: default avatarwangshuai09 <391746016@qq.com>
parent 34a99416
...@@ -11,7 +11,7 @@ from unittest.mock import patch ...@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest import pytest
from vllm import LLM from vllm import LLM
from vllm.utils import is_hip from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
...@@ -51,7 +51,7 @@ def test_models( ...@@ -51,7 +51,7 @@ def test_models(
enforce_eager: bool, enforce_eager: bool,
) -> None: ) -> None:
if backend == "FLASHINFER" and is_hip(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend os.environ["VLLM_ATTENTION_BACKEND"] = backend
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip from vllm.platforms import current_platform
TEST_MODELS = [ TEST_MODELS = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
...@@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"): ...@@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"):
"quantization": "marlin" "quantization": "marlin"
})) }))
if not is_hip() and is_quant_method_supported("awq"): if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ" "quantization": "AWQ"
})) }))
......
...@@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union ...@@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union
import torch import torch
from vllm.utils import is_hip from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy # Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm. # issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0 ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
else torch.float8_e4m3fn
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
...@@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor, ...@@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype) else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max) qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0) s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0) s_512 = as_float32_tensor(512.0)
...@@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ...@@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]: -> Tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE) fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max) fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0) one = as_float32_tensor(1.0)
......
...@@ -6,11 +6,12 @@ import torch ...@@ -6,11 +6,12 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
if not is_hip(): if not current_platform.is_rocm():
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
...@@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 ...@@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321 # Arbitrary values for testing NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512 PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float DTYPES = [
] if not is_hip() else [torch.half, torch.bfloat16] torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
...@@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention( ...@@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"]) "version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
...@@ -317,8 +320,8 @@ def test_paged_attention( ...@@ -317,8 +320,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two # NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two # implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test. # outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3 atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test. # so we use a relaxed tolerance for the test.
...@@ -368,7 +371,7 @@ def ref_multi_query_kv_attention( ...@@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(), @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode() @torch.inference_mode()
def test_multi_query_kv_attention( def test_multi_query_kv_attention(
...@@ -425,6 +428,6 @@ def test_multi_query_kv_attention( ...@@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
scale, scale,
dtype, dtype,
) )
atol = get_default_atol(output) if is_hip() else 1e-3 atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
...@@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch): ...@@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
False) False)
assert backend.name == "TORCH_SDPA" assert backend.name == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True): with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == "ROCM_FLASH" assert backend.name == "ROCM_FLASH"
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import ( from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn) LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
...@@ -316,8 +317,8 @@ def test_paged_attention( ...@@ -316,8 +317,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two # NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two # implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test. # outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3 atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test. # so we use a relaxed tolerance for the test.
......
...@@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, ...@@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.utils import is_hip from vllm.platforms import current_platform
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
...@@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test(
attn_type=attn_type) attn_type=attn_type)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
...@@ -755,7 +756,8 @@ def test_encoder_only( ...@@ -755,7 +756,8 @@ def test_encoder_only(
No KV cache is required for encoder-only attention. No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip(). AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
This test globally forces an override of the usual backend This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test auto-selection process, forcing the specific backend-under-test
...@@ -811,7 +813,8 @@ def test_encoder_only( ...@@ -811,7 +813,8 @@ def test_encoder_only(
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
...@@ -864,7 +867,8 @@ def test_e2e_enc_dec_attn( ...@@ -864,7 +867,8 @@ def test_e2e_enc_dec_attn(
to be utilized. to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip(). AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
Note on metadata: there is a single attention metadata structure shared by Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross), all prefill-phase attention operations (encoder, decoder, enc/dec cross),
......
...@@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize) marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import is_hip, seed_everything from vllm.utils import seed_everything
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
...@@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe( def test_fused_marlin_moe(
m: int, m: int,
n: int, n: int,
...@@ -256,7 +257,7 @@ def test_fused_marlin_moe( ...@@ -256,7 +257,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_single_marlin_moe_multiply( def test_single_marlin_moe_multiply(
m: int, m: int,
n: int, n: int,
......
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import is_hip from vllm.platforms import current_platform
MODEL_PATH = "google/gemma-7b" MODEL_PATH = "google/gemma-7b"
...@@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: ...@@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts return generated_texts
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm") @pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files): def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH, llm = vllm.LLM(MODEL_PATH,
max_model_len=1024, max_model_len=1024,
......
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
import vllm import vllm
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import is_hip from vllm.platforms import current_platform
@dataclass @dataclass
...@@ -19,7 +19,7 @@ class ModelWithQuantization: ...@@ -19,7 +19,7 @@ class ModelWithQuantization:
MODELS: List[ModelWithQuantization] MODELS: List[ModelWithQuantization]
#AWQ quantization is currently not supported in ROCm. #AWQ quantization is currently not supported in ROCm.
if is_hip(): if current_platform.is_rocm():
MODELS = [ MODELS = [
ModelWithQuantization( ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
......
...@@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, ...@@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding) BatchEncoding)
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
...@@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"] ...@@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to, # ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime. # excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan) # FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip(): if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
...@@ -151,7 +152,7 @@ def run_test( ...@@ -151,7 +152,7 @@ def run_test(
pytest.param( pytest.param(
"float", "float",
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
is_hip(), current_platform.is_rocm(),
reason= reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma") "ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half" ), "half"
......
...@@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry ...@@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets) _ImageAssets)
...@@ -56,7 +55,7 @@ if current_platform.is_cpu(): ...@@ -56,7 +55,7 @@ if current_platform.is_cpu():
# ROCm Triton FA can run into shared memory issues with these models, # ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime # use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan) # FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip(): if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
......
...@@ -5,7 +5,7 @@ tensor parallelism. ...@@ -5,7 +5,7 @@ tensor parallelism.
import pytest import pytest
import torch import torch
from vllm.utils import is_hip from vllm.platforms import current_platform
from .conftest import run_equality_correctness_test_tp from .conftest import run_equality_correctness_test_tp
...@@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, ...@@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
batch_size: int, output_len: int, seed: int): batch_size: int, output_len: int, seed: int):
"""Verify greedy equality when tensor parallelism is used. """Verify greedy equality when tensor parallelism is used.
""" """
if is_hip(): if current_platform.is_rocm():
pytest.skip("hip is not well-supported yet") pytest.skip("hip is not well-supported yet")
run_equality_correctness_test_tp("JackFram/llama-68m", run_equality_correctness_test_tp("JackFram/llama-68m",
common_llm_kwargs, common_llm_kwargs,
......
...@@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader ...@@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes, from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip) cuda_device_count_stateless, get_open_port)
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage, from amdsmi import (amdsmi_get_gpu_vram_usage,
...@@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int], ...@@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
output: Dict[int, str] = {} output: Dict[int, str] = {}
output_raw: Dict[int, float] = {} output_raw: Dict[int, float] = {}
for device in devices: for device in devices:
if is_hip(): if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device] dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle) mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10 gb_used = mem_info["vram_used"] / 2**10
......
...@@ -674,8 +674,8 @@ def scaled_fp8_quant( ...@@ -674,8 +674,8 @@ def scaled_fp8_quant(
assert (input.ndim == 2) assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \ out_dtype: torch.dtype = torch.float8_e4m3fnuz \
else torch.float8_e4m3fn if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding: if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1]) shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype) output = torch.empty(shape, device=input.device, dtype=out_dtype)
......
...@@ -3,7 +3,6 @@ import math ...@@ -3,7 +3,6 @@ import math
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step, from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask) get_sparse_attn_mask)
...@@ -32,7 +31,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module): ...@@ -32,7 +31,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
): ):
super().__init__() super().__init__()
if use_spda is None: if use_spda is None:
use_spda = is_hip() or current_platform.is_cpu() or not \ use_spda = current_platform.is_rocm() or \
current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device() device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu") if current_platform.is_cuda_alike() else "cpu")
......
...@@ -10,7 +10,7 @@ import vllm.envs as envs ...@@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip from vllm.utils import STR_BACKEND_ENV_VAR
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -208,7 +208,7 @@ def which_attn_to_use( ...@@ -208,7 +208,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS return _Backend.PALLAS
if is_hip(): if current_platform.is_rocm():
# AMD GPUs. # AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
......
...@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, ...@@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, print_warning_once) print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -350,7 +350,7 @@ class ModelConfig: ...@@ -350,7 +350,7 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"Unknown quantization method: {self.quantization}. Must " f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.") f"be one of {supported_quantization}.")
if is_hip( if current_platform.is_rocm(
) and self.quantization not in rocm_supported_quantization: ) and self.quantization not in rocm_supported_quantization:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
...@@ -365,7 +365,7 @@ class ModelConfig: ...@@ -365,7 +365,7 @@ class ModelConfig:
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization) "non-quantized models.", self.quantization)
if (self.quantization == "awq" and is_hip() if (self.quantization == "awq" and current_platform.is_rocm()
and not envs.VLLM_USE_TRITON_AWQ): and not envs.VLLM_USE_TRITON_AWQ):
logger.warning( logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
...@@ -843,7 +843,8 @@ class LoadConfig: ...@@ -843,7 +843,8 @@ class LoadConfig:
self.load_format = LoadFormat(load_format) self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = [] rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format: if current_platform.is_rocm(
) and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [ rocm_supported_load_format = [
f for f in LoadFormat.__members__ f for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format) if (f not in rocm_not_supported_load_format)
...@@ -967,7 +968,7 @@ class ParallelConfig: ...@@ -967,7 +968,7 @@ class ParallelConfig:
if self.use_ray: if self.use_ray:
from vllm.executor import ray_utils from vllm.executor import ray_utils
ray_utils.assert_ray_available() ray_utils.assert_ray_available()
if is_hip(): if current_platform.is_rocm():
self.disable_custom_all_reduce = True self.disable_custom_all_reduce = True
logger.info( logger.info(
"Disabled the custom all-reduce kernel because it is not " "Disabled the custom all-reduce kernel because it is not "
......
...@@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook ...@@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -231,7 +231,7 @@ def initialize_ray_cluster( ...@@ -231,7 +231,7 @@ def initialize_ray_cluster(
assert_ray_available() assert_ray_available()
# Connect to a ray cluster. # Connect to a ray cluster.
if is_hip() or current_platform.is_xpu(): if current_platform.is_rocm() or current_platform.is_xpu():
ray.init(address=ray_address, ray.init(address=ray_address,
ignore_reinit_error=True, ignore_reinit_error=True,
num_gpus=parallel_config.world_size) num_gpus=parallel_config.world_size)
......
...@@ -7,7 +7,7 @@ import vllm.envs as envs ...@@ -7,7 +7,7 @@ import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -72,7 +72,7 @@ class CustomOp(nn.Module): ...@@ -72,7 +72,7 @@ class CustomOp(nn.Module):
if not enabled: if not enabled:
return self.forward_native return self.forward_native
if is_hip(): if current_platform.is_rocm():
return self.forward_hip return self.forward_hip
elif current_platform.is_cpu(): elif current_platform.is_cpu():
return self.forward_cpu return self.forward_cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment