Unverified Commit 8eb73c87 authored by Qin Zhou's avatar Qin Zhou Committed by GitHub
Browse files

Support pass kwargs to sd3 custom attention processor (#9818)



* Support pass kwargs to sd3 custom attention processor


---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 88b015dc
...@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module): ...@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
self._chunk_dim = dim self._chunk_dim = dim
def forward( def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention: if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb hidden_states, emb=temb
...@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module): ...@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
# Attention. # Attention.
attn_output, context_attn_output = self.attn( attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
) )
# Process attention outputs for the `hidden_states`. # Process attention outputs for the `hidden_states`.
...@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module): ...@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
hidden_states = hidden_states + attn_output hidden_states = hidden_states + attn_output
if self.use_dual_attention: if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2) attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2 hidden_states = hidden_states + attn_output2
......
...@@ -411,11 +411,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -411,11 +411,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
temb, temb,
joint_attention_kwargs,
**ckpt_kwargs, **ckpt_kwargs,
) )
elif not is_skip: elif not is_skip:
encoder_hidden_states, hidden_states = block( encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
) )
# controlnet residual # controlnet residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment