Unverified Commit d9f71ab3 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

correct `attention_head_dim` for `JointTransformerBlock` (#8608)



* add

* update sd3 controlnet

* Update src/diffusers/models/controlnet_sd3.py

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent dd4b731e
...@@ -128,9 +128,9 @@ class JointTransformerBlock(nn.Module): ...@@ -128,9 +128,9 @@ class JointTransformerBlock(nn.Module):
query_dim=dim, query_dim=dim,
cross_attention_dim=None, cross_attention_dim=None,
added_kv_proj_dim=dim, added_kv_proj_dim=dim,
dim_head=attention_head_dim // num_attention_heads, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=attention_head_dim, out_dim=dim,
context_pre_only=context_pre_only, context_pre_only=context_pre_only,
bias=True, bias=True,
processor=processor, processor=processor,
......
...@@ -81,7 +81,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -81,7 +81,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
JointTransformerBlock( JointTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.inner_dim, attention_head_dim=self.config.attention_head_dim,
context_pre_only=False, context_pre_only=False,
) )
for i in range(num_layers) for i in range(num_layers)
......
...@@ -97,7 +97,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -97,7 +97,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
JointTransformerBlock( JointTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads, num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.inner_dim, attention_head_dim=self.config.attention_head_dim,
context_pre_only=i == num_layers - 1, context_pre_only=i == num_layers - 1,
) )
for i in range(self.config.num_layers) for i in range(self.config.num_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment