Commit 4594e6fa authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Add support for Marlin-quantized models

This change adds support for Marlin-quantized models. Marlin is an
FP16xINT4 matmul kernel, which provides good speedups decoding batches
of 16-32 tokens. It supports quantized models with symmetric
quantization, groupsize -1 or 128, and 4-bit.

Tested with:

- Llama 2
- Llama 3
- Phi 3
parent cf0d459a
...@@ -29,6 +29,10 @@ def load_multi_mqa( ...@@ -29,6 +29,10 @@ def load_multi_mqa(
return _load_multi_mqa_gptq( return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size
) )
elif config.quantize == "marlin":
raise RuntimeError(
"santacoder models with marlin quantization are not yet supported"
)
else: else:
return _load_multi_mqa( return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size
......
...@@ -58,7 +58,7 @@ class FlashGPT2(FlashCausalLM): ...@@ -58,7 +58,7 @@ class FlashGPT2(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
......
...@@ -202,6 +202,12 @@ class Weights: ...@@ -202,6 +202,12 @@ class Weights:
groupsize=groupsize, groupsize=groupsize,
use_exllama=False, use_exllama=False,
) )
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
B = self._get_qweight(f"{prefix}.B", blocks)
s = self._get_qweight(f"{prefix}.s", blocks)
weight = MarlinWeight(B=B, s=s)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
...@@ -316,9 +322,25 @@ class Weights: ...@@ -316,9 +322,25 @@ class Weights:
groupsize=groupsize, groupsize=groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat([self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1)
weight = MarlinWeight(B=B, s=s)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
return weight return weight
def get_tensor_shard(self, var, dim): def get_tensor_shard(self, var, dim):
...@@ -481,6 +503,19 @@ class Weights: ...@@ -481,6 +503,19 @@ class Weights:
groupsize=groupsize, groupsize=groupsize,
use_exllama=use_exllama, use_exllama=use_exllama,
) )
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment