Unverified Commit d9f71ab3 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

correct `attention_head_dim` for `JointTransformerBlock` (#8608)



* add

* update sd3 controlnet

* Update src/diffusers/models/controlnet_sd3.py

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent dd4b731e
......@@ -128,9 +128,9 @@ class JointTransformerBlock(nn.Module):
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim // num_attention_heads,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=attention_head_dim,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
......
......@@ -81,7 +81,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.inner_dim,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
)
for i in range(num_layers)
......
......@@ -97,7 +97,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.inner_dim,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=i == num_layers - 1,
)
for i in range(self.config.num_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment