Unverified Commit 093a27c5 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for GPTQ Marlin (#2052)

Add support for GPTQ Marlin kernels

GPTQ Marlin extends the Marlin kernels to support common GPTQ
configurations:

- bits: 4 or 8
- groupsize: -1, 32, 64, or 128
- desc_act: true/false

Using the GPTQ Marlin kernels requires repacking the parameters in the
Marlin quantizer format.

The kernels were contributed by Neural Magic to VLLM. We vendor them
here for convenience.
parent f433f1f7
...@@ -83,7 +83,7 @@ class BLOOMSharded(CausalLM): ...@@ -83,7 +83,7 @@ class BLOOMSharded(CausalLM):
process_group=self.process_group, process_group=self.process_group,
prefix="transformer", prefix="transformer",
) )
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights) model = BloomForCausalLM(config, weights)
......
...@@ -166,7 +166,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -166,7 +166,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq"]: if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
......
...@@ -81,16 +81,11 @@ def _load_multi_mqa_gptq( ...@@ -81,16 +81,11 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device) qzeros = qzeros.to(device=weights.device)
( gptq_params = weights._get_gptq_params()
bits, if gptq_params.quant_method == "gptq":
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
elif quant_method == "awq": elif gptq_params.quant_method == "awq":
g_idx = None g_idx = None
from text_generation_server.layers.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
...@@ -105,8 +100,8 @@ def _load_multi_mqa_gptq( ...@@ -105,8 +100,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=bits, bits=gptq_params.bits,
groupsize=groupsize, groupsize=gptq_params.groupsize,
use_exllama=HAS_EXLLAMA, use_exllama=HAS_EXLLAMA,
) )
......
...@@ -130,7 +130,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -130,7 +130,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq"]: if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
......
...@@ -55,7 +55,7 @@ class FlashCohere(FlashCausalLM): ...@@ -55,7 +55,7 @@ class FlashCohere(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights) model = FlashCohereForCausalLM(config, weights)
......
...@@ -80,7 +80,7 @@ class FlashDbrx(FlashCausalLM): ...@@ -80,7 +80,7 @@ class FlashDbrx(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights) model = FlashDbrxForCausalLM(config, weights)
......
...@@ -53,7 +53,7 @@ class FlashGemma(FlashCausalLM): ...@@ -53,7 +53,7 @@ class FlashGemma(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
# TODO hardcoded # TODO hardcoded
......
...@@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM): ...@@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "exl2"]: if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
......
...@@ -68,7 +68,7 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -68,7 +68,7 @@ class BaseFlashMistral(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
......
...@@ -58,7 +58,7 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -58,7 +58,7 @@ class FlashNeoXSharded(FlashCausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights) model = FlashGPTNeoXForCausalLM(config, weights)
......
...@@ -53,7 +53,7 @@ class FlashPhi(FlashCausalLM): ...@@ -53,7 +53,7 @@ class FlashPhi(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights) model = FlashPhiForCausalLM(config, weights)
......
...@@ -62,7 +62,7 @@ class FlashQwen2(BaseFlashMistral): ...@@ -62,7 +62,7 @@ class FlashQwen2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights) model = Qwen2ForCausalLM(config, weights)
......
...@@ -67,7 +67,7 @@ class FlashRWSharded(FlashCausalLM): ...@@ -67,7 +67,7 @@ class FlashRWSharded(FlashCausalLM):
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights) model = FlashRWForCausalLM(config, weights)
......
...@@ -69,7 +69,7 @@ class FlashSantacoderSharded(FlashCausalLM): ...@@ -69,7 +69,7 @@ class FlashSantacoderSharded(FlashCausalLM):
process_group=self.process_group, process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]}, aliases={"transformer.wte.weight": ["lm_head.weight"]},
) )
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights) model = FlashSantacoderForCausalLM(config, weights)
......
...@@ -61,7 +61,7 @@ class FlashStarcoder2(BaseFlashMistral): ...@@ -61,7 +61,7 @@ class FlashStarcoder2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights) model = FlashStarcoder2ForCausalLM(config, weights)
......
...@@ -205,7 +205,7 @@ class GalacticaSharded(CausalLM): ...@@ -205,7 +205,7 @@ class GalacticaSharded(CausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights) model = OPTForCausalLM(config, weights)
......
...@@ -58,7 +58,7 @@ class GPTNeoxSharded(CausalLM): ...@@ -58,7 +58,7 @@ class GPTNeoxSharded(CausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights) model = GPTNeoxForCausalLM(config, weights)
......
...@@ -82,7 +82,7 @@ class MPTSharded(CausalLM): ...@@ -82,7 +82,7 @@ class MPTSharded(CausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
config.quantize = quantize config.quantize = quantize
......
...@@ -56,7 +56,7 @@ class OPTSharded(CausalLM): ...@@ -56,7 +56,7 @@ class OPTSharded(CausalLM):
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group
) )
if config.quantize == "gptq": if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights) model = OPTForCausalLM(config, weights)
......
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
...@@ -9,6 +10,15 @@ import json ...@@ -9,6 +10,15 @@ import json
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
@dataclass
class _GPTQParams:
bits: int
groupsize: int
desc_act: bool
quant_method: str
sym: bool
class Weights: class Weights:
def __init__( def __init__(
self, self,
...@@ -181,15 +191,15 @@ class Weights: ...@@ -181,15 +191,15 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized." f"Cannot load `{quantize}` weight, make sure the model is already quantized."
) )
bits, groupsize, _, quant_method = self._get_gptq_params() gptq_params = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", block_sizes) scales = self._get_qweight(f"{prefix}.scales", block_sizes)
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and quant_method == "awq": elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
) )
...@@ -199,8 +209,11 @@ class Weights: ...@@ -199,8 +209,11 @@ class Weights:
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = ( g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) torch.arange(
// groupsize qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32) ).to(dtype=torch.int32)
else: else:
g_idx = None g_idx = None
...@@ -210,16 +223,43 @@ class Weights: ...@@ -210,16 +223,43 @@ class Weights:
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=bits, bits=gptq_params.bits,
groupsize=groupsize, groupsize=gptq_params.groupsize,
use_exllama=False, use_exllama=False,
) )
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
B = self._get_qweight(f"{prefix}.B", block_sizes) quant_method = getattr(self, "quant_method", "marlin")
s = self._get_qweight(f"{prefix}.s", block_sizes) if quant_method == "gptq":
weight = MarlinWeight(B=B, s=s) gptq_params = self._get_gptq_params()
try:
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
g_idx = self.get_tensor(f"{prefix}.g_idx")
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
...@@ -295,20 +335,23 @@ class Weights: ...@@ -295,20 +335,23 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
) )
bits, groupsize, desc_act, quant_method = self._get_gptq_params() gptq_params = self._get_gptq_params()
from text_generation_server.layers.gptq import HAS_EXLLAMA from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = ( use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act gptq_params.bits == 4
and HAS_EXLLAMA
and quantize == "gptq"
and not gptq_params.desc_act
) )
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and gptq_params.quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
g_idx = w[0] g_idx = w[0]
elif quantize == "gptq" and quant_method == "awq": elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
) )
...@@ -322,9 +365,10 @@ class Weights: ...@@ -322,9 +365,10 @@ class Weights:
else: else:
g_idx = ( g_idx = (
torch.arange( torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
) )
// groupsize // gptq_params.groupsize
).to(dtype=torch.int32) ).to(dtype=torch.int32)
else: else:
g_idx = None g_idx = None
...@@ -334,24 +378,62 @@ class Weights: ...@@ -334,24 +378,62 @@ class Weights:
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=bits, bits=gptq_params.bits,
groupsize=groupsize, groupsize=gptq_params.groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
try: quant_method = getattr(self, "quant_method", "marlin")
B = torch.cat( if quant_method == "gptq":
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 gptq_params = self._get_gptq_params()
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
) )
except RuntimeError: w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
raise RuntimeError( for w2 in w[1:]:
f"Cannot load `{quantize}` weight, make sure the model is already quantized" torch.testing.assert_close(w2, w[0])
g_idx = w[0]
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
) )
s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)
weight = MarlinWeight(B=B, s=s) weight = MarlinWeight(B=B, s=s)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
...@@ -401,12 +483,12 @@ class Weights: ...@@ -401,12 +483,12 @@ class Weights:
elif quantize == "gptq": elif quantize == "gptq":
use_exllama = True use_exllama = True
bits, groupsize, desc_act, quant_method = self._get_gptq_params() gptq_params = self._get_gptq_params()
if bits != 4: if gptq_params.bits != 4:
use_exllama = False use_exllama = False
if desc_act: if gptq_params.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
...@@ -417,9 +499,9 @@ class Weights: ...@@ -417,9 +499,9 @@ class Weights:
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
) )
if quant_method == "gptq": if gptq_params.quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
elif quant_method == "awq": elif gptq_params.quant_method == "awq":
g_idx = None g_idx = None
if self.process_group.size() > 1: if self.process_group.size() > 1:
...@@ -428,7 +510,10 @@ class Weights: ...@@ -428,7 +510,10 @@ class Weights:
not torch.equal( not torch.equal(
g_idx.cpu(), g_idx.cpu(),
torch.tensor( torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])], [
i // gptq_params.groupsize
for i in range(g_idx.shape[0])
],
dtype=torch.int32, dtype=torch.int32,
), ),
) )
...@@ -455,7 +540,7 @@ class Weights: ...@@ -455,7 +540,7 @@ class Weights:
else: else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and groupsize != -1: if use_exllama and gptq_params.groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
else: else:
...@@ -465,7 +550,7 @@ class Weights: ...@@ -465,7 +550,7 @@ class Weights:
if use_exllama and g_idx is not None: if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0] g_idx = g_idx - g_idx[0]
if quant_method == "awq": if gptq_params.quant_method == "awq":
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
) )
...@@ -479,9 +564,10 @@ class Weights: ...@@ -479,9 +564,10 @@ class Weights:
else: else:
g_idx = ( g_idx = (
torch.arange( torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
) )
// groupsize // gptq_params.groupsize
).to(dtype=torch.int32) ).to(dtype=torch.int32)
weight = GPTQWeight( weight = GPTQWeight(
...@@ -489,14 +575,14 @@ class Weights: ...@@ -489,14 +575,14 @@ class Weights:
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=bits, bits=gptq_params.bits,
groupsize=groupsize, groupsize=gptq_params.groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "awq": elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
bits, groupsize, _, _ = self._get_gptq_params() gptq_params = self._get_gptq_params()
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
...@@ -515,38 +601,74 @@ class Weights: ...@@ -515,38 +601,74 @@ class Weights:
qzeros=qzeros, qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
bits=bits, bits=gptq_params.bits,
groupsize=groupsize, groupsize=gptq_params.groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "marlin": elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
MarlinWeight,
repack_gptq_for_marlin,
)
try: quant_method = getattr(self, "quant_method", "marlin")
B = self.get_sharded(f"{prefix}.B", dim=0) if quant_method == "gptq":
except RuntimeError: log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
raise RuntimeError( gptq_params = self._get_gptq_params()
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] try:
if num_groups == 1: qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
# The number of groups is 1 when group_size == -1. share except RuntimeError:
# scales between all shards in this case. raise RuntimeError(
s = self.get_tensor(f"{prefix}.s") f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
else: else:
s = self.get_sharded(f"{prefix}.s", dim=0) try:
weight = MarlinWeight(B=B, s=s) B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_params(self) -> Tuple[int, int, int, str]: def _get_gptq_params(self) -> _GPTQParams:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False desc_act = False
sym = True
quant_method = "gptq" quant_method = "gptq"
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
...@@ -554,10 +676,17 @@ class Weights: ...@@ -554,10 +676,17 @@ class Weights:
groupsize = self.gptq_groupsize groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False) desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq") quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)
except Exception: except Exception:
raise e raise e
return bits, groupsize, desc_act, quant_method return _GPTQParams(
bits=bits,
desc_act=desc_act,
groupsize=groupsize,
quant_method=quant_method,
sym=sym,
)
def _set_gptq_params(self, model_id, revision): def _set_gptq_params(self, model_id, revision):
filename = "config.json" filename = "config.json"
...@@ -574,6 +703,7 @@ class Weights: ...@@ -574,6 +703,7 @@ class Weights:
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"] self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_sym = data["quantization_config"]["sym"]
self.gptq_desc_act = data["quantization_config"]["desc_act"] self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
...@@ -588,6 +718,7 @@ class Weights: ...@@ -588,6 +718,7 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"] self.gptq_groupsize = data["group_size"]
self.gptq_sym = data["sym"]
self.gptq_desc_act = data["desc_act"] self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq" self.quant_method = "awq"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment