"vscode:/vscode.git/clone" did not exist on "679f02c21f36249e3e033fb8376c14256f452652"
Unverified Commit 8eb73c87 authored by Qin Zhou's avatar Qin Zhou Committed by GitHub
Browse files

Support pass kwargs to sd3 custom attention processor (#9818)



* Support pass kwargs to sd3 custom attention processor


---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 88b015dc
......@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
self._chunk_dim = dim
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
......@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
......@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
......
......@@ -411,11 +411,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states,
encoder_hidden_states,
temb,
joint_attention_kwargs,
**ckpt_kwargs,
)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment