Unverified Commit f26650d6 authored by Hongxia Yang's avatar Hongxia Yang Committed by GitHub
Browse files

[ROCm] add amd-quark package in requirements for rocm to use quantized models (#35658)


Signed-off-by: default avatarHongxia Yang <hongxiay.yang@amd.com>
Co-authored-by: default avatarHongxia Yang <hongxiay.yang@amd.com>
parent 92f5d0f0
...@@ -20,3 +20,6 @@ setuptools-scm>=8 ...@@ -20,3 +20,6 @@ setuptools-scm>=8
runai-model-streamer[s3,gcs]==0.15.3 runai-model-streamer[s3,gcs]==0.15.3
conch-triton-kernels==1.2.1 conch-triton-kernels==1.2.1
timm>=1.0.17 timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
\ No newline at end of file
...@@ -26,9 +26,12 @@ from vllm.platforms import current_platform ...@@ -26,9 +26,12 @@ from vllm.platforms import current_platform
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
# Minimum amd-quark version for MXFP4/OCP_MX tests (single source of truth).
QUARK_MXFP4_MIN_VERSION = "0.8.99"
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark") importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99") ) >= version.parse(QUARK_MXFP4_MIN_VERSION)
if QUARK_MXFP4_AVAILABLE: if QUARK_MXFP4_AVAILABLE:
from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
...@@ -200,7 +203,10 @@ WIKITEXT_ACCURACY_CONFIGS = [ ...@@ -200,7 +203,10 @@ WIKITEXT_ACCURACY_CONFIGS = [
] ]
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS) @pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
...@@ -231,7 +237,10 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): ...@@ -231,7 +237,10 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS) @pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.skipif( @pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS, not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.", reason="Read access to huggingface.co/amd is required for this test.",
...@@ -261,7 +270,10 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): ...@@ -261,7 +270,10 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) @pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]): def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
...@@ -289,7 +301,10 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[in ...@@ -289,7 +301,10 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[in
) )
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) @pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark( def test_mxfp4_dequant_kernel_match_quark(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment