Unverified Commit 093a27c5 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for GPTQ Marlin (#2052)

Add support for GPTQ Marlin kernels

GPTQ Marlin extends the Marlin kernels to support common GPTQ
configurations:

- bits: 4 or 8
- groupsize: -1, 32, 64, or 128
- desc_act: true/false

Using the GPTQ Marlin kernels requires repacking the parameters in the
Marlin quantizer format.

The kernels were contributed by Neural Magic to VLLM. We vendor them
here for convenience.
parent f433f1f7
......@@ -83,7 +83,7 @@ class BLOOMSharded(CausalLM):
process_group=self.process_group,
prefix="transformer",
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights)
......
......@@ -166,7 +166,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
......
......@@ -81,16 +81,11 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
(
bits,
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq":
gptq_params = weights._get_gptq_params()
if gptq_params.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
elif quant_method == "awq":
elif gptq_params.quant_method == "awq":
g_idx = None
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
......@@ -105,8 +100,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=HAS_EXLLAMA,
)
......
......@@ -130,7 +130,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
......
......@@ -55,7 +55,7 @@ class FlashCohere(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights)
......
......@@ -80,7 +80,7 @@ class FlashDbrx(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights)
......
......@@ -53,7 +53,7 @@ class FlashGemma(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
......
......@@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "exl2"]:
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
......
......@@ -68,7 +68,7 @@ class BaseFlashMistral(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
......
......@@ -58,7 +58,7 @@ class FlashNeoXSharded(FlashCausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights)
......
......@@ -53,7 +53,7 @@ class FlashPhi(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
......
......@@ -62,7 +62,7 @@ class FlashQwen2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights)
......
......@@ -67,7 +67,7 @@ class FlashRWSharded(FlashCausalLM):
config.quantize = quantize
config.speculator = speculator
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights)
......
......@@ -69,7 +69,7 @@ class FlashSantacoderSharded(FlashCausalLM):
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights)
......
......@@ -61,7 +61,7 @@ class FlashStarcoder2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
......
......@@ -205,7 +205,7 @@ class GalacticaSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
......
......@@ -58,7 +58,7 @@ class GPTNeoxSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights)
......
......@@ -82,7 +82,7 @@ class MPTSharded(CausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
config.quantize = quantize
......
......@@ -56,7 +56,7 @@ class OPTSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
......
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
......@@ -9,6 +10,15 @@ import json
from text_generation_server.utils.log import log_once
@dataclass
class _GPTQParams:
bits: int
groupsize: int
desc_act: bool
quant_method: str
sym: bool
class Weights:
def __init__(
self,
......@@ -181,15 +191,15 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
bits, groupsize, _, quant_method = self._get_gptq_params()
gptq_params = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq":
if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and quant_method == "awq":
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
......@@ -199,8 +209,11 @@ class Weights:
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
......@@ -210,16 +223,43 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=False,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s)
quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
g_idx = self.get_tensor(f"{prefix}.g_idx")
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
......@@ -295,20 +335,23 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
gptq_params = self._get_gptq_params()
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
gptq_params.bits == 4
and HAS_EXLLAMA
and quantize == "gptq"
and not gptq_params.desc_act
)
if quantize == "gptq" and quant_method == "gptq":
if quantize == "gptq" and gptq_params.quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif quantize == "gptq" and quant_method == "awq":
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
......@@ -322,9 +365,10 @@ class Weights:
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// groupsize
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
......@@ -334,24 +378,62 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)
weight = MarlinWeight(B=B, s=s)
weight = MarlinWeight(B=B, s=s)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
......@@ -401,12 +483,12 @@ class Weights:
elif quantize == "gptq":
use_exllama = True
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
gptq_params = self._get_gptq_params()
if bits != 4:
if gptq_params.bits != 4:
use_exllama = False
if desc_act:
if gptq_params.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
......@@ -417,9 +499,9 @@ class Weights:
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if quant_method == "gptq":
if gptq_params.quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
elif quant_method == "awq":
elif gptq_params.quant_method == "awq":
g_idx = None
if self.process_group.size() > 1:
......@@ -428,7 +510,10 @@ class Weights:
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])],
[
i // gptq_params.groupsize
for i in range(g_idx.shape[0])
],
dtype=torch.int32,
),
)
......@@ -455,7 +540,7 @@ class Weights:
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and groupsize != -1:
if use_exllama and gptq_params.groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
......@@ -465,7 +550,7 @@ class Weights:
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if quant_method == "awq":
if gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
......@@ -479,9 +564,10 @@ class Weights:
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// groupsize
// gptq_params.groupsize
).to(dtype=torch.int32)
weight = GPTQWeight(
......@@ -489,14 +575,14 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
bits, groupsize, _, _ = self._get_gptq_params()
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
......@@ -515,38 +601,74 @@ class Weights:
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
quant_method = getattr(self, "quant_method", "marlin")
if quant_method == "gptq":
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
gptq_params = self._get_gptq_params()
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when group_size == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
def _get_gptq_params(self) -> _GPTQParams:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
sym = True
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
......@@ -554,10 +676,17 @@ class Weights:
groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)
except Exception:
raise e
return bits, groupsize, desc_act, quant_method
return _GPTQParams(
bits=bits,
desc_act=desc_act,
groupsize=groupsize,
quant_method=quant_method,
sym=sym,
)
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
......@@ -574,6 +703,7 @@ class Weights:
self.gptq_groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_sym = data["quantization_config"]["sym"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
......@@ -588,6 +718,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_sym = data["sym"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment