"torchvision/transforms/v2/_transform.py" did not exist on "4941c6b6b62bbda3cd462e8e59d114bbdc9683c6"
Unverified Commit 8511669c authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Move quantized weight handling out of the `Weights` class (#2194)

Quantized weights were loaded in the `Weights` class, but this was
getting quite unwieldy, where every higher level method to load weights
was a long conditional to cover all the different quantizers.

This change moves loading of quantized weights out of the `Weights`
class. This is done by defining a simple `WeightsLoader` interface
that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`,
and `MarlinWeightsLoader`. These implementations are in the quantizers'
respective modules. The `Weights` class provides the low-level load
operations (such as loading tensors or sharded tensors), but delegates
loads that need quantizer-specific weight processing to a loader. The
loaders still use the low-level functionality provided by `Weights`.

I initially tried making a hierarchy where a class like `GPTQWeights`
would inherit from `Weights`. But it is not very flexible (e.g. does
not work well with the new weight storage mock used in tests) and
the implicit indirections made the code harder to follow.
parent 4c976fb4
...@@ -28,6 +28,7 @@ from text_generation_server.models.types import ( ...@@ -28,6 +28,7 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
...@@ -448,8 +449,17 @@ class Mamba(Model): ...@@ -448,8 +449,17 @@ class Mamba(Model):
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = MambaModel(config, weights) model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__( super(Mamba, self).__init__(
......
...@@ -18,6 +18,7 @@ from text_generation_server.utils import ( ...@@ -18,6 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
...@@ -586,6 +587,9 @@ class Seq2SeqLM(Model): ...@@ -586,6 +587,9 @@ class Seq2SeqLM(Model):
) )
tokenizer.bos_token_id = config.decoder_start_token_id tokenizer.bos_token_id = config.decoder_start_token_id
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
...@@ -594,6 +598,7 @@ class Seq2SeqLM(Model): ...@@ -594,6 +598,7 @@ class Seq2SeqLM(Model):
dtype=dtype, dtype=dtype,
process_group=self.process_group, process_group=self.process_group,
aliases=aliases, aliases=aliases,
weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]: if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
......
from typing import Optional
import os
import json
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
@dataclass
class _QuantizerConfig:
bits: int
checkpoint_format: Optional[str]
desc_act: bool
groupsize: int
quant_method: str
sym: bool
# We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
bits = 4
groupsize = -1
quant_method = "gptq"
checkpoint_format = None
sym = True
desc_act = False
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
sym = data["quantization_config"]["sym"]
desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"]
groupsize = data["group_size"]
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"]
groupsize = data["q_group_size"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
pass
return _QuantizerConfig(
bits=bits,
groupsize=groupsize,
quant_method=quant_method,
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
)
def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
return MarlinWeightsLoader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
else:
return DefaultWeightsLoader()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment