Unverified Commit 093a27c5 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for GPTQ Marlin (#2052)

Add support for GPTQ Marlin kernels

GPTQ Marlin extends the Marlin kernels to support common GPTQ
configurations:

- bits: 4 or 8
- groupsize: -1, 32, 64, or 128
- desc_act: true/false

Using the GPTQ Marlin kernels requires repacking the parameters in the
Marlin quantizer format.

The kernels were contributed by Neural Magic to VLLM. We vendor them
here for convenience.
parent f433f1f7
......@@ -83,7 +83,7 @@ class BLOOMSharded(CausalLM):
process_group=self.process_group,
prefix="transformer",
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights)
......
......@@ -166,7 +166,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
......
......@@ -81,16 +81,11 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
(
bits,
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq":
gptq_params = weights._get_gptq_params()
if gptq_params.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
elif quant_method == "awq":
elif gptq_params.quant_method == "awq":
g_idx = None
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
......@@ -105,8 +100,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=HAS_EXLLAMA,
)
......
......@@ -130,7 +130,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
......
......@@ -55,7 +55,7 @@ class FlashCohere(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights)
......
......@@ -80,7 +80,7 @@ class FlashDbrx(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights)
......
......@@ -53,7 +53,7 @@ class FlashGemma(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
......
......@@ -67,7 +67,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "exl2"]:
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
......
......@@ -68,7 +68,7 @@ class BaseFlashMistral(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
......
......@@ -58,7 +58,7 @@ class FlashNeoXSharded(FlashCausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights)
......
......@@ -53,7 +53,7 @@ class FlashPhi(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
......
......@@ -62,7 +62,7 @@ class FlashQwen2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights)
......
......@@ -67,7 +67,7 @@ class FlashRWSharded(FlashCausalLM):
config.quantize = quantize
config.speculator = speculator
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights)
......
......@@ -69,7 +69,7 @@ class FlashSantacoderSharded(FlashCausalLM):
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights)
......
......@@ -61,7 +61,7 @@ class FlashStarcoder2(BaseFlashMistral):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
......
......@@ -205,7 +205,7 @@ class GalacticaSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
......
......@@ -58,7 +58,7 @@ class GPTNeoxSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights)
......
......@@ -82,7 +82,7 @@ class MPTSharded(CausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
config.quantize = quantize
......
......@@ -56,7 +56,7 @@ class OPTSharded(CausalLM):
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize == "gptq":
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment