Commit 4594e6fa authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Add support for Marlin-quantized models

This change adds support for Marlin-quantized models. Marlin is an
FP16xINT4 matmul kernel, which provides good speedups decoding batches
of 16-32 tokens. It supports quantized models with symmetric
quantization, groupsize -1 or 128, and 4-bit.

Tested with:

- Llama 2
- Llama 3
- Phi 3
parent cf0d459a
......@@ -29,6 +29,10 @@ def load_multi_mqa(
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
elif config.quantize == "marlin":
raise RuntimeError(
"santacoder models with marlin quantization are not yet supported"
)
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size
......
......@@ -58,7 +58,7 @@ class FlashGPT2(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
......
......@@ -202,6 +202,12 @@ class Weights:
groupsize=groupsize,
use_exllama=False,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
B = self._get_qweight(f"{prefix}.B", blocks)
s = self._get_qweight(f"{prefix}.s", blocks)
weight = MarlinWeight(B=B, s=s)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
......@@ -316,9 +322,25 @@ class Weights:
groupsize=groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)
weight = MarlinWeight(B=B, s=s)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_tensor_shard(self, var, dim):
......@@ -481,6 +503,19 @@ class Weights:
groupsize=groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment