Unverified Commit bb00f66e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Use `quantization_config` in hf config (#1695)

parent e87557b0
...@@ -104,14 +104,30 @@ class ModelConfig: ...@@ -104,14 +104,30 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "squeezellm"] supported_quantization = ["awq", "squeezellm"]
if self.quantization is None: if self.quantization is not None:
return self.quantization = self.quantization.lower()
quantization = self.quantization.lower()
if quantization not in supported_quantization: # Parse quantization method from the HF model config, if available.
raise ValueError( hf_quant_config = getattr(self.hf_config, "quantization_config", None)
f"Unknown quantization: {self.quantization}. Must be one of " if hf_quant_config is not None:
f"{supported_quantization}.") hf_quant_method = str(hf_quant_config["quant_method"]).lower()
self.quantization = quantization if self.quantization is None:
self.quantization = hf_quant_method
elif self.quantization != hf_quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({hf_quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization}).")
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
......
...@@ -66,6 +66,7 @@ def get_model(model_config: ModelConfig) -> nn.Module: ...@@ -66,6 +66,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config.quantization, quant_config = get_quant_config(model_config.quantization,
model_config.model, model_config.model,
model_config.hf_config,
model_config.download_dir) model_config.download_dir)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
......
...@@ -7,9 +7,10 @@ from collections import defaultdict ...@@ -7,9 +7,10 @@ from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open
import numpy as np import numpy as np
from safetensors.torch import load_file, save_file, safe_open
import torch import torch
from transformers import PretrainedConfig
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file( ...@@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file(
def get_quant_config( def get_quant_config(
quantization: str, quantization: str,
model_name_or_path: str, model_name_or_path: str,
hf_config: PretrainedConfig,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
) -> QuantizationConfig: ) -> QuantizationConfig:
quant_cls = get_quantization_config(quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(hf_config, "quantization_config", None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
if not is_local: if not is_local:
# Download the config files. # Download the config files.
...@@ -98,7 +106,6 @@ def get_quant_config( ...@@ -98,7 +106,6 @@ def get_quant_config(
hf_folder = model_name_or_path hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json")) config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_cls = get_quantization_config(quantization)
quant_config_files = [ quant_config_files = [
f for f in config_files if any( f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames()) f.endswith(x) for x in quant_cls.get_config_filenames())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment