Unverified Commit dcd80206 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Update type annotation of `input_ids` in model forward (#33063)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f4a0921c
......@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -771,7 +771,7 @@ class Zamba2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
......@@ -947,7 +947,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
**kwargs: Any,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment