Unverified Commit f6518b2b authored by hissu-hyvarinen's avatar hissu-hyvarinen Committed by GitHub
Browse files

[ROCm] Skip tests for quantizations incompatible with ROCm (#17905)


Signed-off-by: default avatarHissu Hyvarinen <hissu.hyvarinen@amd.com>
parent d67085c2
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import pytest import pytest
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform
# These ground truth generations were generated using `transformers==4.38.1 # These ground truth generations were generated using `transformers==4.38.1
# aqlm==1.1.0 torch==2.2.0` # aqlm==1.1.0 torch==2.2.0`
...@@ -34,7 +35,9 @@ ground_truth_generations = [ ...@@ -34,7 +35,9 @@ ground_truth_generations = [
] ]
@pytest.mark.skipif(not is_quant_method_supported("aqlm"), @pytest.mark.skipif(not is_quant_method_supported("aqlm")
or current_platform.is_rocm()
or not current_platform.is_cuda(),
reason="AQLM is not supported on this GPU type.") reason="AQLM is not supported on this GPU type.")
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) @pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
......
...@@ -55,6 +55,14 @@ def test_models( ...@@ -55,6 +55,14 @@ def test_models(
Only checks log probs match to cover the discrepancy in Only checks log probs match to cover the discrepancy in
numerical sensitive kernels. numerical sensitive kernels.
""" """
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
pytest.skip(
f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", 'true') m.setenv("TOKENIZERS_PARALLELISM", 'true')
m.setenv(STR_BACKEND_ENV_VAR, backend) m.setenv(STR_BACKEND_ENV_VAR, backend)
......
...@@ -14,6 +14,7 @@ import pytest ...@@ -14,6 +14,7 @@ import pytest
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
from vllm.platforms import current_platform
from ..utils import check_logprobs_close from ..utils import check_logprobs_close
...@@ -34,7 +35,9 @@ MODELS = [ ...@@ -34,7 +35,9 @@ MODELS = [
@pytest.mark.flaky(reruns=3) @pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin")
or current_platform.is_rocm()
or not current_platform.is_cuda(),
reason="gptq_marlin is not supported on this GPU type.") reason="gptq_marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half", "bfloat16"]) @pytest.mark.parametrize("dtype", ["half", "bfloat16"])
......
...@@ -10,6 +10,7 @@ from dataclasses import dataclass ...@@ -10,6 +10,7 @@ from dataclasses import dataclass
import pytest import pytest
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform
from ..utils import check_logprobs_close from ..utils import check_logprobs_close
...@@ -38,7 +39,9 @@ model_pairs = [ ...@@ -38,7 +39,9 @@ model_pairs = [
@pytest.mark.flaky(reruns=2) @pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24")
or current_platform.is_rocm()
or not current_platform.is_cuda(),
reason="Marlin24 is not supported on this GPU type.") reason="Marlin24 is not supported on this GPU type.")
@pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment