Unverified Commit 12358622 authored by linjiapro's avatar linjiapro Committed by GitHub
Browse files

Improve control net block index for sd3 (#9758)



* improve control net index

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 805aa937
...@@ -56,6 +56,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -56,6 +56,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
out_channels: int = 16, out_channels: int = 16,
pos_embed_max_size: int = 96, pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0, extra_conditioning_channels: int = 0,
dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
default_out_channels = in_channels default_out_channels = in_channels
...@@ -84,6 +86,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -84,6 +86,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=self.config.attention_head_dim,
context_pre_only=False, context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
) )
for i in range(num_layers) for i in range(num_layers)
] ]
...@@ -248,7 +252,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -248,7 +252,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
config = transformer.config config = transformer.config
config["num_layers"] = num_layers or config.num_layers config["num_layers"] = num_layers or config.num_layers
config["extra_conditioning_channels"] = num_extra_conditioning_channels config["extra_conditioning_channels"] = num_extra_conditioning_channels
controlnet = cls(**config) controlnet = cls.from_config(config)
if load_weights_from_transformer: if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -349,7 +350,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -349,7 +350,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# controlnet residual # controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False: if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import gc import gc
import unittest import unittest
from typing import Optional
import numpy as np import numpy as np
import pytest import pytest
...@@ -59,7 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ...@@ -59,7 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
) )
batch_params = frozenset(["prompt", "negative_prompt"]) batch_params = frozenset(["prompt", "negative_prompt"])
def get_dummy_components(self): def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"):
torch.manual_seed(0) torch.manual_seed(0)
transformer = SD3Transformer2DModel( transformer = SD3Transformer2DModel(
sample_size=32, sample_size=32,
...@@ -72,6 +73,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ...@@ -72,6 +73,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
caption_projection_dim=32, caption_projection_dim=32,
pooled_projection_dim=64, pooled_projection_dim=64,
out_channels=8, out_channels=8,
qk_norm=qk_norm,
) )
torch.manual_seed(0) torch.manual_seed(0)
...@@ -79,7 +81,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ...@@ -79,7 +81,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
sample_size=32, sample_size=32,
patch_size=1, patch_size=1,
in_channels=8, in_channels=8,
num_layers=1, num_layers=num_controlnet_layers,
attention_head_dim=8, attention_head_dim=8,
num_attention_heads=4, num_attention_heads=4,
joint_attention_dim=32, joint_attention_dim=32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment