"examples/vscode:/vscode.git/clone" did not exist on "e86a280c455130d597e352a6fe90367b14bfe925"
Unverified Commit efb7a299 authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

Fix many type hint errors (#12289)

* fix hidream type hint

* fix hunyuan-video type hint

* fix many type hint

* fix many type hint errors

* fix many type hint errors

* fix many type hint errors

* make stype & make quality
parent d06750a5
......@@ -674,7 +674,7 @@ class JointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
return selected_indices
def forward(self, latent):
def forward(self, latent) -> torch.Tensor:
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
......@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> torch.Tensor:
residual = hidden_states
attention_kwargs = attention_kwargs or {}
......@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}
......@@ -472,7 +472,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
......
......@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
......@@ -441,7 +441,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
......
......@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
......@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> torch.Tensor:
"""
Perform a forward pass through the LuminaNextDiTBlock.
......@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
) -> torch.Tensor:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
Forward pass of LuminaNextDiT.
......
......@@ -472,7 +472,7 @@ class BriaSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
......@@ -588,7 +588,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`BriaTransformer2DModel`] forward method.
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Dict, Union
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
......@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
......@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
......
......@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
......@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
......
......@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
......@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
def forward(self, latent):
def forward(self, latent) -> torch.Tensor:
latent = self.proj(latent)
return latent
......@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
wtype = hidden_states.dtype
(
shift_msa_i,
......@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
......@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
......
......@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
......@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
......@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment