Unverified Commit 7b1a7423 authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

[Frontend] new online quantization frontend (#38138)


Signed-off-by: default avatarVasiliy Kuznetsov <vasiliy@meta.com>
parent 97f92c6b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests online quantization."""
import pytest
import torch
from tests.quantization.utils import (
_test_online_quant_peak_mem_impl,
is_quant_method_supported,
)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
)
from vllm.platforms import current_platform
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize(
"quant_scheme,online_quant_args,expected_linear_cls,expected_moe_cls",
[
# simple case - quantization='fp8_per_tensor'
(
"fp8_per_tensor",
None,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
# simple case - quantization='fp8_per_block'
(
"fp8_per_block",
None,
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
),
# quantization='online with linear_scheme_override and
# moe_scheme_override
(
"online",
{
"linear_scheme_override": "fp8_per_block",
"moe_scheme_override": "fp8_per_tensor",
},
Fp8PerBlockOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
# ignore with direct layer name
(
"fp8_per_tensor",
# qkv_proj is fused from q_proj/k_proj/v_proj, so currently the
# ignore regex must match the unfused shard names
# TODO(future PR): also make 're:.*qkv_proj.*' work
{"ignore": ["model.layers.1.self_attn.o_proj", "re:.*[qkv]_proj"]},
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
],
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_online_quantization(
vllm_runner,
quant_scheme: str,
online_quant_args: dict | None,
expected_linear_cls,
expected_moe_cls,
use_rocm_aiter: bool,
monkeypatch,
) -> None:
"""
Tests that online quantization frontend configuration works -
selecting quant schemes, overriding quant schemes by type, ignoring
layers.
Does not test performance, peak memory usage, etc.
"""
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# a tiny model with both dense and MoE layers
model_name = "ibm-granite/granite-3.0-1b-a400m-base"
runner_kwargs = dict(
quantization=quant_scheme,
enforce_eager=True,
)
if online_quant_args is not None:
runner_kwargs["quantization_config"] = online_quant_args
with vllm_runner(
model_name,
**runner_kwargs,
) as llm:
def check_model(model):
# checks further down in the test case are hardcoded for this
# model
assert model_name == "ibm-granite/granite-3.0-1b-a400m-base"
o_proj = model.model.layers[0].self_attn.o_proj
moe = model.model.layers[0].block_sparse_moe.experts
# o_proj and moe in layer 0 are always quantized (never ignored)
# because of how we craft the test case inputs
assert isinstance(o_proj.quant_method, expected_linear_cls)
if moe is not None:
assert isinstance(moe.quant_method, expected_moe_cls)
if current_platform.is_cuda():
assert o_proj.weight.dtype == torch.float8_e4m3fn
elif current_platform.is_rocm():
assert o_proj.weight.dtype == current_platform.fp8_dtype()
else:
pytest.skip("Only runs on CUDA and ROCm.")
# Verify ignored layers are unquantized.
if isinstance(online_quant_args, dict) and "ignore" in online_quant_args:
# only .*1.self_attn_o_proj is skipped
for layer_idx in range(len(model.model.layers)):
o_proj = model.model.layers[layer_idx].self_attn.o_proj
if layer_idx == 1:
assert isinstance(o_proj.quant_method, UnquantizedLinearMethod)
else:
assert isinstance(o_proj.quant_method, expected_linear_cls)
# every .*self_attn.qkv_proj is skipped
for layer_idx in range(len(model.model.layers)):
qkv_proj = model.model.layers[layer_idx].self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, UnquantizedLinearMethod)
llm.apply_model(check_model)
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(outputs[0][1])
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_peak_mem(
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
_test_online_quant_peak_mem_impl(
"fp8_per_tensor", vllm_runner, caplog_mp_spawn, monkeypatch
)
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_load_format_dummy(
vllm_runner,
monkeypatch,
caplog,
) -> None:
with vllm_runner(
"ibm-granite/granite-3.0-1b-a400m-base",
quantization="fp8_per_tensor",
enforce_eager=True,
load_format="dummy",
) as llm:
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import regex as re
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.platforms import current_platform
......@@ -21,3 +25,74 @@ def is_quant_method_supported(quant_method: str) -> bool:
min_capability = get_quantization_config(quant_method).get_min_capability()
return capability.to_int() >= min_capability
def _test_online_quant_peak_mem_impl(
quantization_arg_value,
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
# 1. it covers both Linear and MoE paths
# 2. it is already used by other tests in CI, so adding it here
# does not increase disk space for CI runners
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
# 1.3 GiB fp8), but could not as adding one more model makes CI
# run out of disk space.
model_name = "allenai/OLMoE-1B-7B-0125-Instruct"
# Force spawn to ensure caplog_mp_spawn works consistently
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
with (
caplog_mp_spawn(logging.DEBUG) as log_holder,
vllm_runner(
model_name,
quantization=quantization_arg_value,
enforce_eager=True,
) as llm,
):
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])
log_text = log_holder.text
# Parse memory usage from captured logs
model_memory_gib = None
peak_memory_gib = None
for line in log_text.splitlines():
if model_memory_gib is None:
match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
if match:
model_memory_gib = float(match.group(1))
if peak_memory_gib is None:
match = re.search(
r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
)
if match:
peak_memory_gib = float(match.group(1))
assert model_memory_gib is not None, "Could not find model loading memory log"
assert peak_memory_gib is not None, "Could not find peak memory log"
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib = 6.7
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
expected_peak_memory_gib = expected_model_memory_gib * 1.4
assert model_memory_gib < expected_model_memory_gib, (
f"{model_memory_gib=} higher than {expected_model_memory_gib}"
)
assert peak_memory_gib < expected_peak_memory_gib, (
f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
)
......@@ -21,6 +21,7 @@ from vllm.config.multimodal import (
MultiModalConfig,
)
from vllm.config.pooler import PoolerConfig
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter
from vllm.logger import init_logger
......@@ -199,6 +200,10 @@ class ModelConfig:
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
determine the data type of the weights."""
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None = None
"""Arguments for online quantization.
Auto-created when `quantization` equals to one of the string values of
the `OnlineQuantScheme` enum."""
allow_deprecated_quantization: bool = False
"""Whether to allow deprecated quantization methods."""
enforce_eager: bool = False
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Any
from pydantic import Field, field_validator
from vllm.config.utils import config
class OnlineQuantScheme(Enum):
"""Supported online quantization schemes."""
# fp8, weights and activations scaled per-tensor
FP8_PER_TENSOR = "fp8_per_tensor"
# fp8, activations scaled in blocks of 1x128 elements, weights scaled in
# blocks of 128x128 elements (popularized by DeepSeek)
FP8_PER_BLOCK = "fp8_per_block"
# TODO(future PRs): add more online quant schemes here: mxfp8, etc
@config
class OnlineQuantizationConfigArgs:
"""Configuration for online quantization.
Controls how ``OnlineQuantizationConfig`` is applied to a model.
At least one of ``global_scheme``, ``linear_scheme_override``, or
``moe_scheme_override`` must be set.
"""
global_scheme: OnlineQuantScheme | None = None
"""Quantization scheme applied to every supported layer."""
linear_scheme_override: OnlineQuantScheme | None = None
"""Quantization scheme override for ``LinearBase`` layers."""
moe_scheme_override: OnlineQuantScheme | None = None
"""Quantization scheme override for ``FusedMoE`` layers."""
ignore: list[str] = Field(default_factory=list)
"""Layers to skip quantization for. Supports exact names and regex
patterns with ``re:`` prefix (e.g. ``re:.*attn.*``), consistent with
compressed_tensors layer skipping."""
@field_validator(
"global_scheme", "linear_scheme_override", "moe_scheme_override", mode="before"
)
@classmethod
def _coerce_scheme(
cls, v: str | OnlineQuantScheme | None
) -> OnlineQuantScheme | None:
if isinstance(v, str):
return OnlineQuantScheme(v)
return v
def resolve_online_quant_config(
quantization: str | None,
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None,
) -> OnlineQuantizationConfigArgs | None:
"""Resolve online quant scheme shorthand into a quantization config.
If ``quantization`` is an online quant scheme (e.g. ``'fp8_per_tensor'``),
ensures ``quantization_config`` has a matching ``global_scheme`` and casts
it to :class:`OnlineQuantizationConfigArgs` if needed.
"""
online_quant_values = {s.value for s in OnlineQuantScheme}
valid_quantization_values = online_quant_values | {"online"}
if quantization not in valid_quantization_values:
if quantization_config is not None:
raise ValueError(
f"quantization_config is only supported when quantization "
f"is one of {sorted(valid_quantization_values)}, "
f"got quantization={quantization!r}"
)
return None
if quantization in online_quant_values:
scheme = OnlineQuantScheme(quantization)
if quantization_config is None:
quantization_config = {
"global_scheme": scheme.value,
}
elif isinstance(quantization_config, OnlineQuantizationConfigArgs):
if quantization_config.global_scheme is None:
quantization_config.global_scheme = scheme
elif quantization_config.global_scheme != scheme:
raise ValueError(
f"quantization={quantization!r} conflicts with "
f"quantization_config.global_scheme="
f"{quantization_config.global_scheme.value!r}. "
f"These must match when both are specified."
)
elif isinstance(quantization_config, dict):
existing = quantization_config.get("global_scheme")
if existing is None:
quantization_config["global_scheme"] = scheme.value
else:
# Coerce to enum for comparison
existing_scheme = (
OnlineQuantScheme(existing)
if isinstance(existing, str)
else existing
)
if existing_scheme != scheme:
raise ValueError(
f"quantization={quantization!r} conflicts "
f"with quantization_config"
f"['global_scheme']={existing!r}. "
f"These must match when both are specified."
)
# Cast dict to OnlineQuantizationConfigArgs
if isinstance(quantization_config, dict):
quantization_config = OnlineQuantizationConfigArgs(**quantization_config)
return quantization_config
......@@ -1713,6 +1713,7 @@ class VllmConfig:
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, "
f"quantization_config={self.model_config.quantization_config}, " # noqa
f"enforce_eager={self.model_config.enforce_eager}, "
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
......
......@@ -112,6 +112,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessor
from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING:
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.model_loader import LoadFormats
from vllm.usage.usage_lib import UsageContext
......@@ -483,6 +484,7 @@ class EngineArgs:
hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
tokenizer_revision: str | None = ModelConfig.tokenizer_revision
quantization: QuantizationMethods | str | None = ModelConfig.quantization
quantization_config: "dict[str, Any] | OnlineQuantizationConfigArgs | None" = None
allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
......@@ -661,6 +663,12 @@ class EngineArgs:
if isinstance(self.ir_op_priority, dict):
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
from vllm.config.quantization import resolve_online_quant_config
self.quantization_config = resolve_online_quant_config(
self.quantization, self.quantization_config
)
# Setup plugins
from vllm.plugins import load_general_plugins
......@@ -1431,6 +1439,7 @@ class EngineArgs:
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_config=self.quantization_config,
allow_deprecated_quantization=self.allow_deprecated_quantization,
enforce_eager=self.enforce_eager,
enable_return_routed_experts=self.enable_return_routed_experts,
......
......@@ -34,6 +34,9 @@ from vllm.config.model import (
RunnerOption,
TokenizerMode,
)
from vllm.config.quantization import (
OnlineQuantizationConfigArgs,
)
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
......@@ -247,6 +250,9 @@ class LLM:
attention_config: dict[str, Any] | AttentionConfig | None = None,
kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
quantization_config: dict[str, Any]
| OnlineQuantizationConfigArgs
| None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None,
**kwargs: Any,
) -> None:
......@@ -367,6 +373,7 @@ class LLM:
profiler_config=profiler_config_instance,
attention_config=attention_config_instance,
compilation_config=compilation_config_instance,
quantization_config=quantization_config,
logits_processors=logits_processors,
**kwargs,
)
......
......@@ -33,6 +33,13 @@ QuantizationMethods = Literal[
"mxfp8",
"petit_nvfp4",
"cpu_awq",
"online",
# Below are values of the OnlineQuantScheme enum, specified as strings to
# avoid circular import issues. This is here to provide a shortcut where
# the user can specify "LLM(..., quantization='fp8_per_tensor')" as
# shorthand for creating a more complicated online quant config object
"fp8_per_tensor",
"fp8_per_block",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
......@@ -103,6 +110,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}")
# lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import OnlineQuantScheme
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .awq import AWQConfig
......@@ -129,6 +137,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig
from .petit import PetitNvFp4Config
from .torchao import TorchAOConfig
......@@ -157,7 +166,20 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp8": Mxfp8Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig,
}
# Below are values of the OnlineQuantScheme enum. This is here to provide
# a shortcut where the user can specify
# "LLM(..., quantization='fp8_per_tensor')" as shorthand for creating a
# more complicated online quant config object
for scheme in OnlineQuantScheme:
assert scheme.value not in method_to_config, (
f"Online quant scheme {scheme.value!r} conflicts with an "
f"existing quantization method"
)
method_to_config[scheme.value] = OnlineQuantizationConfig
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
......@@ -497,6 +497,8 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply_weights(layer, x, bias)
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineLinearMethod
class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
and quantizes weights during loading."""
......@@ -919,6 +921,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineMoEMethod
class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.config.quantization import (
OnlineQuantizationConfigArgs,
OnlineQuantScheme,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.linear import (
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
)
class OnlineQuantizationConfig(QuantizationConfig):
"""Model-level config class for online quantization (quantize fp16/bf16 weights
during model loading, without requiring a pre-quantized checkpoint)."""
def __init__(
self,
args: OnlineQuantizationConfigArgs,
) -> None:
super().__init__()
if (
args.global_scheme is None
and args.linear_scheme_override is None
and args.moe_scheme_override is None
):
raise ValueError(
"OnlineQuantizationConfig requires at least one of "
"global_scheme, linear_scheme_override, or "
"moe_scheme_override to be set."
)
self.args = args
self.quant_scheme = args.global_scheme
self.ignored_layers: list[str] = args.ignore
@classmethod
def get_name(cls) -> QuantizationMethods:
return "online"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# Note: as more online quant schemes will be added, this
# value will become the minimum across all supported schemes.
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "OnlineQuantizationConfig":
raise NotImplementedError(
"OnlineQuantizationConfig does not support loading from a "
"checkpoint config. Use quantization_config or "
"quantization='fp8_per_tensor'/'fp8_per_block' instead."
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
linear_scheme = self.args.linear_scheme_override or self.args.global_scheme
if linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineLinearMethod()
else:
return Fp8PerTensorOnlineLinearMethod()
elif isinstance(layer, FusedMoE):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
moe_scheme = self.args.moe_scheme_override or self.args.global_scheme
if moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineMoEMethod(layer=layer)
else:
return Fp8PerTensorOnlineMoEMethod(layer=layer)
return None
This diff is collapsed.
......@@ -296,6 +296,13 @@ def get_quant_config(
)
if hf_quant_config is not None:
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
# For modelopt_mixed, config.json's quantization_config may or may
# not contain the per-layer quantized_layers map. Newer checkpoints
# embed it directly; older ones keep it only in hf_quant_config.json.
......@@ -319,6 +326,12 @@ def get_quant_config(
quantization_config_file = hf_overrides.get("quantization_config_file", None)
if quantization_config_file is not None:
if hasattr(quant_cls, "from_config_file"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_file(quantization_config_file)
else:
raise NotImplementedError(
......@@ -329,6 +342,12 @@ def get_quant_config(
quantization_config_json = hf_overrides.get("quantization_config_dict_json", None)
if quantization_config_json is not None:
if hasattr(quant_cls, "from_config_dict_json"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_dict_json(quantization_config_json)
else:
raise NotImplementedError(
......@@ -337,6 +356,19 @@ def get_quant_config(
f"{quant_cls}"
)
# Online quantization doesn't read from checkpoint configs — it quantizes
# fp16/bf16 weights on the fly during loading.
if model_config.quantization_config is not None:
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.model_executor.layers.quantization.online.base import (
OnlineQuantizationConfig,
)
assert isinstance(
model_config.quantization_config, OnlineQuantizationConfigArgs
)
return OnlineQuantizationConfig(args=model_config.quantization_config)
# Inflight BNB quantization
if model_config.quantization == "bitsandbytes":
return quant_cls.from_config({})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment