Commit 36dd1601 authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Add support for exl2 quantization

Mostly straightforward, changes to existing code:

* Wrap quantizer parameters in a small wrapper to avoid passing
  around untyped tuples and needing to repack them as a dict.
* Move scratch space computation to warmup, because we need the
  maximum input sequence length to avoid allocating huge
  scratch buffers that OOM.
parent cbced7f0
...@@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM): ...@@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "exl2"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
......
...@@ -89,7 +89,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -89,7 +89,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
if self.quantize == "gptq": if self.quantize in {"exl2", "gptq"}:
try: try:
# When using GPTQ, Exllama kernels need some global kernels # When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded # For which we have the finale shapes only after the model has loaded
......
from dataclasses import dataclass, field
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Set, Tuple, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
...@@ -76,8 +79,9 @@ class Weights: ...@@ -76,8 +79,9 @@ class Weights:
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32 # u4 which are disguised as int32. Exl2 uses int16
if tensor.dtype not in [torch.int32, torch.int64]: # as well.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device: if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
...@@ -102,8 +106,8 @@ class Weights: ...@@ -102,8 +106,8 @@ class Weights:
else: else:
raise NotImplementedError("Let's make that generic when needed") raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32 # u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype != torch.int32: if tensor.dtype not in (torch.int16, torch.int32):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
...@@ -183,7 +187,15 @@ class Weights: ...@@ -183,7 +187,15 @@ class Weights:
else: else:
g_idx = None g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=False,
)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
...@@ -207,8 +219,34 @@ class Weights: ...@@ -207,8 +219,34 @@ class Weights:
weight = weight.to(dtype=self.dtype) weight = weight.to(dtype=self.dtype)
return weight return weight
def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
return self.get_multi_weights_col([prefix], quantize, 0)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize in ["gptq", "awq"]: if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
try: try:
qweight = torch.cat( qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
...@@ -259,7 +297,15 @@ class Weights: ...@@ -259,7 +297,15 @@ class Weights:
else: else:
g_idx = None g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
...@@ -282,7 +328,28 @@ class Weights: ...@@ -282,7 +328,28 @@ class Weights:
return tensor return tensor
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "exl2":
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
elif quantize == "gptq":
use_exllama = True use_exllama = True
bits, groupsize, desc_act, quant_method = self._get_gptq_params() bits, groupsize, desc_act, quant_method = self._get_gptq_params()
...@@ -363,7 +430,15 @@ class Weights: ...@@ -363,7 +430,15 @@ class Weights:
// groupsize // groupsize
).to(dtype=torch.int32) ).to(dtype=torch.int32)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq": elif quantize == "awq":
bits, groupsize, _, _ = self._get_gptq_params() bits, groupsize, _, _ = self._get_gptq_params()
...@@ -379,7 +454,15 @@ class Weights: ...@@ -379,7 +454,15 @@ class Weights:
g_idx = None g_idx = None
use_exllama = False use_exllama = False
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment