Commit 36dd1601 authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Add support for exl2 quantization

Mostly straightforward, changes to existing code:

* Wrap quantizer parameters in a small wrapper to avoid passing
  around untyped tuples and needing to repack them as a dict.
* Move scratch space computation to warmup, because we need the
  maximum input sequence length to avoid allocating huge
  scratch buffers that OOM.
parent cbced7f0
......@@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "exl2"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
......
......@@ -89,7 +89,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
if self.quantize == "gptq":
if self.quantize in {"exl2", "gptq"}:
try:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
......
from dataclasses import dataclass, field
import os
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from typing import List, Dict, Optional, Set, Tuple, Union
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_once
......@@ -76,8 +79,9 @@ class Weights:
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]:
# u4 which are disguised as int32. Exl2 uses int16
# as well.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
......@@ -102,8 +106,8 @@ class Weights:
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
# u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32):
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
......@@ -183,7 +187,15 @@ class Weights:
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=False,
)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
......@@ -207,8 +219,34 @@ class Weights:
weight = weight.to(dtype=self.dtype)
return weight
def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
return self.get_multi_weights_col([prefix], quantize, 0)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize in ["gptq", "awq"]:
if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
......@@ -259,7 +297,15 @@ class Weights:
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
......@@ -282,7 +328,28 @@ class Weights:
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
if quantize == "exl2":
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
elif quantize == "gptq":
use_exllama = True
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
......@@ -363,7 +430,15 @@ class Weights:
// groupsize
).to(dtype=torch.int32)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq":
bits, groupsize, _, _ = self._get_gptq_params()
......@@ -379,7 +454,15 @@ class Weights:
g_idx = None
use_exllama = False
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment