Unverified Commit 9b00990b authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

chore: remove vlm unnecessary import (#7541)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avatarMick <mickjagger19@icloud.com>
parent 4d67025a
......@@ -565,6 +565,7 @@ multimodal_model_archs = [
"CLIPModel",
"DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
......
......@@ -823,6 +823,7 @@ register_conv_template(
sep_style=SeparatorStyle.GEMMA3,
stop_str=["<end_of_turn>"],
image_token="<start_of_image>",
audio_token="<start_of_audio>",
)
)
......
......@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values"
AUDIO = "audio"
@dataclasses.dataclass
......@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
has_image = False
has_pixel_values = False
has_precomputed_features = False
has_audio = False
for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image):
has_image = True
elif isinstance(mm_input, np.ndarray):
has_audio = True
elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True
......@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
# Validate format consistency
format_count = sum(
[has_image, has_pixel_values, has_precomputed_features]
[has_image, has_pixel_values, has_precomputed_features, has_audio]
)
if format_count > 1:
raise ValueError(
"Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}"
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
)
if has_image:
......@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
return MultimodalInputFormat.PRECOMPUTED_FEATURES
elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES
elif has_audio:
return MultimodalInputFormat.AUDIO
else:
raise ValueError("No valid multimodal input format found")
except Exception as e:
......@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_audio(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with audio."""
ret = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios, # Note: "audio" is for gemma3n only
)
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def finalize_mm_item(
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item."""
combined_mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
combined_mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.AUDIO:
combined_mm_item.audio_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.VIDEO:
combined_mm_item.video_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID,
)
else:
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
return combined_mm_item
# Main logic
mm_inputs = base_output.images
# Main logic - determine input type and handle text-only case
mm_inputs = base_output.images or base_output.audios
if not mm_inputs:
# Return text-only case
input_ids = tokenize_text(base_output.input_text)
return None, input_ids
......@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output)
elif input_format == MultimodalInputFormat.AUDIO:
combined_mm_item, input_ids = process_audio(base_output)
else:
raise ValueError(f"Unknown input format: {input_format}")
......
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
from typing import Dict, List, Optional, Union
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
)
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
models = [Gemma3nForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image_soft_token>"
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.AUDIO_TOKEN = "<audio_soft_token>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
)
self.IM_TOKEN_ID = hf_config.image_token_id
self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
async def process_mm_data_async(
self,
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
input_text: str = "",
request_obj=None,
max_req_input_len: int = 0,
*args,
**kwargs,
):
"""Process multimodal data including images and audio."""
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(audio_data, str):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
audio_data=audio_data,
max_req_input_len=max_req_input_len,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
)
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"audio_start_id": self.AUDIO_START_TOKEN_ID,
"audio_end_id": self.AUDIO_END_TOKEN_ID,
}
......@@ -214,6 +214,10 @@ class MultimodalDataItem:
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
# gemma3n related
input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod
......@@ -277,7 +281,10 @@ class MultimodalDataItem:
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
self.hash = hash_feature(self.audio_features)
if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
else:
self.hash = hash_feature(self.pixel_values)
......@@ -288,6 +295,7 @@ class MultimodalDataItem:
return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.audio_features)
or not MultimodalDataItem.is_empty_list(self.input_features)
)
def is_image(self):
......
import math
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Gemma3nAudioConfig, PreTrainedModel
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm
from sglang.srt.utils import add_prefix, make_layers
class Gemma3nCumulativeGroupNorm(nn.Module):
"""Applies Group Normalization cumulatively over the time dimension.
This layer normalizes the input by calculating the mean and variance
cumulatively over the time dimension (dim 1). The statistics are computed
over all feature dimensions (specified by `feature_dims` and `num_channels`)
for elements marked as valid by the optional `mask`.
If a `mask` is provided (True for valid, False for invalid/padded),
invalid time steps do not contribute to the statistics calculation, and
their corresponding output values are zeroed out.
Scale and bias, if enabled, are applied per-channel (last dimension).
This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
and `cumulative=True`.
"""
def __init__(
self,
num_channels: int, # Number of channels (size of the last dimension)
feature_dims: Sequence[
int
], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
eps: float = 1e-3,
):
super().__init__()
self.num_channels = num_channels
self.feature_dims = tuple(feature_dims)
self.eps = eps
# Scale parameter depends only on the channel dimension
self.weight = nn.Parameter(torch.ones(num_channels))
# Axes for normalization: all dimensions except Batch (0) and Time (1).
# For input [B, T, *feature_dims, C], these are dims from 2 onwards.
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Applies cumulative group norm, optionally using a mask.
Args:
x: Input tensor, shape [B, T, *feature_dims, C].
mask: Optional boolean mask, shape [B, T]. True indicates a valid
(non-padded) time step. If None, all time steps are considered valid.
Returns:
Normalized tensor with the same shape as x.
"""
expected_input_suffix = self.feature_dims + (self.num_channels,)
if x.shape[2:] != expected_input_suffix:
raise ValueError(
f"Input tensor shape suffix {x.shape[2:]} does not match expected"
f" suffix (feature_dims + num_channels) {expected_input_suffix}"
)
input_dtype = x.dtype
# Calculations are performed in float32 for numerical stability.
calc_dtype = torch.float32
x_calc = x.to(calc_dtype)
# Prepare a broadcastable mask (`mask_calc`).
# If no mask is provided, treat all elements as valid
# (mask_calc is all ones).
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
# Cumulative Statistics Calculation
# 1. Sum of values over reduction axes at each time step.
sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
# 2. Cumulative sum of values over time.
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
# 3. Count of valid elements in the normalization group at each time step.
# (A "group" here consists of all features at a given Batch, Time).
elements_in_group_at_t = torch.sum(
mask_calc, dim=self.reduction_axes, keepdim=True
)
# 4. Cumulative count of valid elements over time.
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
# Avoid division by zero if all preceding elements were masked.
safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
# 5. Cumulative mean.
cum_mean = cum_sum_values / safe_cum_count_elements
# 6. Sum of squared differences from the cumulative mean.
# Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
# Using x_calc here for the difference, as cum_mean already accounts for masking.
squared_diff_from_mean = (x_calc - cum_mean).pow(2)
sum_sq_diff_at_t = torch.sum(
squared_diff_from_mean, dim=self.reduction_axes, keepdim=True
)
# 7. Cumulative sum of squared differences over time.
cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
# 8. Cumulative variance.
cum_variance = cum_sum_sq_diff / safe_cum_count_elements
# Normalize the input using the calculated cumulative statistics:
# (x - E[x]) / sqrt(Var[x] + eps)
normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
# Apply affine transformation (scale and bias) if enabled.
# Scale and bias are applied per-channel (last dimension).
scale = self.weight.to(calc_dtype)
# Reshape for broadcasting: [C] -> [1, ..., 1, C]
scale_view_shape = [1] * (x.dim() - 1) + [self.num_channels]
normalized_x = normalized_x * scale.view(scale_view_shape)
# Zero out outputs for time steps that were originally masked (where mask_calc is 0).
# This ensures padded/invalid positions in the input result in zero output.
final_output = normalized_x * mask_calc
return final_output.to(input_dtype)
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.channels = self.config.hidden_size
self.head_dim = self.channels // self.num_heads
self.max_backward = max(0, self.config.conf_attention_context_left - 1)
self.max_forward = self.config.conf_attention_context_right
self.pos_proj = ColumnParallelLinear(
self.channels,
self.num_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("pos_proj", prefix),
)
min_timescale = 1.0
max_timescale = 1.0e4
num_timescales = self.channels // 2
log_timescale_increment = math.log(
float(max_timescale) / float(min_timescale)
) / max(num_timescales - 1, 1)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales) * -log_timescale_increment
)
self.register_buffer(
"inv_timescales",
inv_timescales.float().unsqueeze(0).unsqueeze(0),
persistent=False,
)
def _get_timing_signal_1d_pos(
self, position: torch.Tensor, dtype: torch.dtype
) -> torch.Tensor:
assert position.ndim == 2
position = position.float().unsqueeze(-1)
scaled_time = position * self.inv_timescales.to(
device=position.device, dtype=torch.float32
)
timing_signal = torch.cat(
[torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1
)
return timing_signal.type(dtype)
def _relative_shift(
self,
term_bd_before_shift: torch.Tensor,
batch_size: int,
num_heads: int,
num_query_blocks: int,
query_block_size: int,
key_context_size: int,
max_span_plus_1: int,
) -> torch.Tensor:
"""Performs the relative shift."""
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
padding_tuple = (0, pad_amount_last_dim)
term_bd_padded = F.pad(term_bd_before_shift, padding_tuple)
term_bd_reshaped = term_bd_padded.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size * (key_context_size + 1),
)
)
term_bd_sliced = term_bd_reshaped[
:, :, :, : query_block_size * key_context_size
]
term_bd_shifted = term_bd_sliced.reshape(
(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
)
)
return term_bd_shifted
def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = (
queries.shape
)
_, _, key_context_size, _, _ = keys.shape
pos_indices = torch.arange(
self.max_backward, -self.max_forward - 1, -1, device=queries.device
).unsqueeze(0)
max_span_plus_1 = pos_indices.shape[1]
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
pos_indices, dtype=queries.dtype
)
projected_sin_emb, _ = self.pos_proj(sin_emb_timing_signal)
sin_emb = projected_sin_emb.reshape(
1, max_span_plus_1, self.num_heads, self.head_dim
).squeeze(0)
queries_p = queries.permute(0, 3, 1, 2, 4)
keys_p_t = keys.permute(0, 3, 1, 4, 2)
term_ac = torch.matmul(queries_p, keys_p_t)
q_permuted = queries.permute(0, 3, 1, 2, 4)
s_permuted = sin_emb.permute(1, 2, 0)
q_reshaped = q_permuted.reshape(
batch_size, num_heads, num_query_blocks * query_block_size, head_dim
)
term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
batch_size,
num_heads,
num_query_blocks,
query_block_size,
max_span_plus_1,
)
term_bd_shifted = self._relative_shift(
term_bd_unshifed,
batch_size,
num_heads,
num_query_blocks,
query_block_size,
key_context_size,
max_span_plus_1,
)
return term_ac + term_bd_shifted
class Gemma3nAudioAttention(nn.Module):
"""Local dot product self-attention for audio."""
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.num_heads = self.config.conf_num_attention_heads
self.hidden_size = self.config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.chunk_size = self.config.conf_attention_chunk_size
self.max_future_horizon = self.config.conf_attention_context_right
self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
self.context_size = (
self.chunk_size + self.max_past_horizon + self.max_future_horizon
)
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(
config,
quant_config,
prefix=add_prefix("relative_position_embedding", prefix),
)
self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
q_scale = self.head_dim**-0.5
r_softplus_0 = 1.0 / F.softplus(torch.tensor(0.0))
self.register_buffer(
"q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False
)
# Create local causal mask
lower_causal_mask = torch.tril(
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
diagonal=0,
).T
upper_causal_mask = torch.tril(
torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
diagonal=self.max_past_horizon + self.max_future_horizon,
)
local_causal_valid_mask = torch.ones(
(self.chunk_size, self.context_size), dtype=torch.bool
)
local_causal_valid_mask = (
local_causal_valid_mask * lower_causal_mask * upper_causal_mask
)
self.register_buffer(
"local_causal_valid_mask", local_causal_valid_mask, persistent=False
)
self.register_buffer(
"softcap",
torch.tensor(self.attention_logits_soft_cap).float(),
persistent=False,
)
def _pad_dim1(
self, x: torch.Tensor, dim10_val: int, dim11_val: int
) -> torch.Tensor:
padding_tuple = [0] * x.ndim * 2
dim_idx_from_end = x.ndim - 2
start_idx_for_dim = 2 * dim_idx_from_end
padding_tuple[start_idx_for_dim] = dim10_val
padding_tuple[start_idx_for_dim + 1] = dim11_val
return F.pad(x, tuple(padding_tuple))
def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor:
"""Turns a sequence to non overlapping blocks."""
shape = x.shape
b, t = shape[:2]
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
if (padding_len := num_blocks * self.chunk_size - t) > 0:
x = self._pad_dim1(x, 0, padding_len)
permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
x = x.reshape(permute_dims).contiguous()
return x
def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor:
"""Extracts temporal context for every block."""
pad_left = self.max_past_horizon
pad_right = self.max_future_horizon + self.chunk_size - 1
x = self._pad_dim1(x, pad_left, pad_right)
frame_len = self.context_size
frame_step = self.chunk_size
x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step)
if x.ndim > 2 and x_unfolded.ndim > 3:
x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
return x_unfolded.contiguous()
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
# Project to Q, K, V
qkv, _ = self.qkv_proj(x)
query_states, key_states, value_states = qkv.chunk(chunks=3, dim=-1)
# Reshape
query_states = query_states.reshape(
*x.shape[:-1], self.num_heads, self.head_dim
).contiguous()
key_states = key_states.reshape(
*x.shape[:-1], self.num_heads, self.head_dim
).contiguous()
value_states = value_states.reshape(
*x.shape[:-1], self.num_heads, self.head_dim
).contiguous()
# Apply per-dim scale
per_dim_scale_sp = F.softplus(self.per_dim_scale)
broadcast_shape = (1, 1, 1, self.head_dim)
per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
batch_size, q_time = query_states.shape[:2]
# Convert to blocks
query_blocks = self._convert_to_block(query_states)
key_blocks = self._extract_block_context(key_states)
value_blocks = self._extract_block_context(value_states)
num_query_blocks = query_blocks.shape[1]
# Create mask for valid positions
original_valid_mask = ~mask
extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
if (
extracted_valid_mask_blocks.ndim == 4
and extracted_valid_mask_blocks.shape[0] == batch_size
and extracted_valid_mask_blocks.shape[1] == num_query_blocks
and extracted_valid_mask_blocks.shape[2]
* extracted_valid_mask_blocks.shape[3]
== self.context_size
):
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
batch_size, num_query_blocks, self.context_size
)
condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(
1
).unsqueeze(-2)
condition_from_causality = (
self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
)
final_condition_for_where = torch.logical_and(
condition_from_input_validity,
condition_from_causality.to(condition_from_input_validity.device),
)
# Compute attention scores
logits = self.relative_position_embedding(query_blocks, key_blocks)
# Apply attention logit softcap
softcap_val = self.softcap.to(logits.device)
logits = logits / softcap_val
logits = torch.tanh(logits)
logits = logits * softcap_val
# Apply the combined mask.
# final_condition_for_where will broadcast with logits [B,N,U,W,C]
logits = torch.where(
final_condition_for_where, logits, torch.finfo(logits.dtype).min
)
probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to(
dtype=value_blocks.dtype
)
# context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
h_dim = value_blocks.shape[-1]
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
result_bmm = torch.bmm(prob_bun, v_bun)
context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(
0, 1, 3, 2, 4
)
context_vectors = context_vectors.reshape(
(
batch_size,
num_query_blocks * self.chunk_size,
self.num_heads,
self.head_dim,
)
)
context_vectors = context_vectors[:, :q_time]
return context_vectors
class Gemma3nAudioSSCPConvBlock(nn.Module):
"""A single convolution block for the SubSampleConvProjection."""
def __init__(
self,
config: Gemma3nAudioConfig,
idx: int,
input_freq_dim: int,
manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.manual_padding = manual_padding
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
out_channels = self.config.sscp_conv_channel_size[idx]
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(kernel_h, kernel_w),
stride=(stride_h, stride_w),
padding=(0, 0), # Manual padding is used
bias=False,
)
f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
self.norm = Gemma3nCumulativeGroupNorm(
num_channels=out_channels,
feature_dims=(f_out_conv,),
eps=self.config.sscp_conv_group_norm_eps,
)
self.activation = nn.ReLU()
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_padded = F.pad(
audio_encodings, self.manual_padding, mode="constant", value=0.0
)
audio_encodings_conv = self.conv(audio_encodings_padded)
x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
x_normed = self.norm(x_for_norm)
audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
return self.activation(audio_encodings_normed)
class Gemma3nAudioSubSampleConvProjection(nn.Module):
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
current_f_for_block_input = config.input_feat_size
calculated_block_padding = []
calculated_f_out_dims = []
for i in range(2): # Assuming 2 conv layers
kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
stride_h, stride_w = config.sscp_conv_stride_size[i]
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
pad_t_top = 0
pad_t_bottom = kernel_h - 1
# Frequency Padding (Width for Conv2d)
pad_f_left = 1
pad_f_right = 1
manual_padding_tuple = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom)
calculated_block_padding.append(manual_padding_tuple)
f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1
calculated_f_out_dims.append(f_out_after_conv)
current_f_for_block_input = f_out_after_conv
self.conv_0 = Gemma3nAudioSSCPConvBlock(
idx=0,
input_freq_dim=config.input_feat_size,
config=config,
manual_padding=calculated_block_padding[0],
quant_config=quant_config,
prefix=add_prefix("conv_0", prefix),
)
self.conv_1 = Gemma3nAudioSSCPConvBlock(
idx=1,
input_freq_dim=calculated_f_out_dims[0],
config=config,
manual_padding=calculated_block_padding[1],
quant_config=quant_config,
prefix=add_prefix("conv_1", prefix),
)
final_c_out = config.sscp_conv_channel_size[-1]
final_f_out = calculated_f_out_dims[-1]
self.input_proj_in_features = final_c_out * final_f_out
self.input_proj_linear = RowParallelLinear(
self.input_proj_in_features,
self.config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("input_proj_linear", prefix),
)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_reshaped = audio_encodings.unsqueeze(1)
x = self.conv_0(audio_encodings_reshaped)
x = self.conv_1(x)
b, c_out, t_out, f_out = x.shape
x_permuted = x.permute(0, 2, 3, 1).contiguous()
output_flattened = x_permuted.view(b, t_out, f_out * c_out)
output, _ = self.input_proj_linear(output_flattened)
return output
class Gemma3nAudioConformerAttention(nn.Module):
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
head_dim = self.config.hidden_size // self.config.conf_num_attention_heads
self.post_in_shape = (self.config.conf_num_attention_heads, head_dim)
self.post_in_features = self.config.hidden_size
self.register_buffer(
"gradient_clipping",
torch.tensor(self.config.gradient_clipping),
persistent=False,
)
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.attn = Gemma3nAudioAttention(
config, quant_config, prefix=add_prefix("attn", prefix)
)
self.post = RowParallelLinear(
self.post_in_features,
self.config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("post", prefix),
)
self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
def forward(
self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
) -> torch.Tensor:
audio_encodings_input_to_attn = audio_encodings
audio_encodings = torch.clamp(
audio_encodings, -self.gradient_clipping, self.gradient_clipping
)
audio_encodings_norm = self.pre_attn_norm(audio_encodings)
audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
b, t, num_heads, head_dim = audio_encodings_attn_out.shape
audio_encodings_reshaped = audio_encodings_attn_out.reshape(
b, t, num_heads * head_dim
)
audio_encodings, _ = self.post(audio_encodings_reshaped)
audio_encodings = torch.clamp(
audio_encodings, -self.gradient_clipping, self.gradient_clipping
)
return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
class Gemma3nAudioConformerFeedForward(nn.Module):
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.register_buffer(
"gradient_clipping",
torch.tensor(self.config.gradient_clipping),
persistent=False,
)
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.ffw_layer_1 = ColumnParallelLinear(
self.config.hidden_size,
self.config.hidden_size * 4,
bias=False,
quant_config=quant_config,
prefix=add_prefix("ffw_layer_1", prefix),
)
self.ffw_layer_2 = RowParallelLinear(
self.config.hidden_size * 4,
self.config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("ffw_layer_2", prefix),
)
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
residual = audio_encodings
audio_encodings = torch.clamp(
audio_encodings, -self.gradient_clipping, self.gradient_clipping
)
audio_encodings = self.pre_layer_norm(audio_encodings)
audio_encodings, _ = self.ffw_layer_1(audio_encodings)
audio_encodings = F.silu(audio_encodings)
audio_encodings, _ = self.ffw_layer_2(audio_encodings)
audio_encodings = torch.clamp(
audio_encodings, -self.gradient_clipping, self.gradient_clipping
)
audio_encodings = self.post_layer_norm(audio_encodings)
return residual + (audio_encodings * self.post_layer_scale)
class Gemma3nAudioConformerLightConv1d(nn.Module):
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.pre_layer_norm = Gemma3nRMSNorm(
self.config.hidden_size, eps=self.config.rms_norm_eps
)
self.linear_start = ColumnParallelLinear(
self.config.hidden_size,
self.config.hidden_size * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("linear_start", prefix),
)
self.depthwise_conv1d = nn.Conv1d(
in_channels=self.config.hidden_size,
out_channels=self.config.hidden_size,
kernel_size=self.config.conf_conv_kernel_size,
stride=1,
padding=0, # Manual causal padding
groups=self.config.hidden_size, # Depthwise
bias=False,
)
self.register_buffer(
"gradient_clipping",
torch.tensor(self.config.gradient_clipping),
persistent=False,
)
self.conv_norm = Gemma3nRMSNorm(
self.config.hidden_size, eps=self.config.rms_norm_eps
)
self.linear_end = RowParallelLinear(
self.config.hidden_size,
self.config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("linear_end", prefix),
)
self.causal_padding = self.config.conf_conv_kernel_size - 1
def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
audio_encodings_residual = audio_encodings # Save for residual connection
audio_encodings = self.pre_layer_norm(audio_encodings)
audio_encodings, _ = self.linear_start(audio_encodings)
audio_encodings = F.glu(audio_encodings, dim=-1)
# Permute for Conv1d: [B, T, D] -> [B, D, T]
audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
# Apply manual causal padding
audio_encodings_permuted_padded = F.pad(
audio_encodings_permuted, (self.causal_padding, 0)
)
audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
# Permute back: [B, D, T_out] -> [B, T_out, D]
audio_encodings = audio_encodings.permute(0, 2, 1)
audio_encodings = torch.clamp(
audio_encodings, -self.gradient_clipping, self.gradient_clipping
)
audio_encodings = self.conv_norm(audio_encodings)
audio_encodings = F.silu(audio_encodings)
audio_encodings, _ = self.linear_end(audio_encodings)
output = audio_encodings + audio_encodings_residual
return output
class Gemma3nAudioConformerBlock(nn.Module):
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.ffw_layer_start = Gemma3nAudioConformerFeedForward(
config, quant_config, prefix=add_prefix("ffw_layer_start", prefix)
)
self.attention = Gemma3nAudioConformerAttention(
config, quant_config, prefix=add_prefix("attention", prefix)
)
self.lconv1d = Gemma3nAudioConformerLightConv1d(
config, quant_config, prefix=add_prefix("lconv1d", prefix)
)
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(
config, quant_config, prefix=add_prefix("ffw_layer_end", prefix)
)
self.register_buffer(
"gradient_clipping",
torch.tensor(self.config.gradient_clipping),
persistent=False,
)
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
def forward(
self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor
) -> torch.Tensor:
audio_encodings = self.ffw_layer_start(audio_encodings)
audio_encodings = self.attention(audio_encodings, audio_mel_mask)
validity_mask_for_lconv = ~audio_mel_mask # True for valid
audio_encodings_for_lconv_input = (
audio_encodings
* validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype)
)
audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
audio_encodings = self.ffw_layer_end(audio_encodings)
audio_encodings = torch.clamp(
audio_encodings, -self.gradient_clipping, self.gradient_clipping
)
output = self.norm(audio_encodings)
return output
class Gemma3nAudioEncoder(PreTrainedModel):
"""A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
config_class = Gemma3nAudioConfig
def __init__(
self,
config: Gemma3nAudioConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(config)
self.config = config
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(
config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix)
)
self.conformer = make_layers(
config.conf_num_hidden_layers,
lambda idx, prefix: Gemma3nAudioConformerBlock(
config=config,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("conformer", prefix),
)
def forward(
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
) -> Tuple[torch.Tensor, torch.BoolTensor]:
"""Encodes a batch of MELs.
Args:
audio_mel: a torch.Tensor of shape [batch, num_frames, mel_bins].
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
Returns:
audio_encodings: a torch.Tensor of shape
`[batch_size, reduced_time_frames, hidden_size]`
audio_mel_mask: a torch.BoolTensor of shape [batch, reduced_time_frames].
"""
audio_encodings = self.subsample_conv_projection(
audio_mel
) # audio_encodings: [B, T_sub, D]
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
t_sub = audio_encodings.shape[1]
time_stride_product = 1
for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
# Create indices for gathering from the original mask.
# These indices map to original time steps corresponding to the start of each
# receptive field in the subsampled output.
indices = (
torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
)
indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1)
# Expand indices for batch compatibility if B > 1 and indices is 1D.
if audio_mel_mask.ndim > 1 and indices.ndim == 1:
indices = indices.unsqueeze(0).expand(
audio_mel_mask.shape[0], -1
) # [B, T_sub]
elif (
audio_mel_mask.ndim == indices.ndim
and audio_mel_mask.shape[0] == 1
and indices.shape[0] != 1
and t_sub == indices.shape[0]
):
# Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
indices = indices.unsqueeze(0)
current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
# Fallback: Ensure mask length matches feature length after gather.
if current_mask.shape[1] != t_sub:
if current_mask.shape[1] > t_sub:
current_mask = current_mask[:, :t_sub]
else: # current_mask.shape[1] < t_sub
padding_needed = t_sub - current_mask.shape[1]
current_mask = F.pad(
current_mask, (0, padding_needed), value=True
) # Pad with True (masked)
for i, block in enumerate(self.conformer):
audio_encodings = block(
audio_encodings, current_mask
) # Pass the processed mask
if self.config.conf_reduction_factor > 1:
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
# Reduce the mask as well
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
# Final masking of audio_encodings based on the final current_mask
# Ensure current_mask length matches the finally reduced audio_encodings length
if current_mask.shape[1] != audio_encodings.shape[1]:
target_len = audio_encodings.shape[1]
mask_current_len = current_mask.shape[1]
if target_len > mask_current_len:
padding_needed = target_len - mask_current_len
current_mask = F.pad(current_mask, (0, padding_needed), value=True)
elif mask_current_len > target_len: # mask is longer
current_mask = current_mask[:, :target_len]
audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
return audio_encodings, current_mask
from typing import Iterable, Optional, Set, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel, Gemma3nTextConfig, PretrainedConfig, PreTrainedModel
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding
from sglang.srt.utils import add_prefix, make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
class Gemma3nRMSNorm(RMSNorm):
def __init__(
self,
dim: int,
eps: float = 1e-6,
with_scale: bool = True,
) -> None:
super().__init__(dim, eps=eps)
if not with_scale:
del self.weight
self.register_buffer(
"weight",
torch.ones(dim, dtype=torch.get_default_dtype()),
persistent=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
original_shape = x.shape
x_2d = x.contiguous().reshape(-1, original_shape[-1])
x_2d = super().forward(x_2d)
x = x_2d.reshape(original_shape)
return x
class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
pass
class Gemma3nMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_activation: str,
activation_sparsity: float = 0.0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma3n uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
# Use proper GELU with tanh approximation as specified
self.act_fn = GeluAndMul()
self.activation_sparsity = activation_sparsity
self.register_buffer(
"target_sparsity_tensor",
torch.tensor(self.activation_sparsity, dtype=torch.float32),
persistent=False,
) # moved from _gaussian_topk for cuda graph
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
# Split gate and up projections
gate_proj, up_proj = gate_up.chunk(2, dim=-1)
# Apply activation sparsity if needed
if self.activation_sparsity > 0.0:
gate_proj = self._gaussian_topk(gate_proj)
gate_up = torch.cat([gate_proj, up_proj], dim=-1)
# Apply GELU activation to gate projection and multiply with up projection
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
normal_dist = torch.distributions.normal.Normal(0, 1)
std_multiplier = normal_dist.icdf(self.target_sparsity_tensor)
std_multiplier = std_multiplier.type(inputs.dtype)
inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
cutoff_x = inputs_mean + inputs_std * std_multiplier
return F.relu(inputs - cutoff_x)
class Gemma3nLaurelBlock(nn.Module):
"""Learned Augmented Residual Layer"""
def __init__(
self,
config: Gemma3nTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.linear_left = ColumnParallelLinear(
config.hidden_size,
config.laurel_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("linear_left", prefix),
)
self.linear_right = RowParallelLinear(
config.laurel_rank,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("linear_right", prefix),
)
self.post_laurel_norm = Gemma3nRMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [num_tokens, hidden_size]
laurel_x, _ = self.linear_left(x)
laurel_x, _ = self.linear_right(laurel_x)
normed_laurel_x = self.post_laurel_norm(laurel_x)
return x + normed_laurel_x
class Gemma3nAltUp(nn.Module):
"""Alternating Updates (AltUp)"""
def __init__(
self,
config: Gemma3nTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.correct_output_scale = nn.Parameter(
torch.zeros(config.hidden_size, dtype=torch.float32)
)
self.correction_coefs = ColumnParallelLinear(
config.altup_num_inputs,
config.altup_num_inputs,
bias=False,
quant_config=quant_config,
prefix=add_prefix("correction_coefs", prefix),
)
self.prediction_coefs = ColumnParallelLinear(
config.altup_num_inputs,
config.altup_num_inputs**2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("prediction_coefs", prefix),
)
self.modality_router = ColumnParallelLinear(
config.hidden_size,
config.altup_num_inputs,
bias=False,
quant_config=quant_config,
prefix=add_prefix("modality_router", prefix),
)
self.router_norm = Gemma3nRMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
)
self.register_buffer(
"router_input_scale",
torch.tensor(config.hidden_size**-1.0),
persistent=False,
)
def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
# x : [num_tokens, hidden_size]
router_inputs = self.router_norm(x) * self.router_input_scale.to(
self.router_norm.weight.dtype
)
# router_inputs : [num_tokens, hidden_size]
routed, _ = self.modality_router(router_inputs)
# routed : [num_tokens, altup_num_inputs]
return torch.tanh(routed.float()).type_as(routed)
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Predicts the output of a layer using a trainable map.
hidden_states: [num_altup_inputs, num_tokens, hidden_size]
"""
modalities = self.compute_router_modalities(
hidden_states[self.config.altup_active_idx]
) # (n_tokens, altup_num_inputs)
# TODO: CHECK DO WE NEED THIS: self.prediction_coefs.float() # Force computation in float32, in-place operation
if self.config.altup_coef_clip is not None:
self.prediction_coefs.weight.data.clamp_(
-self.config.altup_coef_clip, self.config.altup_coef_clip
)
all_coefs, _ = self.prediction_coefs(
modalities
) # (n_tokens, altup_num_inputs) -> (n_tokens, altup_num_inputs**2)
all_coefs = all_coefs.reshape(
*modalities.shape[:-1],
self.config.altup_num_inputs,
self.config.altup_num_inputs,
).permute(0, 2, 1)
# permute hidden_states from [num_altup_inputs, num_tokens, hidden_size] to [num_tokens, hidden_size, altup_num_inputs]
predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs)
predictions = predictions.permute(2, 0, 1) # undo the permute
predictions += hidden_states # add the original input
return predictions.contiguous().type_as(
hidden_states
) # [num_altup_inputs, num_tokens, hidden_size]
def correct(
self, predictions: torch.Tensor, activated: torch.Tensor
) -> torch.Tensor:
"""Corrects the predictions relative to the activated inputs."""
# prediction : [num_altup_inputs, num_tokens, hidden_size]
# activated : [num_tokens, hidden_size]
modalities = self.compute_router_modalities(
activated
) # [num_tokens, altup_num_inputs]
innovation = (
activated - predictions[self.config.altup_active_idx]
) # [num_tokens, hidden_size]
innovation = innovation.repeat(
self.config.altup_num_inputs, 1, 1
) # (self.config.altup_num_inputs, num_tokens, hidden_size)
if self.config.altup_coef_clip is not None:
self.correction_coefs.weight.data.clamp_(
-self.config.altup_coef_clip, self.config.altup_coef_clip
)
all_coefs, _ = self.correction_coefs(
modalities
) # [num_tokens, altup_num_inputs]
all_coefs = (all_coefs + 1.0).permute(1, 0).unsqueeze(-1)
# # [num_tokens, altup_num_inputs, 1]
corrected = torch.mul(innovation, all_coefs)
corrected += predictions
return corrected.contiguous().type_as(activated)
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
"""Scales the provided 3D tensor."""
return corrected * self.correct_output_scale.to(corrected.dtype)
def forward(
self, hidden_states: torch.Tensor, activated: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts, correct, and optionally scales the output of a layer using trainable maps.
hidden_states: [num_altup_inputs, num_tokens, hidden_size]
"""
predictions = self.predict(hidden_states)
corrected = self.correct(predictions=predictions, activated=activated)
output = corrected[self.config.altup_active_idx]
if self.config.altup_correct_scale:
output = self.scale_corrected_output(output)
return corrected, output
class Gemma3nAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
layer_id: int,
config: Gemma3nTextConfig,
max_position_embeddings: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.config = config
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
hidden_size = config.hidden_size
head_dim = getattr(
config, "head_dim", hidden_size // config.num_attention_heads
)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# self.scaling = config.query_rescale_scalar / config.query_pre_attn_scalar
self.scaling = 1.0
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
# Determine if layer uses sliding window based on pattern
self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
# Check if this is a KV shared layer
first_kv_shared_layer_idx = (
config.num_hidden_layers - config.num_kv_shared_layers
)
self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx
# Compute the layer index from which shared KV cache values will be retrieved
if not self.is_kv_shared_layer:
self.kv_shared_layer_index = None
elif self.is_sliding:
self.kv_shared_layer_index = first_kv_shared_layer_idx - 2
else:
self.kv_shared_layer_index = first_kv_shared_layer_idx - 1
if self.is_sliding:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=config.rope_local_base_freq,
rope_scaling={"rope_type": "default"},
)
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
rope_scaling=config.rope_scaling,
)
self.sliding_window = config.sliding_window if self.is_sliding else None
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=(
layer_id if not self.is_kv_shared_layer else self.kv_shared_layer_index
),
logit_cap=0.0,
sliding_window_size=self.sliding_window,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
# Gemma3n adds normalization for q, k, v
self.q_norm = Gemma3nRMSNorm(
dim=config.head_dim,
eps=config.rms_norm_eps,
)
self.k_norm = Gemma3nRMSNorm(
dim=config.head_dim,
eps=config.rms_norm_eps,
)
self.v_norm = Gemma3nRMSNorm(
dim=config.head_dim,
eps=config.rms_norm_eps,
with_scale=False,
)
def forward(
self,
hidden_states: torch.Tensor,
positions: Tuple[torch.Tensor, torch.Tensor],
forward_batch: ForwardBatch,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# TODO: for first 20 layers, we use QKVParallelLinear
# for others, we only calc Q.
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Apply normalization to q, k, v
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
# Check if we should use shared KV cache
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None:
# For KV shared layers, we skip K/V computation and normalization
# The RadixAttention will handle retrieving shared KV from cache
k = None
v = None
else:
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
k = self.k_norm(k)
v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
v = self.v_norm(v)
# Flatten back for rotary embedding
q = q.flatten(-2, -1)
# Apply rotary embedding
if k is not None:
k = k.flatten(-2, -1)
q, k = self.rotary_emb(positions, q, k)
# Reshape k back to head format for attention
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
else:
# For shared KV layers, create a dummy key for rotary embedding and discard it
dummy_k = torch.zeros_like(
q[:, : self.kv_size]
) # Create dummy key with same shape as needed
q, _ = self.rotary_emb(positions, q, dummy_k)
# Reshape q back to head format for attention
q = q.unflatten(-1, (self.num_heads, self.head_dim))
attn_output = self.attn(
q,
k,
v,
forward_batch=forward_batch,
save_kv_cache=not self.is_kv_shared_layer,
)
output, _ = self.o_proj(attn_output)
return output
class Gemma3nDecoderLayer(nn.Module):
def __init__(
self,
layer_id: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.layer_id = layer_id
self.attention_type = config.layer_types[layer_id]
self.config = config
self.self_attn = Gemma3nAttention(
layer_id=layer_id,
config=config,
max_position_embeddings=config.max_position_embeddings,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
activation_sparsity = config.activation_sparsity_pattern[layer_id]
self.mlp = Gemma3nMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation,
activation_sparsity=activation_sparsity,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3nRMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = Gemma3nRMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = Gemma3nRMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
self.altup = Gemma3nAltUp(
config, quant_config, prefix=add_prefix("altup", prefix)
)
self.laurel = Gemma3nLaurelBlock(
config, quant_config, prefix=add_prefix("laurel", prefix)
)
self.per_layer_input_gate = ColumnParallelLinear(
self.hidden_size,
self.hidden_size_per_layer_input,
bias=False,
quant_config=quant_config,
prefix=add_prefix("per_layer_input_gate", prefix),
)
self.per_layer_projection = RowParallelLinear(
self.hidden_size_per_layer_input,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("per_layer_projection", prefix),
)
self.post_per_layer_input_norm = Gemma3nRMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.is_sliding = self.self_attn.is_sliding
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_input: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs,
) -> torch.Tensor:
predictions = self.altup.predict(
hidden_states
) # [num_altup_inputs, num_tokens, hidden_size]
active_prediction = predictions[self.config.altup_active_idx]
active_prediction_normed = self.input_layernorm(active_prediction)
laurel_output = self.laurel(
active_prediction_normed
) # laurel_output: [num_tokens, hidden_size]
# active_prediction: [num_tokens, hidden_size]
attn = self.self_attn(
positions=positions,
hidden_states=active_prediction_normed,
forward_batch=forward_batch,
**kwargs,
)
attn = self.post_attention_layernorm(attn) # [num_tokens, hidden_size]
attn_gated = active_prediction + attn # [num_tokens, hidden_size]
attn_laurel = (attn_gated + laurel_output) / torch.sqrt(torch.tensor(2.0))
attn_norm = self.pre_feedforward_layernorm(
attn_laurel
) # [num_tokens, hidden_size]
attn_ffw = self.mlp(attn_norm) # [num_tokens, hidden_size]
attn_ffw_norm = self.post_feedforward_layernorm(
attn_ffw
) # [num_tokens, hidden_size]
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm # [num_tokens, hidden_size]
corrected_predictions = self.altup.correct(
predictions, attn_ffw_laurel_gated
) # prediction : [num_altup_inputs, num_tokens, hidden_size]
# attn_ffw_laurel_gated: [num_tokens, hidden_size]
first_prediction = corrected_predictions[self.config.altup_active_idx]
if self.config.altup_correct_scale:
first_prediction = self.altup.scale_corrected_output(first_prediction)
# per_layer_input_gate
first_prediction = first_prediction.to(self.per_layer_input_gate.weight.dtype)
first_prediction, _ = self.per_layer_input_gate(first_prediction)
first_prediction = F.gelu(first_prediction, approximate="tanh")
first_prediction = torch.multiply(first_prediction, per_layer_input)
# per_layer_projection
first_prediction, _ = self.per_layer_projection(first_prediction)
first_prediction = self.post_per_layer_input_norm(first_prediction)
corrected_predictions[1:] += first_prediction
return corrected_predictions
class Gemma3nTextModel(PreTrainedModel):
def __init__(
self,
config: Gemma3nTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.padding_idx = config.pad_token_id
# Gemma3n downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
self.embed_tokens = Gemma3nTextScaledWordEmbedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
embed_scale=self.config.hidden_size**0.5,
)
self.norm = Gemma3nRMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
)
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Gemma3nDecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("layers", prefix),
)
# Per-layer input embeddings
self.hidden_size = config.hidden_size
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
self.padding_idx,
embed_scale=self.config.hidden_size_per_layer_input**0.5,
)
self.per_layer_model_projection = ColumnParallelLinear(
self.hidden_size,
config.num_hidden_layers * config.hidden_size_per_layer_input,
bias=False,
quant_config=quant_config,
prefix=add_prefix("per_layer_model_projection", prefix),
)
self.per_layer_projection_norm = Gemma3nRMSNorm(
dim=config.hidden_size_per_layer_input,
eps=config.rms_norm_eps,
)
self.altup_projections = make_layers(
self.config.altup_num_inputs - 1,
lambda idx, prefix: ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("altup_projections", prefix),
)
self.altup_unembed_projections = make_layers(
self.config.altup_num_inputs - 1,
lambda idx, prefix: ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=prefix,
),
prefix=add_prefix("altup_unembed_projections", prefix),
)
self.register_buffer(
"per_layer_projection_scale",
torch.tensor(self.hidden_size**-0.5),
persistent=False,
)
self.register_buffer(
"per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False
)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
embeddings = self.embed_tokens_per_layer(input_ids)
return embeddings.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_projection_scale.type(
inputs_embeds.dtype
)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
if per_layer_projection.shape != per_layer_inputs.shape:
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
return (
per_layer_projection + per_layer_inputs
) * self.per_layer_input_scale.type(inputs_embeds.dtype)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (input_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if input_ids is not None:
input_embeds = self.embed_tokens(input_ids)
per_layer_inputs = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_inputs)
if positions.dim() == 1:
positions = positions.unsqueeze(0)
# Expand hidden_states to support per-layer inputs
target_magnitude = torch.mean(input_embeds**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(torch.finfo(input_embeds.dtype).min)
# embed positions
hidden_states_0 = input_embeds
temp_hidden_states = [hidden_states_0]
for i in range(1, self.config.altup_num_inputs):
altup_proj, _ = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
new_magnitude = (
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
)
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(
temp_hidden_states, dim=0
) # [num_altup_inputs, n_tokens, hidden_size]
for layer_idx, layer in enumerate(self.layers):
per_layer_input = per_layer_inputs[:, layer_idx, :]
hidden_states = layer(
positions=positions,
per_layer_input=per_layer_input,
hidden_states=hidden_states,
forward_batch=forward_batch,
**kwargs,
)
# Per-layer inputs to single output
target_magnitude = (
torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
)
temp_hidden_states = [hidden_states[0]]
for i in range(1, self.config.altup_num_inputs):
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_unemb_proj, _ = self.altup_unembed_projections[i - 1](
hidden_states[i]
)
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
new_magnitude = (
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
)
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states)
hidden_states = torch.mean(hidden_states, dim=0)
hidden_states = self.norm(hidden_states)
return hidden_states
class Gemma3nForCausalLM(PreTrainedModel):
config_class = Gemma3nTextConfig
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = Gemma3nTextConfig
base_model_prefix = "language_model"
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
".q_proj": (".qkv_proj", 0),
".k_proj": (".qkv_proj", 1),
".v_proj": (".qkv_proj", 2),
".gate_proj": (".gate_up_proj", 0),
".up_proj": (".gate_up_proj", 1),
}
packed_modules_mapping = {
".qkv_proj": [
".q_proj",
".k_proj",
".v_proj",
],
".gate_up_proj": [
".gate_proj",
".up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
".qkv_proj",
".o_proj",
".gate_up_proj",
".down_proj",
]
# Gemma does not apply LoRA to the embedding layer
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3nTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
self.model = Gemma3nTextModel(
config=config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
self.logits_processor = LogitsProcessor(config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
per_layer_inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> LogitsProcessor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
per_layer_inputs,
**kwargs,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
name = name.replace("model.language_model.", "model.")
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
# Skip loading weights that are not in the model
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict:
# Skip loading weights that are not in the model
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
EntryClass = Gemma3nForCausalLM
AutoModel.register(Gemma3nTextConfig, Gemma3nForCausalLM, exist_ok=True)
import logging
import re
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import (
Gemma3nAudioConfig,
Gemma3nConfig,
Gemma3nTextConfig,
Gemma3nVisionConfig,
PreTrainedModel,
)
from transformers.models.auto.modeling_auto import AutoModel
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
cached_get_processor = lru_cache(get_processor)
class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
class Gemma3nMultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language model space."""
def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
text_config: Gemma3nTextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.multimodal_hidden_size = multimodal_config.hidden_size
self.eps = multimodal_config.rms_norm_eps
self.vocab_offset = multimodal_config.vocab_offset
self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size
self.embedding = VocabParallelEmbedding(
self.vocab_size,
self.multimodal_hidden_size,
quant_config=quant_config,
prefix=add_prefix("embedding", prefix),
)
self.hard_embedding_norm = Gemma3nRMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.soft_embedding_norm = Gemma3nRMSNorm(
self.multimodal_hidden_size,
eps=self.eps,
)
self.embedding_projection = RowParallelLinear(
self.multimodal_hidden_size,
self.text_hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("embedding_projection", prefix),
)
self.embedding_post_projection_norm = Gemma3nRMSNorm(
self.text_hidden_size,
eps=self.eps,
with_scale=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Embeds token ids or soft tokens for multimodal content into language model space.
Args:
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
`[vocab_offset, vocab_offset + vocab_size)`.
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
Returns:
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
else:
# Handle out of vocab ids to prevent CUDA assertion failures
out_of_vocab_id = self.vocab_size - 1
adjusted_ids = input_ids - self.vocab_offset
adjusted_ids = torch.where(adjusted_ids < 0, out_of_vocab_id, adjusted_ids)
adjusted_ids = torch.where(
adjusted_ids >= self.vocab_size, out_of_vocab_id, adjusted_ids
)
hard_emb = self.embedding(adjusted_ids)
emb_norm = self.hard_embedding_norm(hard_emb)
emb_norm_proj, _ = self.embedding_projection(emb_norm)
return self.embedding_post_projection_norm(emb_norm_proj)
class Gemma3nForConditionalGeneration(PreTrainedModel):
config_class = Gemma3nConfig
"""Gemma3n multimodal model for conditional generation."""
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
".out_proj.",
]
bitsandbytes_stacked_params_mapping = {
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
"out_proj": ("proj", 0),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
# Gemma does not apply LoRA to the embedding layer
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
def __init__(
self,
config: Gemma3nConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config)
self.config = config
self.quant_config = quant_config
prefix = add_prefix("model", prefix)
# Vision components
# TODO: Use sglang's vision model
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config,
config.text_config,
quant_config=quant_config,
prefix=add_prefix("embed_vision", prefix),
)
# Audio components
self.embed_audio = Gemma3nMultimodalEmbedder(
config.audio_config,
config.text_config,
quant_config=quant_config,
prefix=add_prefix("embed_audio", prefix),
)
self.audio_tower = Gemma3nAudioEncoder(
config.audio_config,
quant_config=quant_config,
prefix=add_prefix("audio_tower", prefix),
)
self.vocab_size = config.text_config.vocab_size
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
# Text model
self.language_model = Gemma3nTextModel(
config.text_config,
quant_config,
prefix=add_prefix("language_model", prefix),
)
# Create logits processor for the multimodal model
self.logits_processor = LogitsProcessor(config.text_config)
self.post_init()
def pad_input_ids(
self,
input_ids: List[int],
mm_inputs: Optional[MultimodalInputs] = None,
) -> List[int]:
"""Pad input IDs with image and audio tokens."""
if mm_inputs is None:
return input_ids
# Collect available media token pairs
media_token_pairs = []
for attr_name in ["im_start_id", "audio_start_id"]:
if hasattr(mm_inputs, attr_name):
start_id = getattr(mm_inputs, attr_name)
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
media_token_pairs.append((start_id, end_id))
# Apply padding pattern if we have media tokens
if media_token_pairs:
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
return input_ids
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_attention_sliding_window_size(self):
return self.config.text_config.sliding_window - 1
def get_image_feature(self, items: List[MultimodalDataItem]):
"""
Projects the last hidden state from the vision model into language model space.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
# Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
vision_outputs_list = []
for pixel_values_batch in all_pixel_values:
# Normalize input shape to [batch_size, channels, height, width]
if pixel_values_batch.dim() == 5:
pixel_values_batch = pixel_values_batch.squeeze(0)
elif pixel_values_batch.dim() == 3:
pixel_values_batch = pixel_values_batch.unsqueeze(0)
elif pixel_values_batch.dim() != 4:
raise ValueError(
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
)
# Process each image in the batch
batch_size = pixel_values_batch.shape[0]
for i in range(batch_size):
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
pixel_value = pixel_value.to(
device=self.vision_tower.device, dtype=self.language_model.dtype()
)
vision_outputs = self.vision_tower(
pixel_values=pixel_value, do_pooling=False, return_dict=True
).last_hidden_state
vision_outputs_list.append(vision_outputs)
# Concatenate all vision outputs
vision_outputs = torch.cat(vision_outputs_list, dim=0)
# Convert from (batch, channels, height, width) to (batch, height * width, channels)
vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
).permute(0, 2, 1)
# Normalize and embed the soft tokens into language model space
vision_outputs *= self.config.vision_config.hidden_size**0.5
return self.embed_vision(inputs_embeds=vision_outputs)
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
"""
Projects the last hidden state from the audio encoder into language model space.
Args:
items: List of multimodal data items containing audio data.
Returns:
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
"""
# Extract audio features and masks from items
all_input_features = flatten_nested_list(
[item.input_features for item in items]
)
all_input_features_mask = flatten_nested_list(
[~item.input_features_mask for item in items]
) # Note(Xinyuan): reverse the mask according to the HF implementation
# Process audio features one by one
audio_features_list = []
for input_features, input_features_mask in zip(
all_input_features, all_input_features_mask
):
# Ensure proper tensor format
if input_features.dim() == 2:
input_features = input_features.unsqueeze(0)
if input_features_mask.dim() == 1:
input_features_mask = input_features_mask.unsqueeze(0)
# Move to device and dtype
input_features = input_features.to(
device=next(self.audio_tower.parameters()).device,
dtype=self.language_model.dtype(),
)
input_features_mask = input_features_mask.to(device=input_features.device)
# Process through audio tower
audio_outputs, audio_mask = self.audio_tower(
input_features, input_features_mask
)
# Embed the audio outputs
audio_embeds = self.embed_audio(inputs_embeds=audio_outputs)
audio_features_list.append(audio_embeds)
# Concatenate all audio features
if audio_features_list:
audio_features = torch.cat(audio_features_list, dim=0)
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
audio_padding_toks = torch.tensor(
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
)
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
audio_features = torch.where(
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = (
self.config.audio_soft_tokens_per_image - audio_seq_len
)
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
return audio_features
else:
return torch.empty(
0,
0,
self.language_model.config.hidden_size,
device=next(self.parameters()).device,
dtype=self.language_model.dtype(),
)
def get_per_layer_inputs(
self, input_ids: torch.LongTensor
) -> Optional[torch.Tensor]:
return self.language_model.get_per_layer_inputs(input_ids)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.language_model.project_per_layer_inputs(
inputs_embeds, per_layer_inputs
)
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
**kwargs: object,
) -> LogitsProcessor:
"""Forward pass for multimodal Gemma3n."""
if (input_ids is None) ^ (input_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
positions += 1
if input_ids is not None:
# Prepare per-layer inputs from inputs_ids
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0, input_ids < self.vocab_size_per_layer_input
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_inputs = self.language_model.get_per_layer_inputs(
per_layer_inputs_tokens
)
# Use general_mm_embed_routine for handling multimodal data
# This will automatically handle text, image, and audio embeddings
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
positions=positions,
per_layer_inputs=per_layer_inputs,
)
# Process hidden states through logits processor
return self.logits_processor(
input_ids, hidden_states, self.language_model.embed_tokens, forward_batch
)
def tie_weights(self):
return self.language_model.tie_weights()
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".up_proj", 1),
(".gate_up_proj", ".gate_proj", 0),
]
"""Load weights for the model."""
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
name = re.sub(r"^model\.", "", name)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "vision_model" in name:
# adapt to VisionAttention
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
EntryClass = Gemma3nForConditionalGeneration
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment