Unverified Commit dcd80206 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Update type annotation of `input_ids` in model forward (#33063)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f4a0921c
...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -618,7 +618,7 @@ class Gemma3ForConditionalGeneration( ...@@ -618,7 +618,7 @@ class Gemma3ForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -704,7 +704,7 @@ class Gemma3nSelfDecoder(nn.Module): ...@@ -704,7 +704,7 @@ class Gemma3nSelfDecoder(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None, per_layer_inputs: torch.Tensor | None = None,
...@@ -887,7 +887,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -887,7 +887,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def fast_prefill_forward( def fast_prefill_forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None, per_layer_inputs: torch.Tensor | None = None,
...@@ -964,7 +964,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -964,7 +964,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def normal_forward( def normal_forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None, per_layer_inputs: torch.Tensor | None = None,
...@@ -1131,7 +1131,7 @@ class Gemma3nForCausalLM(nn.Module): ...@@ -1131,7 +1131,7 @@ class Gemma3nForCausalLM(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
*, *,
per_layer_inputs: torch.Tensor | None = None, per_layer_inputs: torch.Tensor | None = None,
......
...@@ -707,7 +707,7 @@ class Gemma3nForConditionalGeneration( ...@@ -707,7 +707,7 @@ class Gemma3nForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -270,7 +270,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -270,7 +270,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1711,7 +1711,7 @@ class Glm4vForConditionalGeneration( ...@@ -1711,7 +1711,7 @@ class Glm4vForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -451,7 +451,7 @@ class Glm4MoeModel(nn.Module): ...@@ -451,7 +451,7 @@ class Glm4MoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -687,7 +687,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExper ...@@ -687,7 +687,7 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExper
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -264,7 +264,7 @@ class Glm4MoeLiteModel(nn.Module): ...@@ -264,7 +264,7 @@ class Glm4MoeLiteModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -596,7 +596,7 @@ class Glm4MoeLiteForCausalLM( ...@@ -596,7 +596,7 @@ class Glm4MoeLiteForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -230,7 +230,7 @@ class Glm4MoeLiteMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts): ...@@ -230,7 +230,7 @@ class Glm4MoeLiteMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
......
...@@ -216,7 +216,7 @@ class Glm4MoeMTP(nn.Module, Glm4MixtureOfExperts): ...@@ -216,7 +216,7 @@ class Glm4MoeMTP(nn.Module, Glm4MixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
......
...@@ -769,7 +769,7 @@ class GLM4VForCausalLM( ...@@ -769,7 +769,7 @@ class GLM4VForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1075,7 +1075,7 @@ class GlmAsrForConditionalGeneration( ...@@ -1075,7 +1075,7 @@ class GlmAsrForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -218,7 +218,7 @@ class GPT2Model(nn.Module): ...@@ -218,7 +218,7 @@ class GPT2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None, inputs_embeds: torch.Tensor | None,
...@@ -298,7 +298,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -298,7 +298,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -362,7 +362,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -362,7 +362,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -235,7 +235,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -235,7 +235,7 @@ class GPTBigCodeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -311,7 +311,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -311,7 +311,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -220,7 +220,7 @@ class GPTJModel(nn.Module): ...@@ -220,7 +220,7 @@ class GPTJModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -324,7 +324,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -324,7 +324,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -230,7 +230,7 @@ class GPTNeoXModel(nn.Module): ...@@ -230,7 +230,7 @@ class GPTNeoXModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -318,7 +318,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -318,7 +318,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -275,7 +275,7 @@ class GptOssModel(nn.Module): ...@@ -275,7 +275,7 @@ class GptOssModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -714,7 +714,7 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): ...@@ -714,7 +714,7 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -437,7 +437,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -437,7 +437,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -806,7 +806,7 @@ class GraniteSpeechForConditionalGeneration( ...@@ -806,7 +806,7 @@ class GraniteSpeechForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -312,7 +312,7 @@ class GraniteMoeModel(nn.Module): ...@@ -312,7 +312,7 @@ class GraniteMoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -528,7 +528,7 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -528,7 +528,7 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment