Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
......@@ -562,7 +562,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -1101,7 +1101,7 @@ class DeepseekV2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -1284,7 +1284,7 @@ class DeepseekV2ForCausalLM(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -614,7 +614,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -394,7 +394,7 @@ class Dots1Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -754,7 +754,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -432,7 +432,7 @@ class Eagle2_5_VLForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -440,7 +440,6 @@ class Eagle2_5_VLForConditionalGeneration(
) -> IntermediateTensors:
"""Forward pass through the model."""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
forward_kwargs = {
......
......@@ -466,7 +466,7 @@ class Ernie4_5_MoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -727,7 +727,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -1650,7 +1650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -565,7 +565,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -646,7 +646,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -164,7 +164,7 @@ class ErnieMTP(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
......
......@@ -496,7 +496,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -490,7 +490,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -549,7 +549,7 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -423,7 +423,7 @@ class FalconModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -602,7 +602,7 @@ class FalconH1ForCausalLM(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -340,7 +340,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -297,7 +297,7 @@ class GemmaModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -400,7 +400,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -410,7 +410,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -618,7 +618,7 @@ class Gemma3ForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -656,3 +656,41 @@ class Gemma3ForConditionalGeneration(
connector="multi_modal_projector",
tower_model="vision_tower",
)
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
"""
Calculate the number of tokens output by the vision encoder.
The vision encoder processes images into patch embeddings. For Gemma3,
the relationship between prompt placeholder tokens and actual vision
encoder output tokens depends on the patch grid size.
Args:
num_image_tokens: Number of image placeholder tokens in the prompt
(typically mm_tokens_per_image per image)
Returns:
Number of tokens output by the vision encoder
"""
# For Gemma3, the vision encoder outputs tokens_per_side x tokens_per_side
# tokens per image. Since num_image_tokens represents the number of
# connector output tokens (mm_tokens_per_image = 256), and tokens_per_side
# is sqrt(256) = 16, we need to account for the token expansion.
# Based on empirical testing, the multiplier of 16 works correctly.
return num_image_tokens * 16
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
"""
Calculate the number of tokens output by the multimodal connector.
The connector applies projection and normalization but maintains the
token count for Gemma3.
Args:
num_vision_tokens: Number of tokens from vision encoder
Returns:
Number of tokens after connector processing
"""
# The Gemma3 connector maintains a 1:1 token mapping
return num_vision_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment