Unverified Commit 4e31b7f2 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Quantization][Deprecation] Remove RTN (#32697)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 6c20e89c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright © 2025, Oracle and/or its affiliates.
"""Tests RTN quantization startup and generation,
doesn't test correctness
"""
import pytest
from tests.quantization.utils import is_quant_method_supported
MODELS = [
"ai21labs/Jamba-tiny-dev", # MoE model
]
@pytest.mark.skipif(
not is_quant_method_supported("rtn"),
reason="RTN is not supported on this GPU type.",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_model_rtn_startup(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with vllm_runner(
model,
enforce_eager=True,
dtype=dtype,
quantization="rtn",
allow_deprecated_quantization=True,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
...@@ -31,7 +31,6 @@ QuantizationMethods = Literal[ ...@@ -31,7 +31,6 @@ QuantizationMethods = Literal[
"quark", "quark",
"moe_wna16", "moe_wna16",
"torchao", "torchao",
"rtn",
"inc", "inc",
"mxfp4", "mxfp4",
"petit_nvfp4", "petit_nvfp4",
...@@ -49,7 +48,6 @@ DEPRECATED_QUANTIZATION_METHODS = [ ...@@ -49,7 +48,6 @@ DEPRECATED_QUANTIZATION_METHODS = [
"gptq_bitblas", "gptq_bitblas",
"experts_int8", "experts_int8",
"ipex", "ipex",
"rtn",
"petit_nvfp4", "petit_nvfp4",
] ]
...@@ -138,7 +136,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -138,7 +136,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .mxfp4 import Mxfp4Config from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
method_to_config: dict[str, type[QuantizationConfig]] = { method_to_config: dict[str, type[QuantizationConfig]] = {
...@@ -163,7 +160,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -163,7 +160,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"quark": QuarkConfig, "quark": QuarkConfig,
"moe_wna16": MoeWNA16Config, "moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig, "torchao": TorchAOConfig,
"rtn": RTNConfig,
"auto-round": INCConfig, "auto-round": INCConfig,
"inc": INCConfig, "inc": INCConfig,
"mxfp4": Mxfp4Config, "mxfp4": Mxfp4Config,
......
This diff is collapsed.
...@@ -650,64 +650,3 @@ def apply_awq_marlin_linear( ...@@ -650,64 +650,3 @@ def apply_awq_marlin_linear(
) )
return output.reshape(out_shape) return output.reshape(out_shape)
def apply_rtn_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
None,
None,
None,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment