Unverified Commit 12358622 authored by linjiapro's avatar linjiapro Committed by GitHub
Browse files

Improve control net block index for sd3 (#9758)



* improve control net index

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 805aa937
......@@ -56,6 +56,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
out_channels: int = 16,
pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None,
):
super().__init__()
default_out_channels = in_channels
......@@ -84,6 +86,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
......@@ -248,7 +252,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
config = transformer.config
config["num_layers"] = num_layers or config.num_layers
config["extra_conditioning_channels"] = num_extra_conditioning_channels
controlnet = cls(**config)
controlnet = cls.from_config(config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
......
......@@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
......@@ -349,7 +350,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
......
......@@ -15,6 +15,7 @@
import gc
import unittest
from typing import Optional
import numpy as np
import pytest
......@@ -59,7 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
)
batch_params = frozenset(["prompt", "negative_prompt"])
def get_dummy_components(self):
def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
......@@ -72,6 +73,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
caption_projection_dim=32,
pooled_projection_dim=64,
out_channels=8,
qk_norm=qk_norm,
)
torch.manual_seed(0)
......@@ -79,7 +81,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
sample_size=32,
patch_size=1,
in_channels=8,
num_layers=1,
num_layers=num_controlnet_layers,
attention_head_dim=8,
num_attention_heads=4,
joint_attention_dim=32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment