Unverified Commit 80adb5be authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Hotfix: fix of use of unquantized weights in Gemma GQA loading (#2255)

parent ba291dad
...@@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding ...@@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class Gemma2Config(PretrainedConfig): class Gemma2Config(PretrainedConfig):
...@@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights): ...@@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None)) return TensorParallelColumnLinear(get_linear(weight, bias=None))
......
...@@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding ...@@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
...@@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights): ...@@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None)) return TensorParallelColumnLinear(get_linear(weight, bias=None))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment