Unverified Commit 0f46a780 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Model] [Quantization] Support quantization for Gemma3n (#21974)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent e1a7fe4a
...@@ -46,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -46,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant
from .utils import (AutoWeightsLoader, extract_layer_index, from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter, make_layers, maybe_prefix) is_pp_missing_parameter, make_layers, maybe_prefix)
...@@ -68,6 +69,7 @@ class Gemma3nAltUp(nn.Module): ...@@ -68,6 +69,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs: int, altup_num_inputs: int,
altup_coef_clip: float, altup_coef_clip: float,
altup_active_idx: int, altup_active_idx: int,
quant_config: QuantizationConfig,
prefix: str, prefix: str,
): ):
super().__init__() super().__init__()
...@@ -80,6 +82,7 @@ class Gemma3nAltUp(nn.Module): ...@@ -80,6 +82,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs, altup_num_inputs,
altup_num_inputs, altup_num_inputs,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.correction_coefs", prefix=f"{prefix}.correction_coefs",
return_bias=False, return_bias=False,
) )
...@@ -87,6 +90,7 @@ class Gemma3nAltUp(nn.Module): ...@@ -87,6 +90,7 @@ class Gemma3nAltUp(nn.Module):
altup_num_inputs, altup_num_inputs,
altup_num_inputs**2, altup_num_inputs**2,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.prediction_coefs", prefix=f"{prefix}.prediction_coefs",
return_bias=False, return_bias=False,
) )
...@@ -94,6 +98,7 @@ class Gemma3nAltUp(nn.Module): ...@@ -94,6 +98,7 @@ class Gemma3nAltUp(nn.Module):
hidden_size, hidden_size,
altup_num_inputs, altup_num_inputs,
bias=False, bias=False,
quant_config=quant_config,
prefix=f"{prefix}.modality_router", prefix=f"{prefix}.modality_router",
return_bias=False, return_bias=False,
) )
...@@ -400,6 +405,7 @@ class Gemma3nDecoderLayer(nn.Module): ...@@ -400,6 +405,7 @@ class Gemma3nDecoderLayer(nn.Module):
altup_num_inputs=config.altup_num_inputs, altup_num_inputs=config.altup_num_inputs,
altup_coef_clip=config.altup_coef_clip, altup_coef_clip=config.altup_coef_clip,
altup_active_idx=config.altup_active_idx, altup_active_idx=config.altup_active_idx,
quant_config=quant_config,
prefix=f"{prefix}.altup", prefix=f"{prefix}.altup",
) )
self.self_attn = Gemma3nAttention( self.self_attn = Gemma3nAttention(
...@@ -527,7 +533,7 @@ class Gemma3nDecoderLayer(nn.Module): ...@@ -527,7 +533,7 @@ class Gemma3nDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Gemma3nTextModel(nn.Module): class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -540,6 +546,7 @@ class Gemma3nTextModel(nn.Module): ...@@ -540,6 +546,7 @@ class Gemma3nTextModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens", prefix=f"{prefix}.embed_tokens",
) )
self.embed_scale = torch.tensor( self.embed_scale = torch.tensor(
...@@ -549,6 +556,7 @@ class Gemma3nTextModel(nn.Module): ...@@ -549,6 +556,7 @@ class Gemma3nTextModel(nn.Module):
self.embed_tokens_per_layer = VocabParallelEmbedding( self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input, config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_embed_tokens", prefix=f"{prefix}.per_layer_embed_tokens",
) )
self.embed_scale_per_layer = torch.tensor( self.embed_scale_per_layer = torch.tensor(
...@@ -582,7 +590,7 @@ class Gemma3nTextModel(nn.Module): ...@@ -582,7 +590,7 @@ class Gemma3nTextModel(nn.Module):
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_projections", prefix=f"{prefix}.altup_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])
self.altup_unembed_projections = nn.ModuleList([ self.altup_unembed_projections = nn.ModuleList([
...@@ -593,7 +601,7 @@ class Gemma3nTextModel(nn.Module): ...@@ -593,7 +601,7 @@ class Gemma3nTextModel(nn.Module):
gather_output=True, gather_output=True,
return_bias=False, return_bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.{idx-1}.altup_unembed_projections", prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
) for idx in range(1, self.config.altup_num_inputs) ) for idx in range(1, self.config.altup_num_inputs)
]) ])
...@@ -774,7 +782,7 @@ class Gemma3nModel(nn.Module): ...@@ -774,7 +782,7 @@ class Gemma3nModel(nn.Module):
**kwargs) **kwargs)
class Gemma3nForConditionalGeneration(nn.Module): class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment