Unverified Commit dcd80206 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Update type annotation of `input_ids` in model forward (#33063)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f4a0921c
...@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: torch.Tensor | None = None, intermediate_tensors: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration( ...@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -771,7 +771,7 @@ class Zamba2Model(nn.Module): ...@@ -771,7 +771,7 @@ class Zamba2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
...@@ -947,7 +947,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC ...@@ -947,7 +947,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
**kwargs: Any, **kwargs: Any,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment