"...text-generation-inference.git" did not exist on "e64a65891bd842a559ba1ea2d37525d7dae7f0f4"
Unverified Commit 75bd1e83 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Sd35 controlnet (#10020)



* add model/pipeline
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8d477dae
"""
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
Example:
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py \
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
--output_path "output/sd35-controlnet-canny" \
--dtype "fp16" # optional, defaults to fp32
```
Or download and convert from HuggingFace repository:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py \
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
--filename "sd3.5_large_controlnet_canny.safetensors" \
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
--dtype "fp32" # optional, defaults to fp32
```
Note:
The script supports the following ControlNet types from SD3.5:
- Canny edge detection
- Depth estimation
- Blur detection
The checkpoint files can be downloaded from:
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
"""
import argparse
import safetensors.torch
import torch
from huggingface_hub import hf_hub_download
from diffusers import SD3ControlNetModel
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
parser.add_argument(
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
)
parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
parser.add_argument(
"--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
)
args = parser.parse_args()
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
if args.filename is None:
raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
converted_state_dict = {}
# Direct mappings for controlnet blocks
for i in range(19): # 19 controlnet blocks
converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]
# Positional embeddings
converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]
# Time and text embeddings
time_text_mappings = {
"time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
"time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
"time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
"time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
"time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
"time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
"time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
"time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
}
for new_key, old_key in time_text_mappings.items():
if old_key in original_state_dict:
converted_state_dict[new_key] = original_state_dict[old_key]
# Transformer blocks
for i in range(19):
# Split QKV into separate Q, K, V
qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
block_mappings = {
f"transformer_blocks.{i}.attn.to_q.weight": q,
f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
f"transformer_blocks.{i}.attn.to_k.weight": k,
f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
f"transformer_blocks.{i}.attn.to_v.weight": v,
f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
# Output projections
f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
f"transformer_blocks.{i}.attn.proj.weight"
],
f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
f"transformer_blocks.{i}.attn.proj.bias"
],
# Feed forward
f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
f"transformer_blocks.{i}.mlp.fc1.weight"
],
f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
# Norms
f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
f"transformer_blocks.{i}.adaLN_modulation.1.weight"
],
f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
f"transformer_blocks.{i}.adaLN_modulation.1.bias"
],
}
converted_state_dict.update(block_mappings)
return converted_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
original_dtype = next(iter(original_ckpt.values())).dtype
# Initialize dtype with fp32 as default
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")
if dtype != original_dtype:
print(
f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
)
converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)
controlnet = SD3ControlNetModel(
patch_size=2,
in_channels=16,
num_layers=19,
attention_head_dim=64,
num_attention_heads=38,
joint_attention_dim=None,
caption_projection_dim=2048,
pooled_projection_dim=2048,
out_channels=16,
pos_embed_max_size=None,
pos_embed_type=None,
use_pos_embed=False,
force_zeros_for_pooled_projection=False,
)
controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)
print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
controlnet.to(dtype).save_pretrained(args.output_path)
if __name__ == "__main__":
main(args)
...@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP ...@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
from .controlnet import BaseOutput, zero_module from .controlnet import BaseOutput, zero_module
...@@ -58,40 +59,60 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -58,40 +59,60 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
extra_conditioning_channels: int = 0, extra_conditioning_channels: int = 0,
dual_attention_layers: Tuple[int, ...] = (), dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
pos_embed_type: Optional[str] = "sincos",
use_pos_embed: bool = True,
force_zeros_for_pooled_projection: bool = True,
): ):
super().__init__() super().__init__()
default_out_channels = in_channels default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = PatchEmbed( if use_pos_embed:
height=sample_size, self.pos_embed = PatchEmbed(
width=sample_size, height=sample_size,
patch_size=patch_size, width=sample_size,
in_channels=in_channels, patch_size=patch_size,
embed_dim=self.inner_dim, in_channels=in_channels,
pos_embed_max_size=pos_embed_max_size, embed_dim=self.inner_dim,
) pos_embed_max_size=pos_embed_max_size,
pos_embed_type=pos_embed_type,
)
else:
self.pos_embed = None
self.time_text_embed = CombinedTimestepTextProjEmbeddings( self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
) )
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) if joint_attention_dim is not None:
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints. # `attention_head_dim` is doubled to account for the mixing.
self.transformer_blocks = nn.ModuleList( # It needs to crafted when we get the actual checkpoints.
[ self.transformer_blocks = nn.ModuleList(
JointTransformerBlock( [
dim=self.inner_dim, JointTransformerBlock(
num_attention_heads=num_attention_heads, dim=self.inner_dim,
attention_head_dim=self.config.attention_head_dim, num_attention_heads=num_attention_heads,
context_pre_only=False, attention_head_dim=self.config.attention_head_dim,
qk_norm=qk_norm, context_pre_only=False,
use_dual_attention=True if i in dual_attention_layers else False, qk_norm=qk_norm,
) use_dual_attention=True if i in dual_attention_layers else False,
for i in range(num_layers) )
] for i in range(num_layers)
) ]
)
else:
self.context_embedder = None
self.transformer_blocks = nn.ModuleList(
[
SD3SingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(num_layers)
]
)
# controlnet_blocks # controlnet_blocks
self.controlnet_blocks = nn.ModuleList([]) self.controlnet_blocks = nn.ModuleList([])
...@@ -318,9 +339,27 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -318,9 +339,27 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
) )
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")
# SD3.5 8b controlnet does not have a `pos_embed`,
# it use the `pos_embed` from the transformer to process input before passing to controlnet
elif self.pos_embed is None and hidden_states.ndim != 3:
raise ValueError("hidden_states must be 3D when pos_embed is not used")
if self.context_embedder is not None and encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
elif self.context_embedder is None and encoder_hidden_states is not None:
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
if self.pos_embed is not None:
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, pooled_projections) temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.context_embedder is not None:
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
# add # add
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
...@@ -349,9 +388,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -349,9 +388,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
) )
else: else:
encoder_hidden_states, hidden_states = block( if self.context_embedder is not None:
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb encoder_hidden_states, hidden_states = block(
) hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
else:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states = block(hidden_states, temb)
block_res_samples = block_res_samples + (hidden_states,) block_res_samples = block_res_samples + (hidden_states,)
......
...@@ -18,14 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -18,14 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import JointTransformerBlock from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ...models.attention_processor import (
Attention,
AttentionProcessor,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
...@@ -33,6 +40,72 @@ from ..modeling_outputs import Transformer2DModelOutput ...@@ -33,6 +40,72 @@ from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module):
r"""
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
return hidden_states
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
""" """
The Transformer model introduced in Stable Diffusion 3. The Transformer model introduced in Stable Diffusion 3.
......
...@@ -858,6 +858,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -858,6 +858,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height = height or self.default_sample_size * self.vae_scale_factor height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor
controlnet_config = (
self.controlnet.config
if isinstance(self.controlnet, SD3ControlNetModel)
else self.controlnet.nets[0].config
)
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_start = len(control_guidance_end) * [control_guidance_start]
...@@ -932,6 +938,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -932,6 +938,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Prepare control image # 3. Prepare control image
if controlnet_config.force_zeros_for_pooled_projection:
# instantx sd3 controlnet does not apply shift factor
vae_shift_factor = 0
else:
vae_shift_factor = self.vae.config.shift_factor
if isinstance(self.controlnet, SD3ControlNetModel): if isinstance(self.controlnet, SD3ControlNetModel):
control_image = self.prepare_image( control_image = self.prepare_image(
image=control_image, image=control_image,
...@@ -947,8 +958,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -947,8 +958,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height, width = control_image.shape[-2:] height, width = control_image.shape[-2:]
control_image = self.vae.encode(control_image).latent_dist.sample() control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = control_image * self.vae.config.scaling_factor control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
elif isinstance(self.controlnet, SD3MultiControlNetModel): elif isinstance(self.controlnet, SD3MultiControlNetModel):
control_images = [] control_images = []
...@@ -966,7 +976,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -966,7 +976,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
) )
control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = control_image_ * self.vae.config.scaling_factor control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor
control_images.append(control_image_) control_images.append(control_image_)
...@@ -974,11 +984,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -974,11 +984,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
else: else:
assert False assert False
if controlnet_pooled_projections is None:
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
else:
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
...@@ -1006,6 +1011,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -1006,6 +1011,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
] ]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
if controlnet_config.force_zeros_for_pooled_projection:
# instantx sd3 controlnet used zero pooled projection
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
else:
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
if controlnet_config.joint_attention_dim is not None:
controlnet_encoder_hidden_states = prompt_embeds
else:
# SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states = None
# 7. Denoising loop # 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -1025,11 +1042,17 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -1025,11 +1042,17 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
controlnet_cond_scale = controlnet_cond_scale[0] controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i] cond_scale = controlnet_cond_scale * controlnet_keep[i]
if controlnet_config.use_pos_embed is False:
# sd35 (offical) 8b controlnet
controlnet_model_input = self.transformer.pos_embed(latent_model_input)
else:
controlnet_model_input = latent_model_input
# controlnet(s) inference # controlnet(s) inference
control_block_samples = self.controlnet( control_block_samples = self.controlnet(
hidden_states=latent_model_input, hidden_states=controlnet_model_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=controlnet_encoder_hidden_states,
pooled_projections=controlnet_pooled_projections, pooled_projections=controlnet_pooled_projections,
joint_attention_kwargs=self.joint_attention_kwargs, joint_attention_kwargs=self.joint_attention_kwargs,
controlnet_cond=control_image, controlnet_cond=control_image,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment