Unverified Commit fdde1822 authored by Areeb Syed's avatar Areeb Syed Committed by GitHub
Browse files

[Bugfix] Fix shape mismatch assertion error when loading Gemma3n model with...


[Bugfix] Fix shape mismatch assertion error when loading Gemma3n model with BitsAndBytes quantization (#21808)
Signed-off-by: default avatarsydarb <areebsyed237@gmail.com>
parent b917da44
...@@ -167,22 +167,33 @@ class Gemma3nAltUp(nn.Module): ...@@ -167,22 +167,33 @@ class Gemma3nAltUp(nn.Module):
class Gemma3nLaurelBlock(nn.Module): class Gemma3nLaurelBlock(nn.Module):
"""Learned Augmented Residual Layer""" """Learned Augmented Residual Layer"""
def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float, def __init__(
prefix: str): self,
hidden_size: int,
laurel_rank: int,
rms_norm_eps: float,
*,
quant_config: Optional[QuantizationConfig] = None,
prefix: str,
) -> None:
super().__init__() super().__init__()
self.linear_left = ColumnParallelLinear( self.linear_left = ColumnParallelLinear(
hidden_size, hidden_size,
laurel_rank, laurel_rank,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_left", prefix=f"{prefix}.linear_left",
return_bias=False, return_bias=False,
) )
self.linear_right = RowParallelLinear(laurel_rank, self.linear_right = RowParallelLinear(
laurel_rank,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear_right", prefix=f"{prefix}.linear_right",
return_bias=False) return_bias=False,
)
self.post_laurel_norm = RMSNorm( self.post_laurel_norm = RMSNorm(
hidden_size=hidden_size, hidden_size=hidden_size,
eps=rms_norm_eps, eps=rms_norm_eps,
...@@ -417,6 +428,7 @@ class Gemma3nDecoderLayer(nn.Module): ...@@ -417,6 +428,7 @@ class Gemma3nDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
laurel_rank=config.laurel_rank, laurel_rank=config.laurel_rank,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.laurel", prefix=f"{prefix}.laurel",
) )
...@@ -427,6 +439,7 @@ class Gemma3nDecoderLayer(nn.Module): ...@@ -427,6 +439,7 @@ class Gemma3nDecoderLayer(nn.Module):
config.hidden_size, config.hidden_size,
config.hidden_size_per_layer_input, config.hidden_size_per_layer_input,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_input_gate", prefix=f"{prefix}.per_layer_input_gate",
return_bias=False, return_bias=False,
) )
...@@ -434,6 +447,7 @@ class Gemma3nDecoderLayer(nn.Module): ...@@ -434,6 +447,7 @@ class Gemma3nDecoderLayer(nn.Module):
config.hidden_size_per_layer_input, config.hidden_size_per_layer_input,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_projection", prefix=f"{prefix}.per_layer_projection",
return_bias=False, return_bias=False,
) )
...@@ -547,6 +561,7 @@ class Gemma3nTextModel(nn.Module): ...@@ -547,6 +561,7 @@ class Gemma3nTextModel(nn.Module):
bias=False, bias=False,
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_model_projection", prefix=f"{prefix}.per_layer_model_projection",
) )
self.per_layer_projection_norm = RMSNorm( self.per_layer_projection_norm = RMSNorm(
...@@ -566,6 +581,7 @@ class Gemma3nTextModel(nn.Module): ...@@ -566,6 +581,7 @@ class Gemma3nTextModel(nn.Module):
bias=False, bias=False,
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_projections", prefix=f"{prefix}.{idx-1}.altup_projections",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])
...@@ -576,6 +592,7 @@ class Gemma3nTextModel(nn.Module): ...@@ -576,6 +592,7 @@ class Gemma3nTextModel(nn.Module):
bias=False, bias=False,
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_unembed_projections", prefix=f"{prefix}.{idx-1}.altup_unembed_projections",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment