Unverified Commit ba291dad authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Improve the handling of quantized weights (#2250)

* Improve the handling of quantized weights

Handling of quantized weights was split between two mechanisms:

- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.

Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.

This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:

- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.

* Exclude non-MLP layers when using FP8 quantization with Llama
parent 1d1b1efa
...@@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=None))
get_linear(weight, bias=None, quantize=config.quantize)
)
def _load_experts(config, prefix: str, mat, weights): def _load_experts(config, prefix: str, mat, weights):
......
...@@ -56,7 +56,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -56,7 +56,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
if config.use_parallel_residual: if config.use_parallel_residual:
return linear return linear
else: else:
...@@ -81,7 +81,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): ...@@ -81,7 +81,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
if config.use_parallel_residual: if config.use_parallel_residual:
return linear return linear
else: else:
......
...@@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights):
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True # this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=True))
get_linear(weight, bias=True, quantize=config.quantize)
)
class FlashPhiAttention(torch.nn.Module): class FlashPhiAttention(torch.nn.Module):
......
...@@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
if config.parallel_attn: if config.parallel_attn:
return linear return linear
else: else:
......
...@@ -105,6 +105,7 @@ def _load_multi_mqa_gptq( ...@@ -105,6 +105,7 @@ def _load_multi_mqa_gptq(
g_idx=g_idx, g_idx=g_idx,
bits=loader.bits, bits=loader.bits,
groupsize=loader.groupsize, groupsize=loader.groupsize,
use_awq_kernel=loader.quantize == "awq",
use_exllama=HAS_EXLLAMA, use_exllama=HAS_EXLLAMA,
) )
...@@ -121,7 +122,7 @@ def _load_multi_mqa_gptq( ...@@ -121,7 +122,7 @@ def _load_multi_mqa_gptq(
bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device) bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
else: else:
raise NotImplementedError("Gptq loading with santacoder is not implemented") raise NotImplementedError("Gptq loading with santacoder is not implemented")
...@@ -193,7 +194,7 @@ def _load_multi_mqa( ...@@ -193,7 +194,7 @@ def _load_multi_mqa(
assert list(bias.shape) == [ assert list(bias.shape) == [
(num_heads + 2) * head_size (num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}" ], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def load_col(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool):
...@@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool): ...@@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool):
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
bias = None bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
...@@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
return TensorParallelRowLinear( return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group get_linear(weight, bias), process_group=weights.process_group
) )
......
...@@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights):
else: else:
bias = None bias = None
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=bias))
get_linear(weight, bias=bias, quantize=config.quantize)
)
class Starcoder2Attention(torch.nn.Module): class Starcoder2Attention(torch.nn.Module):
......
...@@ -34,7 +34,7 @@ from text_generation_server.layers import ( ...@@ -34,7 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
...@@ -698,7 +698,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -698,7 +698,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
self.dtype = weights.dtype self.dtype = weights.dtype
# The vision and connector models are not quantized. # The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader()): with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics2VisionTransformer( self.vision_model = Idefics2VisionTransformer(
prefix=( prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model" f"{prefix}.model.vision_model" if prefix else "model.vision_model"
...@@ -707,16 +707,12 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -707,16 +707,12 @@ class Idefics2ForConditionalGeneration(nn.Module):
weights=weights, weights=weights,
) )
quantize = config.quantize config.quantize = None
try: self.connector = Idefics2Connector(
config.quantize = None prefix=f"{prefix}.model.connector" if prefix else "model.connector",
self.connector = Idefics2Connector( config=config,
prefix=f"{prefix}.model.connector" if prefix else "model.connector", weights=weights,
config=config, )
weights=weights,
)
finally:
config.quantize = quantize
self.config = config self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_seq_len = config.perceiver_config.resampler_n_latents
......
...@@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias): ...@@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias):
bias = bias.to(device=weights.device) bias = bias.to(device=weights.device)
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return TensorParallelColumnLinear(linear) return TensorParallelColumnLinear(linear)
......
from typing import Optional
import os
import json import json
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import (
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader,
)
@dataclass @dataclass
...@@ -104,10 +107,30 @@ def get_loader( ...@@ -104,10 +107,30 @@ def get_loader(
quantize=quantize, quantize=quantize,
sym=quantizer_config.sym, sym=quantizer_config.sym,
) )
elif quantize == "bitsandbytes":
from text_generation_server.layers.bnb import BNBWeight
return DefaultWeightsLoader(BNBWeight)
elif quantize == "bitsandbytes-fp4":
from text_generation_server.layers.bnb import BNBFP4Weight
return DefaultWeightsLoader(BNBFP4Weight)
elif quantize == "bitsandbytes-nf4":
from text_generation_server.layers.bnb import BNBNF4Weight
return DefaultWeightsLoader(BNBNF4Weight)
elif quantize == "eetq":
from text_generation_server.layers.eetq import EETQWeight
return DefaultWeightsLoader(EETQWeight)
elif quantize == "exl2": elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader() return Exl2WeightsLoader()
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Weight
return DefaultWeightsLoader(Fp8Weight)
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader from text_generation_server.layers.marlin import MarlinWeightsLoader
...@@ -115,5 +138,7 @@ def get_loader( ...@@ -115,5 +138,7 @@ def get_loader(
bits=quantizer_config.bits, bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
) )
elif quantize is None:
return DefaultWeightsLoader(UnquantizedWeight)
else: else:
return DefaultWeightsLoader() raise ValueError(f"Unknown quantization method: {quantize}")
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from safetensors import safe_open
import torch import torch
from safetensors import safe_open
from text_generation_server.utils.import_utils import SYSTEM
class WeightsLoader(ABC): class WeightsLoader(ABC):
...@@ -62,7 +66,39 @@ class WeightsLoader(ABC): ...@@ -62,7 +66,39 @@ class WeightsLoader(ABC):
... ...
class Weight(ABC):
"""Instances of this type implement unquantized/quantized/to-be
quantized weights."""
@abstractmethod
def get_linear(self, bias: torch.Tensor):
"""Create a linear layer from this weight."""
...
@dataclass
class UnquantizedWeight:
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
if SYSTEM == "rocm":
return FastLinearROCm(self.weight, bias)
else:
return FastLinear(self.weight, bias)
class DefaultWeightsLoader(WeightsLoader): class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class):
"""Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading.
"""
self.weight_class = weight_class
""" """
Loader that uses tensors as-is with the exception of applying sharding Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation. and/or concatenation.
...@@ -74,16 +110,21 @@ class DefaultWeightsLoader(WeightsLoader): ...@@ -74,16 +110,21 @@ class DefaultWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
return weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes return self.weight_class(
weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
),
) )
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return torch.cat(w, dim=dim) return self.weight_class(torch.cat(w, dim=dim))
def get_weights_row(self, weights: "Weights", prefix: str): def get_weights_row(self, weights: "Weights", prefix: str):
return weights.get_sharded(f"{prefix}.weight", dim=1) return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1),
)
class Weights: class Weights:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment