Unverified Commit efb7a299 authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

Fix many type hint errors (#12289)

* fix hidream type hint

* fix hunyuan-video type hint

* fix many type hint

* fix many type hint errors

* fix many type hint errors

* fix many type hint errors

* make stype & make quality
parent d06750a5
...@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module): ...@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor, temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {} joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention: if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module): ...@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
return selected_indices return selected_indices
def forward(self, latent): def forward(self, latent) -> torch.Tensor:
batch_size, num_channels, height, width = latent.size() batch_size, num_channels, height, width = latent.size()
latent = latent.view( latent = latent.view(
batch_size, batch_size,
...@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module): ...@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
temb: torch.FloatTensor, temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
): ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
attention_kwargs = attention_kwargs or {} attention_kwargs = attention_kwargs or {}
...@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module): ...@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor, temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states residual = hidden_states
residual_context = encoder_hidden_states residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {} attention_kwargs = attention_kwargs or {}
...@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ...@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
timestep: torch.LongTensor = None, timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
......
...@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module): ...@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {} attention_kwargs = attention_kwargs or {}
...@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac ...@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
): ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
......
...@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module): ...@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
# norm & modulate # norm & modulate
...@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
id_cond: Optional[torch.Tensor] = None, id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None, id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
): ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module): ...@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
encoder_mask: torch.Tensor, encoder_mask: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ) -> torch.Tensor:
""" """
Perform a forward pass through the LuminaNextDiTBlock. Perform a forward pass through the LuminaNextDiTBlock.
...@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): ...@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
image_rotary_emb: torch.Tensor, image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None, cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True, return_dict=True,
) -> torch.Tensor: ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
""" """
Forward pass of LuminaNextDiT. Forward pass of LuminaNextDiT.
......
...@@ -472,7 +472,7 @@ class BriaSingleTransformerBlock(nn.Module): ...@@ -472,7 +472,7 @@ class BriaSingleTransformerBlock(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1] text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
...@@ -588,7 +588,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -588,7 +588,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
return_dict: bool = True, return_dict: bool = True,
controlnet_block_samples=None, controlnet_block_samples=None,
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
""" """
The [`BriaTransformer2DModel`] forward method. The [`BriaTransformer2DModel`] forward method.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Dict, Union from typing import Dict, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module): ...@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
emb: torch.Tensor, emb: torch.Tensor,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
# norm & modulate # norm & modulate
...@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
target_size: torch.Tensor, target_size: torch.Tensor,
crop_coords: torch.Tensor, crop_coords: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]: ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
""" """
The [`CogView3PlusTransformer2DModel`] forward method. The [`CogView3PlusTransformer2DModel`] forward method.
......
...@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module): ...@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
] = None, ] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None, attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning # 1. Timestep conditioning
( (
norm_hidden_states, norm_hidden_states,
...@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach ...@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb: Optional[ image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None, ] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]: ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
......
...@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module): ...@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None): def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
t_emb = self.time_proj(timesteps).to(dtype=wdtype) t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb) t_emb = self.timestep_embedder(t_emb)
return t_emb return t_emb
...@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module): ...@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
self.out_channels = out_channels self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
def forward(self, latent): def forward(self, latent) -> torch.Tensor:
latent = self.proj(latent) latent = self.proj(latent)
return latent return latent
...@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module): ...@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None, image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
wtype = hidden_states.dtype wtype = hidden_states.dtype
( (
shift_msa_i, shift_msa_i,
...@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module): ...@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None, image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.block( return self.block(
hidden_states=hidden_states, hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks, hidden_states_masks=hidden_states_masks,
...@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
encoder_hidden_states = kwargs.get("encoder_hidden_states", None) encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
......
...@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module): ...@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1] text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
...@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module): ...@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None, token_replace_emb: torch.Tensor = None,
num_tokens: int = None, num_tokens: int = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1] text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
...@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel( ...@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
indices_latents_history_4x: Optional[torch.Tensor] = None, indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
): ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment