Unverified Commit 5884c2b4 authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[Misc] Update to comply with the new `compressed-tensors` config (#5350)


Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
parent 45f92c00
...@@ -5,15 +5,15 @@ Run `pytest tests/quantization/test_compressed_tensors.py`. ...@@ -5,15 +5,15 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
import torch import torch
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken, CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor) CompressedTensorsW8A8StaticTensor)
def test_compressed_tensors_w8a8_static_setup(vllm_runner): def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
with vllm_runner(model_path, quantization="sparseml", with vllm_runner(model_path, enforce_eager=True) as llm:
enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0] layer = model.model.layers[0]
...@@ -40,11 +40,17 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): ...@@ -40,11 +40,17 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
assert qkv_proj.input_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32
def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
with vllm_runner(model_path) as llm:
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-dynamic-test" model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
with vllm_runner(model_path, with vllm_runner(model_path, enforce_eager=True,
quantization="sparseml",
enforce_eager=True,
dtype=torch.float16) as llm: dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0] layer = model.model.layers[0]
......
...@@ -164,12 +164,8 @@ class ModelConfig: ...@@ -164,12 +164,8 @@ class ModelConfig:
def _parse_quant_hf_config(self): def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None) quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None: if quant_cfg is None:
# SparseML uses a "compression_config" with a "quantization_config". # compress-tensors uses a "compression_config" key
compression_cfg = getattr(self.hf_config, "compression_config", quant_cfg = getattr(self.hf_config, "compression_config", None)
None)
if compression_cfg is not None:
quant_cfg = compression_cfg.get("quantization_config", None)
return quant_cfg return quant_cfg
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
......
...@@ -31,7 +31,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -31,7 +31,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"sparseml": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
} }
......
...@@ -122,12 +122,9 @@ def get_quant_config(model_config: ModelConfig, ...@@ -122,12 +122,9 @@ def get_quant_config(model_config: ModelConfig,
hf_quant_config = getattr(model_config.hf_config, "quantization_config", hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None) None)
if hf_quant_config is None: if hf_quant_config is None:
compression_config = getattr(model_config.hf_config, # compressed-tensors uses a compressions_config
"compression_config", None) hf_quant_config = getattr(model_config.hf_config, "compression_config",
if compression_config is not None: None)
hf_quant_config = compression_config.get("quantization_config",
None)
if hf_quant_config is not None: if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config) return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model. # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment