Unverified Commit bde57ab2 authored by Matt's avatar Matt Committed by GitHub
Browse files

[Hardware][AMD][CI][Bugfix] Fix AMD Quantization test group (#31713)


Signed-off-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
parent 9103ed16
...@@ -731,7 +731,7 @@ steps: ...@@ -731,7 +731,7 @@ steps:
- label: Quantization Test # 70min - label: Quantization Test # 70min
timeout_in_minutes: 90 timeout_in_minutes: 90
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
source_file_dependencies: source_file_dependencies:
......
...@@ -644,6 +644,9 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): ...@@ -644,6 +644,9 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
assert output assert output
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args", "args",
[ [
...@@ -762,7 +765,10 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner): ...@@ -762,7 +765,10 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
assert isinstance(input_quant_op, QuantFP8) assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method == input_quant_op.forward_cuda assert input_quant_op._forward_method in (
input_quant_op.forward_cuda,
input_quant_op.forward_hip,
)
llm.apply_model(check_model) llm.apply_model(check_model)
......
...@@ -10,6 +10,7 @@ from dataclasses import dataclass ...@@ -10,6 +10,7 @@ from dataclasses import dataclass
import pytest import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform
@dataclass @dataclass
...@@ -23,20 +24,44 @@ MODEL_ARG_EXPTYPES = [ ...@@ -23,20 +24,44 @@ MODEL_ARG_EXPTYPES = [
# AUTOGPTQ # AUTOGPTQ
# compat: autogptq <=0.7.1 is_marlin_format: bool # compat: autogptq <=0.7.1 is_marlin_format: bool
# Model Serialized in Exllama Format. # Model Serialized in Exllama Format.
("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"), (
("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"), "TheBloke/Llama-2-7B-Chat-GPTQ",
None,
"gptq_marlin" if current_platform.is_cuda() else "gptq",
),
(
"TheBloke/Llama-2-7B-Chat-GPTQ",
"marlin",
"gptq_marlin" if current_platform.is_cuda() else "ERROR",
),
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"), ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
# compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq >=0.8.0 use checkpoint_format: str
# Model Serialized in Exllama Format. # Model Serialized in Exllama Format.
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"), (
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
None,
"gptq_marlin" if current_platform.is_cuda() else "gptq",
),
(
"LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"marlin",
"gptq_marlin" if current_platform.is_cuda() else "ERROR",
),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
# AUTOAWQ # AUTOAWQ
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq_marlin"), (
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
None,
"awq_marlin" if current_platform.is_cuda() else "awq",
),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"), ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "awq_marlin"), (
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
"marlin",
"awq_marlin" if current_platform.is_cuda() else "ERROR",
),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"), ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
] ]
......
...@@ -66,7 +66,7 @@ def test_cpu_offload_compressed_tensors(monkeypatch): ...@@ -66,7 +66,7 @@ def test_cpu_offload_compressed_tensors(monkeypatch):
monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto") monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto")
# Test wNa16 # Test wNa16
compare_two_settings( compare_two_settings(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2", "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16",
["--enforce_eager"], ["--enforce_eager"],
["--enforce_eager", "--cpu-offload-gb", "1"], ["--enforce_eager", "--cpu-offload-gb", "1"],
max_wait_seconds=480, max_wait_seconds=480,
......
...@@ -36,7 +36,9 @@ MODELS = [ ...@@ -36,7 +36,9 @@ MODELS = [
reason="FP8 is not supported on this GPU type.", reason="FP8 is not supported on this GPU type.",
) )
@pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize(
"force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
) )
...@@ -125,7 +127,9 @@ def test_kv_cache_model_load_and_run( ...@@ -125,7 +127,9 @@ def test_kv_cache_model_load_and_run(
reason="FP8 is not supported on this GPU type.", reason="FP8 is not supported on this GPU type.",
) )
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize(
"force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
) )
...@@ -197,10 +201,10 @@ def test_scaled_fp8_quant(dtype) -> None: ...@@ -197,10 +201,10 @@ def test_scaled_fp8_quant(dtype) -> None:
def quantize_ref(tensor, inv_scale): def quantize_ref(tensor, inv_scale):
# The reference implementation that fully aligns to # The reference implementation that fully aligns to
# the kernel being tested. # the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(current_platform.fp8_dtype())
scale = inv_scale.reciprocal() scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn) qweight = qweight.to(current_platform.fp8_dtype())
return qweight return qweight
def per_tensor_dequantize(tensor, inv_scale, dtype): def per_tensor_dequantize(tensor, inv_scale, dtype):
...@@ -267,6 +271,10 @@ def test_scaled_fp8_quant(dtype) -> None: ...@@ -267,6 +271,10 @@ def test_scaled_fp8_quant(dtype) -> None:
) )
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="FP8 e4m3fn weight reloading is not supported on e4m3fnuz platforms",
)
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod]) @pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization # FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False @pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
......
...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinLinear ...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinLinear
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_dynamic_override, get_dynamic_override,
) )
from vllm.platforms import current_platform
PROMPT = "On the surface of Mars, we found" PROMPT = "On the surface of Mars, we found"
...@@ -21,7 +22,10 @@ PROMPT = "On the surface of Mars, we found" ...@@ -21,7 +22,10 @@ PROMPT = "On the surface of Mars, we found"
# The second layer is quantized using bits=8, group_size=32 # The second layer is quantized using bits=8, group_size=32
# All other layers (layer index >= 2) are not quantized # All other layers (layer index >= 2) are not quantized
MODEL_QUANT = [ MODEL_QUANT = [
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", True), (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue",
current_platform.is_cuda(),
),
( (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse",
False, False,
......
...@@ -6,18 +6,12 @@ Run `pytest tests/quantization/test_ptpc_fp8.py --forked`. ...@@ -6,18 +6,12 @@ Run `pytest tests/quantization/test_ptpc_fp8.py --forked`.
""" """
import pytest import pytest
import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
UNSUPPORTED_STR = (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified."
)
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch): def enable_pickle(monkeypatch):
...@@ -30,24 +24,17 @@ def enable_pickle(monkeypatch): ...@@ -30,24 +24,17 @@ def enable_pickle(monkeypatch):
reason="PTPC FP8 is not supported on this GPU type.", reason="PTPC FP8 is not supported on this GPU type.",
) )
@pytest.mark.skipif(not current_platform.is_rocm(), reason="This test is for ROCm GPU.") @pytest.mark.skipif(not current_platform.is_rocm(), reason="This test is for ROCm GPU.")
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
try: llm = vllm_runner(
llm = vllm_runner( "facebook/opt-125m",
"facebook/opt-125m", dtype=dtype,
dtype=dtype, quantization="ptpc_fp8",
quantization="ptpc_fp8", enforce_eager=True,
enforce_eager=True, kv_cache_dtype=kv_cache_dtype,
kv_cache_dtype=kv_cache_dtype, allow_deprecated_quantization=True,
) )
except AssertionError as e:
if str(e) == UNSUPPORTED_STR:
# If the error message matches, the test passes
return
else:
# If the error message does not match, re-raise the exception
raise
with llm: with llm:
...@@ -60,9 +47,9 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: ...@@ -60,9 +47,9 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
assert attn._k_scale == 1.0 assert attn._k_scale == 1.0
assert attn._v_scale == 1.0 assert attn._v_scale == 1.0
# For GPUs with hardware support, we keep weights in fp8
if current_platform.has_device_capability(94): if current_platform.has_device_capability(94):
# For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == current_platform.fp8_dtype()
assert fc1.weight.dtype == torch.float8_e4m3fnuz
llm.apply_model(check_model) llm.apply_model(check_model)
......
...@@ -10,6 +10,11 @@ def is_quant_method_supported(quant_method: str) -> bool: ...@@ -10,6 +10,11 @@ def is_quant_method_supported(quant_method: str) -> bool:
if not (current_platform.is_cuda() or current_platform.is_rocm()): if not (current_platform.is_cuda() or current_platform.is_rocm()):
return False return False
try:
current_platform.verify_quantization(quant_method)
except ValueError:
return False
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
assert capability is not None assert capability is not None
......
...@@ -5,6 +5,7 @@ from typing import Literal, get_args ...@@ -5,6 +5,7 @@ from typing import Literal, get_args
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -98,6 +99,9 @@ def register_quantization_config(quantization: str): ...@@ -98,6 +99,9 @@ def register_quantization_config(quantization: str):
) )
else: else:
QUANTIZATION_METHODS.append(quantization) QUANTIZATION_METHODS.append(quantization)
# Automatically assume the custom quantization config is supported
if sq := current_platform.supported_quantization:
sq.append(quantization)
if not issubclass(quant_config_cls, QuantizationConfig): if not issubclass(quant_config_cls, QuantizationConfig):
raise ValueError( raise ValueError(
......
...@@ -9,6 +9,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm ...@@ -9,6 +9,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm
triton_scaled_mm, triton_scaled_mm,
) )
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
...@@ -37,6 +40,20 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -37,6 +40,20 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
torch.nn.Parameter(weight.t().data, requires_grad=False), torch.nn.Parameter(weight.t().data, requires_grad=False),
) )
# WEIGHT SCALE
# Triton kernel supports only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE # INPUT SCALE
if self.config.is_static_input_scheme: if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name) input_scale = getattr(layer, self.i_s_name)
......
...@@ -103,21 +103,25 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ...@@ -103,21 +103,25 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) assert layer.weight.data.dtype not in (torch.float16, torch.float32), (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support "
assert layer.weight.data.dtype == torch.bfloat16, ( f"output dtype of bfloat16. {layer.weight.data.dtype} is specified."
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
)
# Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant(
layer.weight, scale=None, use_per_token_if_dynamic=True
) )
# Update the layer with the new values. if layer.weight.data.dtype == torch.bfloat16:
layer.weight = Parameter( # Quantize the weights.
qweight.t(), requires_grad=False qweight, weight_scale = ops.scaled_fp8_quant(
) # Pretranspose the weight layer.weight, scale=None, use_per_token_if_dynamic=True
layer.weight_scale = Parameter(weight_scale, requires_grad=False) )
# Update the layer with the new values.
layer.weight = Parameter(
qweight.t(), requires_grad=False
) # Pretranspose the weight
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
assert layer.weight.data.dtype == current_platform.fp8_dtype()
assert getattr(layer, "weight_scale", None) is not None
layer.input_scale = None layer.input_scale = None
def apply( def apply(
......
...@@ -170,7 +170,9 @@ class RocmPlatform(Platform): ...@@ -170,7 +170,9 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "awq",
"awq_marlin", # will be overwritten with awq
"gptq", "gptq",
"gptq_marlin", # will be overwritten with gptq
"fp8", "fp8",
"compressed-tensors", "compressed-tensors",
"fbgemm_fp8", "fbgemm_fp8",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment