Commit 0d96468e authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

marlin: support tp>1 when group_size==-1

parent 4594e6fa
...@@ -513,7 +513,13 @@ class Weights: ...@@ -513,7 +513,13 @@ class Weights:
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
) )
s = self.get_sharded(f"{prefix}.s", dim=0) num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when group_size == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s) weight = MarlinWeight(B=B, s=s)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment