Unverified Commit 4b17fa2a authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

fix flux type hint (#12089)

fix-flux-type-hint
parent d45199a2
...@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module): ...@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1] text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment