Unverified Commit ba291dad authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Improve the handling of quantized weights (#2250)

* Improve the handling of quantized weights

Handling of quantized weights was split between two mechanisms:

- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.

Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.

This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:

- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.

* Exclude non-MLP layers when using FP8 quantization with Llama
parent 1d1b1efa
......@@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=None))
def _load_experts(config, prefix: str, mat, weights):
......
......@@ -56,7 +56,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
if config.use_parallel_residual:
return linear
else:
......@@ -81,7 +81,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
if config.use_parallel_residual:
return linear
else:
......
......@@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights):
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear(
get_linear(weight, bias=True, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=True))
class FlashPhiAttention(torch.nn.Module):
......
......@@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
if config.parallel_attn:
return linear
else:
......
......@@ -105,6 +105,7 @@ def _load_multi_mqa_gptq(
g_idx=g_idx,
bits=loader.bits,
groupsize=loader.groupsize,
use_awq_kernel=loader.quantize == "awq",
use_exllama=HAS_EXLLAMA,
)
......@@ -121,7 +122,7 @@ def _load_multi_mqa_gptq(
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
......@@ -193,7 +194,7 @@ def _load_multi_mqa(
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_col(config, prefix: str, weights, bias: bool):
......@@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool):
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool):
......@@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
get_linear(weight, bias), process_group=weights.process_group
)
......
......@@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
else:
bias = None
return TensorParallelColumnLinear(
get_linear(weight, bias=bias, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
class Starcoder2Attention(torch.nn.Module):
......
......@@ -34,7 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
......@@ -698,7 +698,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader()):
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics2VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
......@@ -707,16 +707,12 @@ class Idefics2ForConditionalGeneration(nn.Module):
weights=weights,
)
quantize = config.quantize
try:
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
finally:
config.quantize = quantize
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
......
......@@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias):
bias = bias.to(device=weights.device)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return TensorParallelColumnLinear(linear)
......
from typing import Optional
import os
import json
import os
from dataclasses import dataclass
from typing import Optional
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader,
)
@dataclass
......@@ -104,10 +107,30 @@ def get_loader(
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "bitsandbytes":
from text_generation_server.layers.bnb import BNBWeight
return DefaultWeightsLoader(BNBWeight)
elif quantize == "bitsandbytes-fp4":
from text_generation_server.layers.bnb import BNBFP4Weight
return DefaultWeightsLoader(BNBFP4Weight)
elif quantize == "bitsandbytes-nf4":
from text_generation_server.layers.bnb import BNBNF4Weight
return DefaultWeightsLoader(BNBNF4Weight)
elif quantize == "eetq":
from text_generation_server.layers.eetq import EETQWeight
return DefaultWeightsLoader(EETQWeight)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Weight
return DefaultWeightsLoader(Fp8Weight)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
......@@ -115,5 +138,7 @@ def get_loader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
elif quantize is None:
return DefaultWeightsLoader(UnquantizedWeight)
else:
return DefaultWeightsLoader()
raise ValueError(f"Unknown quantization method: {quantize}")
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Dict, List, Optional, Union
from safetensors import safe_open
import torch
from safetensors import safe_open
from text_generation_server.utils.import_utils import SYSTEM
class WeightsLoader(ABC):
......@@ -62,7 +66,39 @@ class WeightsLoader(ABC):
...
class Weight(ABC):
"""Instances of this type implement unquantized/quantized/to-be
quantized weights."""
@abstractmethod
def get_linear(self, bias: torch.Tensor):
"""Create a linear layer from this weight."""
...
@dataclass
class UnquantizedWeight:
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
if SYSTEM == "rocm":
return FastLinearROCm(self.weight, bias)
else:
return FastLinear(self.weight, bias)
class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class):
"""Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading.
"""
self.weight_class = weight_class
"""
Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation.
......@@ -74,16 +110,21 @@ class DefaultWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
return weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
return self.weight_class(
weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
),
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return torch.cat(w, dim=dim)
return self.weight_class(torch.cat(w, dim=dim))
def get_weights_row(self, weights: "Weights", prefix: str):
return weights.get_sharded(f"{prefix}.weight", dim=1)
return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1),
)
class Weights:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment