"tests/vscode:/vscode.git/clone" did not exist on "11629d52686a0f5899db5159bf3357a8e8f1b9da"
controlnet_sd3.py 19.9 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
26
from ..attention import AttentionMixin, JointTransformerBlock
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
27
28
29
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
YiYi Xu's avatar
YiYi Xu committed
30
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
31
32
33
34
35
36
37
38
39
40
41
from .controlnet import BaseOutput, zero_module


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
class SD3ControlNetOutput(BaseOutput):
    controlnet_block_samples: Tuple[torch.Tensor]


42
class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    r"""
    ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).

    Parameters:
        sample_size (`int`, defaults to `128`):
            The width/height of the latents. This is fixed during training since it is used to learn a number of
            position embeddings.
        patch_size (`int`, defaults to `2`):
            Patch size to turn the input data into small patches.
        in_channels (`int`, defaults to `16`):
            The number of latent channels in the input.
        num_layers (`int`, defaults to `18`):
            The number of layers of transformer blocks to use.
        attention_head_dim (`int`, defaults to `64`):
            The number of channels in each head.
        num_attention_heads (`int`, defaults to `18`):
            The number of heads to use for multi-head attention.
        joint_attention_dim (`int`, defaults to `4096`):
            The embedding dimension to use for joint text-image attention.
        caption_projection_dim (`int`, defaults to `1152`):
            The embedding dimension of caption embeddings.
        pooled_projection_dim (`int`, defaults to `2048`):
            The embedding dimension of pooled text projections.
        out_channels (`int`, defaults to `16`):
            The number of latent channels in the output.
        pos_embed_max_size (`int`, defaults to `96`):
            The maximum latent height/width of positional embeddings.
        extra_conditioning_channels (`int`, defaults to `0`):
            The number of extra channels to use for conditioning for patch embedding.
        dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
            The number of dual-stream transformer blocks to use.
        qk_norm (`str`, *optional*, defaults to `None`):
            The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
        pos_embed_type (`str`, defaults to `"sincos"`):
            The type of positional embedding to use. Choose between `"sincos"` and `None`.
        use_pos_embed (`bool`, defaults to `True`):
            Whether to use positional embeddings.
        force_zeros_for_pooled_projection (`bool`, defaults to `True`):
            Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
            config value of the ControlNet model.
    """

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        sample_size: int = 128,
        patch_size: int = 2,
        in_channels: int = 16,
        num_layers: int = 18,
        attention_head_dim: int = 64,
        num_attention_heads: int = 18,
        joint_attention_dim: int = 4096,
        caption_projection_dim: int = 1152,
        pooled_projection_dim: int = 2048,
        out_channels: int = 16,
        pos_embed_max_size: int = 96,
        extra_conditioning_channels: int = 0,
102
103
        dual_attention_layers: Tuple[int, ...] = (),
        qk_norm: Optional[str] = None,
YiYi Xu's avatar
YiYi Xu committed
104
105
106
        pos_embed_type: Optional[str] = "sincos",
        use_pos_embed: bool = True,
        force_zeros_for_pooled_projection: bool = True,
107
108
109
110
111
112
    ):
        super().__init__()
        default_out_channels = in_channels
        self.out_channels = out_channels if out_channels is not None else default_out_channels
        self.inner_dim = num_attention_heads * attention_head_dim

YiYi Xu's avatar
YiYi Xu committed
113
114
115
116
117
118
119
120
121
122
123
124
        if use_pos_embed:
            self.pos_embed = PatchEmbed(
                height=sample_size,
                width=sample_size,
                patch_size=patch_size,
                in_channels=in_channels,
                embed_dim=self.inner_dim,
                pos_embed_max_size=pos_embed_max_size,
                pos_embed_type=pos_embed_type,
            )
        else:
            self.pos_embed = None
125
126
127
        self.time_text_embed = CombinedTimestepTextProjEmbeddings(
            embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
        )
YiYi Xu's avatar
YiYi Xu committed
128
129
130
131
132
133
134
135
136
137
        if joint_attention_dim is not None:
            self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

            # `attention_head_dim` is doubled to account for the mixing.
            # It needs to crafted when we get the actual checkpoints.
            self.transformer_blocks = nn.ModuleList(
                [
                    JointTransformerBlock(
                        dim=self.inner_dim,
                        num_attention_heads=num_attention_heads,
138
                        attention_head_dim=attention_head_dim,
YiYi Xu's avatar
YiYi Xu committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                        context_pre_only=False,
                        qk_norm=qk_norm,
                        use_dual_attention=True if i in dual_attention_layers else False,
                    )
                    for i in range(num_layers)
                ]
            )
        else:
            self.context_embedder = None
            self.transformer_blocks = nn.ModuleList(
                [
                    SD3SingleTransformerBlock(
                        dim=self.inner_dim,
                        num_attention_heads=num_attention_heads,
153
                        attention_head_dim=attention_head_dim,
YiYi Xu's avatar
YiYi Xu committed
154
155
156
157
                    )
                    for _ in range(num_layers)
                ]
            )
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

        # controlnet_blocks
        self.controlnet_blocks = nn.ModuleList([])
        for _ in range(len(self.transformer_blocks)):
            controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
            controlnet_block = zero_module(controlnet_block)
            self.controlnet_blocks.append(controlnet_block)
        pos_embed_input = PatchEmbed(
            height=sample_size,
            width=sample_size,
            patch_size=patch_size,
            in_channels=in_channels + extra_conditioning_channels,
            embed_dim=self.inner_dim,
            pos_embed_type=None,
        )
        self.pos_embed_input = zero_module(pos_embed_input)

        self.gradient_checkpointing = False

    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        Sets the attention processor to use [feed forward
        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).

        Parameters:
            chunk_size (`int`, *optional*):
                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
                over each tensor of dim=`dim`.
            dim (`int`, *optional*, defaults to `0`):
                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
                or dim=1 (sequence length).
        """
        if dim not in [0, 1]:
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # By default chunk size is 1
        chunk_size = chunk_size or 1

        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)

    # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
    def fuse_qkv_projections(self):
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.

Steven Liu's avatar
Steven Liu committed
213
        > [!WARNING] > This API is 🧪 experimental.
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        """
        self.original_attn_processors = None

        for _, attn_processor in self.attn_processors.items():
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        self.original_attn_processors = self.attn_processors

        for module in self.modules():
            if isinstance(module, Attention):
                module.fuse_projections(fuse=True)

        self.set_attn_processor(FusedJointAttnProcessor2_0())

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
    def unfuse_qkv_projections(self):
        """Disables the fused QKV projection if enabled.

Steven Liu's avatar
Steven Liu committed
233
        > [!WARNING] > This API is 🧪 experimental.
234
235
236
237
238

        """
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

239
240
241
242
243
244
245
246
247
248
249
250
251
252
    # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
    # we should have handled this in conversion script
    def _get_pos_embed_from_transformer(self, transformer):
        pos_embed = PatchEmbed(
            height=transformer.config.sample_size,
            width=transformer.config.sample_size,
            patch_size=transformer.config.patch_size,
            in_channels=transformer.config.in_channels,
            embed_dim=transformer.inner_dim,
            pos_embed_max_size=transformer.config.pos_embed_max_size,
        )
        pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True)
        return pos_embed

253
254
255
256
257
258
259
    @classmethod
    def from_transformer(
        cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
    ):
        config = transformer.config
        config["num_layers"] = num_layers or config.num_layers
        config["extra_conditioning_channels"] = num_extra_conditioning_channels
260
        controlnet = cls.from_config(config)
261
262
263
264
265
266
267
268
269
270
271
272
273

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        return controlnet

    def forward(
        self,
274
        hidden_states: torch.Tensor,
275
276
        controlnet_cond: torch.Tensor,
        conditioning_scale: float = 1.0,
277
278
        encoder_hidden_states: torch.Tensor = None,
        pooled_projections: torch.Tensor = None,
279
280
281
        timestep: torch.LongTensor = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = True,
282
    ) -> Union[torch.Tensor, Transformer2DModelOutput]:
283
284
285
286
        """
        The [`SD3Transformer2DModel`] forward method.

        Args:
287
            hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
288
289
290
291
292
                Input `hidden_states`.
            controlnet_cond (`torch.Tensor`):
                The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
            conditioning_scale (`float`, defaults to `1.0`):
                The scale factor for ControlNet outputs.
293
            encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
294
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
295
            pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                from the embeddings of input conditions.
            timestep ( `torch.LongTensor`):
                Used to indicate denoising step.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
        if joint_attention_kwargs is not None:
            joint_attention_kwargs = joint_attention_kwargs.copy()
            lora_scale = joint_attention_kwargs.pop("scale", 1.0)
        else:
            lora_scale = 1.0

        if USE_PEFT_BACKEND:
            # weight the lora layers by setting `lora_scale` for each PEFT layer
            scale_lora_layers(self, lora_scale)
        else:
            if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
                logger.warning(
                    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
                )

YiYi Xu's avatar
YiYi Xu committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        if self.pos_embed is not None and hidden_states.ndim != 4:
            raise ValueError("hidden_states must be 4D when pos_embed is used")

        # SD3.5 8b controlnet does not have a `pos_embed`,
        # it use the `pos_embed` from the transformer to process input before passing to controlnet
        elif self.pos_embed is None and hidden_states.ndim != 3:
            raise ValueError("hidden_states must be 3D when pos_embed is not used")

        if self.context_embedder is not None and encoder_hidden_states is None:
            raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
        # SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
        elif self.context_embedder is None and encoder_hidden_states is not None:
            raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")

        if self.pos_embed is not None:
            hidden_states = self.pos_embed(hidden_states)  # takes care of adding positional embeddings too.

343
        temb = self.time_text_embed(timestep, pooled_projections)
YiYi Xu's avatar
YiYi Xu committed
344
345
346

        if self.context_embedder is not None:
            encoder_hidden_states = self.context_embedder(encoder_hidden_states)
347
348
349
350
351
352
353

        # add
        hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)

        block_res_samples = ()

        for block in self.transformer_blocks:
354
            if torch.is_grad_enabled() and self.gradient_checkpointing:
355
                if self.context_embedder is not None:
356
357
                    encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
                        block,
358
359
360
361
362
363
                        hidden_states,
                        encoder_hidden_states,
                        temb,
                    )
                else:
                    # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
364
                    hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb)
365
366

            else:
YiYi Xu's avatar
YiYi Xu committed
367
368
369
370
371
372
373
                if self.context_embedder is not None:
                    encoder_hidden_states, hidden_states = block(
                        hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
                    )
                else:
                    # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
                    hidden_states = block(hidden_states, temb)
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

            block_res_samples = block_res_samples + (hidden_states,)

        controlnet_block_res_samples = ()
        for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
            block_res_sample = controlnet_block(block_res_sample)
            controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)

        # 6. scaling
        controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]

        if USE_PEFT_BACKEND:
            # remove `lora_scale` from each PEFT layer
            unscale_lora_layers(self, lora_scale)

        if not return_dict:
            return (controlnet_block_res_samples,)

        return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)


class SD3MultiControlNetModel(ModelMixin):
    r"""
    `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet

    This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
    compatible with `SD3ControlNetModel`.

    Args:
        controlnets (`List[SD3ControlNetModel]`):
            Provides additional conditioning to the unet during the denoising process. You must set multiple
            `SD3ControlNetModel` as a list.
    """

    def __init__(self, controlnets):
        super().__init__()
        self.nets = nn.ModuleList(controlnets)

    def forward(
        self,
414
        hidden_states: torch.Tensor,
415
416
        controlnet_cond: List[torch.tensor],
        conditioning_scale: List[float],
417
418
        pooled_projections: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        timestep: torch.LongTensor = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = True,
    ) -> Union[SD3ControlNetOutput, Tuple]:
        for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
            block_samples = controlnet(
                hidden_states=hidden_states,
                timestep=timestep,
                encoder_hidden_states=encoder_hidden_states,
                pooled_projections=pooled_projections,
                controlnet_cond=image,
                conditioning_scale=scale,
                joint_attention_kwargs=joint_attention_kwargs,
                return_dict=return_dict,
            )

            # merge samples
            if i == 0:
                control_block_samples = block_samples
            else:
                control_block_samples = [
                    control_block_sample + block_sample
                    for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
                ]
                control_block_samples = (tuple(control_block_samples),)

        return control_block_samples