Unverified Commit a7850008 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add initial support for compressed-tensors checkpoints (#2732)

compressed-tensors is a safetensors extension for sparse, quantized
tensors. The format is more powerful than earlier AWQ/GPTQ/FP8
quantization, because

- Different quantizer configurations can be used for different targets.
- The format can specify input/output quantizers in addition to weight
  quantizers.
- Configurable exclusions for quantization.

This change adds a dependency on the `compressed-tensors` package for
its configuration parsing and layer matching functionality.

The following types of quantization are supported in this PR:

- W8A16 and W4A16 INT using GPTQ-Marlin kernels.
- W8A8 and W8A16 FP using FP8-Marlin and cutlass kernels.

Support for other quantization types will be added in subsequent PRs.
parent 97f7a22f
from .loader import CompressedTensorsLoader
__all__ = ["CompressedTensorsLoader"]
from typing import Any, Dict, List, Union
from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16Loader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""
def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)
try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)
self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)
for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)
def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)
def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)
def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""
loaders: Dict[str, WeightsLoader] = {}
format = quantization_config.format
for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)
loader = self._create_loader_for_group(format, group_name, group)
# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader
return loaders
def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.
input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
elif (
format == CompressionFormat.pack_quantized.value
and weights is not None
and weights.type == QuantizationType.INT
and weights.num_bits in (4, 8)
):
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
return WNA16Loader(weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)
def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)
# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]
from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale
from text_generation_server.utils.weights import Weights, WeightsLoader
class W8ANFpLoader(WeightsLoader):
"""
Loader for W8A8/W8A16 FP compressed-tensors parameters.
"""
def __init__(
self,
*,
input_activations: Optional[QuantizationArgs],
weights: QuantizationArgs,
):
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
# We ignore the `strategy` option which sets the scales to be
# per-tensor, per-channel or per-token. What scales are supported
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
# So, instead we try to use the best-possible configuration.
self.load_weight_scale = not weights.dynamic
self.load_input_scale = (
input_activations is not None and not input_activations.dynamic
)
self.force_w8a16 = (
input_activations is not None and input_activations.num_bits == 16
)
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
from typing import List, Union
import torch
from compressed_tensors.quantization import ActivationOrdering, QuantizationArgs
from loguru import logger
from text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader
class WNA16Loader(WeightsLoader):
"""
Loader for W4A16/W8A16 INT compressed-tensors parameters.
"""
def __init__(self, weights: QuantizationArgs):
self.weights = weights
self.desc_act = self.weights.actorder == ActivationOrdering.GROUP
self.groupsize = (
-1 if self.weights.group_size is None else self.weights.group_size
)
def __str__(self) -> str:
quantization_type = f"W{self.weights.num_bits}8A16"
return f"{self.__class__.__name__} ({quantization_type})"
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
weight_packed = weights.get_tensor(f"{prefix}.weight_packed").t()
except RuntimeError:
raise RuntimeError(
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
)
zero_point = None
if not self.weights.symmetric:
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
g_idx = None
if self.desc_act:
g_idx = weights.get_tensor(f"{prefix}.weight_g_idx")
scales = weights.get_tensor(f"{prefix}.weight.scales").t()
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=False,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
weight_packed = weights.get_packed_sharded(
f"{prefix}.weight_packed", dim=0, block_sizes=block_sizes
).t()
except RuntimeError:
raise RuntimeError(
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
)
scales = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
).t()
scales = scales.to(dtype=weights.dtype)
zero_point = None
if not self.weights.symmetric:
zero_point = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=0, block_sizes=block_sizes
).t()
g_idx = None
if self.desc_act:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
weight_packed = torch.cat(
[
weights.get_sharded(f"{p}.weight_packed", dim=0).t()
for p in prefixes
],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.weight_scale", dim=0).t() for p in prefixes],
dim=1,
)
zero_point = None
if not self.weights.symmetric:
zero_point = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=0).t() for p in prefixes], dim=1
).t()
g_idx = None
if self.desc_act:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=1).t()
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
zero_point = None
if not self.weights.symmetric:
if self.desc_act or self.groupsize == -1:
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
else:
zero_point = weights.get_sharded(
f"{prefix}.weight_zero_point", dim=1
).t()
g_idx = None
if self.desc_act:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.weight_scale").t()
else:
scales = weights.get_sharded(f"{prefix}.weight_scale", dim=1).t()
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=sharded_in_features,
)
...@@ -29,7 +29,7 @@ else: ...@@ -29,7 +29,7 @@ else:
CUTLASS_FP8_AVAILABLE = False CUTLASS_FP8_AVAILABLE = False
def get_fp8_linear() -> Type[torch.nn.Module]: def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
""" """
Return an FP8 linear `Module` that is compatible with the current system. Return an FP8 linear `Module` that is compatible with the current system.
""" """
...@@ -37,7 +37,14 @@ def get_fp8_linear() -> Type[torch.nn.Module]: ...@@ -37,7 +37,14 @@ def get_fp8_linear() -> Type[torch.nn.Module]:
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1": # Marlin is W8A16, use it when:
#
# - On capability 8.x where x < 8: W8A8 FP8 GEMM is not supported.
# - On capability 8.9: W8A8 FP8 GEMM is supported, but Marlin-FP8 is faster.
# - On capability 9.x when force_w8a16: cutlass kernels do not support W8A16.
if (major == 8 or (major == 9 and force_w8a16)) and os.getenv(
"USE_CUTLASS_W8A8", "0"
) != "1":
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin # NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
# gives better decoding throughput on L4 and L40. # gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
...@@ -283,14 +290,17 @@ class Fp8Weight(Weight): ...@@ -283,14 +290,17 @@ class Fp8Weight(Weight):
weight_scale: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
self.weight, bias, self.dtype
)
# This is not checked by the fbgemm kernels, but they require contiguous # This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars. # memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous() self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8( return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
weight=self.weight, weight=self.weight,
scale=self.weight_scale, scale=self.weight_scale,
dtype=self.dtype, dtype=self.dtype,
......
...@@ -261,7 +261,7 @@ class GPTQMarlinWeight(Weight): ...@@ -261,7 +261,7 @@ class GPTQMarlinWeight(Weight):
def __post_init__(self): def __post_init__(self):
assert self.qweight.dtype == torch.int32 assert self.qweight.dtype == torch.int32
assert self.scales.dtype == torch.float16 assert self.scales.dtype in (torch.float16, torch.bfloat16)
assert self.g_idx.dtype == torch.int32 assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32 assert self.perm.dtype == torch.int32
...@@ -300,7 +300,7 @@ def repack_gptq_for_marlin( ...@@ -300,7 +300,7 @@ def repack_gptq_for_marlin(
raise RuntimeError( raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
) )
if not (sym or quant_method == "awq"): if not (sym or quant_method == "awq" or quant_method == "compressed-tensors"):
raise RuntimeError( raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
) )
......
...@@ -370,46 +370,23 @@ def get_model( ...@@ -370,46 +370,23 @@ def get_model(
compression_config = config_dict.get("compression_config", None) compression_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
config_groups = quantization_config.get("config_groups", None)
if method in {"gptq", "awq", "exl2"}: if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}") log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method quantize = method
elif method == "fbgemm_fp8" or method == "fp8": elif method == "fbgemm_fp8" or method == "fp8":
log_master(logger.info, "Auto selecting quantization method fp8") log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8" quantize = "fp8"
elif config_groups is not None: if method == "compressed-tensors":
# TODO: at some point we should probably fully parse the compression log_master(
# configuration to know which parameters are compressed. logger.info, "Auto selecting quantization method compressed-tensors"
for _, group in config_groups.items(): )
weights_config = group.get("weights") quantize = "compressed-tensors"
if weights_config is not None:
if (
weights_config["type"] == "float"
and weights_config["num_bits"] == 8
):
log_master(
logger.info, "Auto selecting quantization method fp8"
)
quantize = "fp8"
break
else: else:
log_master(logger.warning, f"Unknown quantization method {method}") log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None: elif compression_config is not None:
# `compression_config` renamed to `quantization_config`; support retained for backward compatibility. # `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
config_groups = compression_config.get("config_groups") log_master(logger.info, "Auto selecting quantization method compressed-tensors")
if config_groups is not None: quantize = "compressed-tensors"
for _, group in config_groups.items():
weights_config = group.get("weights")
if weights_config is not None:
if (
weights_config["type"] == "float"
and weights_config["num_bits"] == 8
):
log_master(
logger.info, "Auto selecting quantization method fp8"
)
quantize = "fp8"
break
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
......
...@@ -27,7 +27,20 @@ class _FP8QuantizerConfig: ...@@ -27,7 +27,20 @@ class _FP8QuantizerConfig:
activation_scale_ub: float activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization, def _get_config_json(model_id: str, revision: Optional[str], filename: str):
if os.path.exists(
os.path.join(
model_id,
)
):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
return json.load(f)
# We should probably do this with Pydantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params. # but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision): def _get_quantizer_config(model_id, revision):
bits = 4 bits = 4
...@@ -39,12 +52,7 @@ def _get_quantizer_config(model_id, revision): ...@@ -39,12 +52,7 @@ def _get_quantizer_config(model_id, revision):
filename = "config.json" filename = "config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
# FP8 config # FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8": if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
...@@ -67,14 +75,7 @@ def _get_quantizer_config(model_id, revision): ...@@ -67,14 +75,7 @@ def _get_quantizer_config(model_id, revision):
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"] bits = data["bits"]
groupsize = data["group_size"] groupsize = data["group_size"]
...@@ -90,14 +91,7 @@ def _get_quantizer_config(model_id, revision): ...@@ -90,14 +91,7 @@ def _get_quantizer_config(model_id, revision):
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"] bits = data["w_bit"]
groupsize = data["q_group_size"] groupsize = data["q_group_size"]
desc_act = data["desc_act"] desc_act = data["desc_act"]
...@@ -119,6 +113,14 @@ def _get_quantizer_config(model_id, revision): ...@@ -119,6 +113,14 @@ def _get_quantizer_config(model_id, revision):
def get_loader( def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str] quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader: ) -> WeightsLoader:
if quantize == "compressed-tensors":
config = _get_config_json(model_id, revision, "config.json")
from text_generation_server.layers.compressed_tensors import (
CompressedTensorsLoader,
)
return CompressedTensorsLoader(config)
quantizer_config = _get_quantizer_config(model_id, revision) quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}: if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment