Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
75bd1e83
Unverified
Commit
75bd1e83
authored
Nov 27, 2024
by
YiYi Xu
Committed by
GitHub
Nov 27, 2024
Browse files
Sd35 controlnet (#10020)
* add model/pipeline Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
8d477dae
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
367 additions
and
43 deletions
+367
-43
scripts/convert_sd3_controlnet_to_diffusers.py
scripts/convert_sd3_controlnet_to_diffusers.py
+185
-0
src/diffusers/models/controlnets/controlnet_sd3.py
src/diffusers/models/controlnets/controlnet_sd3.py
+73
-30
src/diffusers/models/transformers/transformer_sd3.py
src/diffusers/models/transformers/transformer_sd3.py
+76
-3
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
.../controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+33
-10
No files found.
scripts/convert_sd3_controlnet_to_diffusers.py
0 → 100644
View file @
75bd1e83
"""
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
Example:
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py
\
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors"
\
--output_path "output/sd35-controlnet-canny"
\
--dtype "fp16" # optional, defaults to fp32
```
Or download and convert from HuggingFace repository:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py
\
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets"
\
--filename "sd3.5_large_controlnet_canny.safetensors"
\
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers"
\
--dtype "fp32" # optional, defaults to fp32
```
Note:
The script supports the following ControlNet types from SD3.5:
- Canny edge detection
- Depth estimation
- Blur detection
The checkpoint files can be downloaded from:
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
"""
import
argparse
import
safetensors.torch
import
torch
from
huggingface_hub
import
hf_hub_download
from
diffusers
import
SD3ControlNetModel
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to local checkpoint file"
)
parser
.
add_argument
(
"--original_state_dict_repo_id"
,
type
=
str
,
default
=
None
,
help
=
"HuggingFace repo ID containing the checkpoint"
)
parser
.
add_argument
(
"--filename"
,
type
=
str
,
default
=
None
,
help
=
"Filename of the checkpoint in the HF repo"
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to save the converted model"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"fp32"
,
help
=
"Data type for the converted model (fp16, bf16, or fp32)"
)
args
=
parser
.
parse_args
()
def
load_original_checkpoint
(
args
):
if
args
.
original_state_dict_repo_id
is
not
None
:
if
args
.
filename
is
None
:
raise
ValueError
(
"When using `original_state_dict_repo_id`, `filename` must also be specified"
)
print
(
f
"Downloading checkpoint from
{
args
.
original_state_dict_repo_id
}
/
{
args
.
filename
}
"
)
ckpt_path
=
hf_hub_download
(
repo_id
=
args
.
original_state_dict_repo_id
,
filename
=
args
.
filename
)
elif
args
.
checkpoint_path
is
not
None
:
print
(
f
"Loading checkpoint from local path:
{
args
.
checkpoint_path
}
"
)
ckpt_path
=
args
.
checkpoint_path
else
:
raise
ValueError
(
"Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`"
)
original_state_dict
=
safetensors
.
torch
.
load_file
(
ckpt_path
)
return
original_state_dict
def
convert_sd3_controlnet_checkpoint_to_diffusers
(
original_state_dict
):
converted_state_dict
=
{}
# Direct mappings for controlnet blocks
for
i
in
range
(
19
):
# 19 controlnet blocks
converted_state_dict
[
f
"controlnet_blocks.
{
i
}
.weight"
]
=
original_state_dict
[
f
"controlnet_blocks.
{
i
}
.weight"
]
converted_state_dict
[
f
"controlnet_blocks.
{
i
}
.bias"
]
=
original_state_dict
[
f
"controlnet_blocks.
{
i
}
.bias"
]
# Positional embeddings
converted_state_dict
[
"pos_embed_input.proj.weight"
]
=
original_state_dict
[
"pos_embed_input.proj.weight"
]
converted_state_dict
[
"pos_embed_input.proj.bias"
]
=
original_state_dict
[
"pos_embed_input.proj.bias"
]
# Time and text embeddings
time_text_mappings
=
{
"time_text_embed.timestep_embedder.linear_1.weight"
:
"time_text_embed.timestep_embedder.linear_1.weight"
,
"time_text_embed.timestep_embedder.linear_1.bias"
:
"time_text_embed.timestep_embedder.linear_1.bias"
,
"time_text_embed.timestep_embedder.linear_2.weight"
:
"time_text_embed.timestep_embedder.linear_2.weight"
,
"time_text_embed.timestep_embedder.linear_2.bias"
:
"time_text_embed.timestep_embedder.linear_2.bias"
,
"time_text_embed.text_embedder.linear_1.weight"
:
"time_text_embed.text_embedder.linear_1.weight"
,
"time_text_embed.text_embedder.linear_1.bias"
:
"time_text_embed.text_embedder.linear_1.bias"
,
"time_text_embed.text_embedder.linear_2.weight"
:
"time_text_embed.text_embedder.linear_2.weight"
,
"time_text_embed.text_embedder.linear_2.bias"
:
"time_text_embed.text_embedder.linear_2.bias"
,
}
for
new_key
,
old_key
in
time_text_mappings
.
items
():
if
old_key
in
original_state_dict
:
converted_state_dict
[
new_key
]
=
original_state_dict
[
old_key
]
# Transformer blocks
for
i
in
range
(
19
):
# Split QKV into separate Q, K, V
qkv_weight
=
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.qkv.weight"
]
qkv_bias
=
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.qkv.bias"
]
q
,
k
,
v
=
torch
.
chunk
(
qkv_weight
,
3
,
dim
=
0
)
q_bias
,
k_bias
,
v_bias
=
torch
.
chunk
(
qkv_bias
,
3
,
dim
=
0
)
block_mappings
=
{
f
"transformer_blocks.
{
i
}
.attn.to_q.weight"
:
q
,
f
"transformer_blocks.
{
i
}
.attn.to_q.bias"
:
q_bias
,
f
"transformer_blocks.
{
i
}
.attn.to_k.weight"
:
k
,
f
"transformer_blocks.
{
i
}
.attn.to_k.bias"
:
k_bias
,
f
"transformer_blocks.
{
i
}
.attn.to_v.weight"
:
v
,
f
"transformer_blocks.
{
i
}
.attn.to_v.bias"
:
v_bias
,
# Output projections
f
"transformer_blocks.
{
i
}
.attn.to_out.0.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.proj.weight"
],
f
"transformer_blocks.
{
i
}
.attn.to_out.0.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.proj.bias"
],
# Feed forward
f
"transformer_blocks.
{
i
}
.ff.net.0.proj.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc1.weight"
],
f
"transformer_blocks.
{
i
}
.ff.net.0.proj.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc1.bias"
],
f
"transformer_blocks.
{
i
}
.ff.net.2.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc2.weight"
],
f
"transformer_blocks.
{
i
}
.ff.net.2.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc2.bias"
],
# Norms
f
"transformer_blocks.
{
i
}
.norm1.linear.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.adaLN_modulation.1.weight"
],
f
"transformer_blocks.
{
i
}
.norm1.linear.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.adaLN_modulation.1.bias"
],
}
converted_state_dict
.
update
(
block_mappings
)
return
converted_state_dict
def
main
(
args
):
original_ckpt
=
load_original_checkpoint
(
args
)
original_dtype
=
next
(
iter
(
original_ckpt
.
values
())).
dtype
# Initialize dtype with fp32 as default
if
args
.
dtype
==
"fp16"
:
dtype
=
torch
.
float16
elif
args
.
dtype
==
"bf16"
:
dtype
=
torch
.
bfloat16
elif
args
.
dtype
==
"fp32"
:
dtype
=
torch
.
float32
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
args
.
dtype
}
. Must be one of: fp16, bf16, fp32"
)
if
dtype
!=
original_dtype
:
print
(
f
"Converting checkpoint from
{
original_dtype
}
to
{
dtype
}
. This can lead to unexpected results, proceed with caution."
)
converted_controlnet_state_dict
=
convert_sd3_controlnet_checkpoint_to_diffusers
(
original_ckpt
)
controlnet
=
SD3ControlNetModel
(
patch_size
=
2
,
in_channels
=
16
,
num_layers
=
19
,
attention_head_dim
=
64
,
num_attention_heads
=
38
,
joint_attention_dim
=
None
,
caption_projection_dim
=
2048
,
pooled_projection_dim
=
2048
,
out_channels
=
16
,
pos_embed_max_size
=
None
,
pos_embed_type
=
None
,
use_pos_embed
=
False
,
force_zeros_for_pooled_projection
=
False
,
)
controlnet
.
load_state_dict
(
converted_controlnet_state_dict
,
strict
=
True
)
print
(
f
"Saving SD3 ControlNet in Diffusers format in
{
args
.
output_path
}
."
)
controlnet
.
to
(
dtype
).
save_pretrained
(
args
.
output_path
)
if
__name__
==
"__main__"
:
main
(
args
)
src/diffusers/models/controlnets/controlnet_sd3.py
View file @
75bd1e83
...
...
@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP
from
..embeddings
import
CombinedTimestepTextProjEmbeddings
,
PatchEmbed
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_utils
import
ModelMixin
from
..transformers.transformer_sd3
import
SD3SingleTransformerBlock
from
.controlnet
import
BaseOutput
,
zero_module
...
...
@@ -58,12 +59,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
extra_conditioning_channels
:
int
=
0
,
dual_attention_layers
:
Tuple
[
int
,
...]
=
(),
qk_norm
:
Optional
[
str
]
=
None
,
pos_embed_type
:
Optional
[
str
]
=
"sincos"
,
use_pos_embed
:
bool
=
True
,
force_zeros_for_pooled_projection
:
bool
=
True
,
):
super
().
__init__
()
default_out_channels
=
in_channels
self
.
out_channels
=
out_channels
if
out_channels
is
not
None
else
default_out_channels
self
.
inner_dim
=
num_attention_heads
*
attention_head_dim
if
use_pos_embed
:
self
.
pos_embed
=
PatchEmbed
(
height
=
sample_size
,
width
=
sample_size
,
...
...
@@ -71,10 +76,14 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
in_channels
=
in_channels
,
embed_dim
=
self
.
inner_dim
,
pos_embed_max_size
=
pos_embed_max_size
,
pos_embed_type
=
pos_embed_type
,
)
else
:
self
.
pos_embed
=
None
self
.
time_text_embed
=
CombinedTimestepTextProjEmbeddings
(
embedding_dim
=
self
.
inner_dim
,
pooled_projection_dim
=
pooled_projection_dim
)
if
joint_attention_dim
is
not
None
:
self
.
context_embedder
=
nn
.
Linear
(
joint_attention_dim
,
caption_projection_dim
)
# `attention_head_dim` is doubled to account for the mixing.
...
...
@@ -92,6 +101,18 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
for
i
in
range
(
num_layers
)
]
)
else
:
self
.
context_embedder
=
None
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
SD3SingleTransformerBlock
(
dim
=
self
.
inner_dim
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
self
.
config
.
attention_head_dim
,
)
for
_
in
range
(
num_layers
)
]
)
# controlnet_blocks
self
.
controlnet_blocks
=
nn
.
ModuleList
([])
...
...
@@ -318,8 +339,26 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if
self
.
pos_embed
is
not
None
and
hidden_states
.
ndim
!=
4
:
raise
ValueError
(
"hidden_states must be 4D when pos_embed is used"
)
# SD3.5 8b controlnet does not have a `pos_embed`,
# it use the `pos_embed` from the transformer to process input before passing to controlnet
elif
self
.
pos_embed
is
None
and
hidden_states
.
ndim
!=
3
:
raise
ValueError
(
"hidden_states must be 3D when pos_embed is not used"
)
if
self
.
context_embedder
is
not
None
and
encoder_hidden_states
is
None
:
raise
ValueError
(
"encoder_hidden_states must be provided when context_embedder is used"
)
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
elif
self
.
context_embedder
is
None
and
encoder_hidden_states
is
not
None
:
raise
ValueError
(
"encoder_hidden_states should not be provided when context_embedder is not used"
)
if
self
.
pos_embed
is
not
None
:
hidden_states
=
self
.
pos_embed
(
hidden_states
)
# takes care of adding positional embeddings too.
temb
=
self
.
time_text_embed
(
timestep
,
pooled_projections
)
if
self
.
context_embedder
is
not
None
:
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
# add
...
...
@@ -349,9 +388,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
)
else
:
if
self
.
context_embedder
is
not
None
:
encoder_hidden_states
,
hidden_states
=
block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
)
else
:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states
=
block
(
hidden_states
,
temb
)
block_res_samples
=
block_res_samples
+
(
hidden_states
,)
...
...
src/diffusers/models/transformers/transformer_sd3.py
View file @
75bd1e83
...
...
@@ -18,14 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...models.attention
import
JointTransformerBlock
from
...models.attention_processor
import
Attention
,
AttentionProcessor
,
FusedJointAttnProcessor2_0
from
...models.attention
import
FeedForward
,
JointTransformerBlock
from
...models.attention_processor
import
(
Attention
,
AttentionProcessor
,
FusedJointAttnProcessor2_0
,
JointAttnProcessor2_0
,
)
from
...models.modeling_utils
import
ModelMixin
from
...models.normalization
import
AdaLayerNormContinuous
from
...models.normalization
import
AdaLayerNormContinuous
,
AdaLayerNormZero
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils.torch_utils
import
maybe_allow_in_graph
from
..embeddings
import
CombinedTimestepTextProjEmbeddings
,
PatchEmbed
from
..modeling_outputs
import
Transformer2DModelOutput
...
...
@@ -33,6 +40,72 @@ from ..modeling_outputs import Transformer2DModelOutput
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
@
maybe_allow_in_graph
class
SD3SingleTransformerBlock
(
nn
.
Module
):
r
"""
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""
def
__init__
(
self
,
dim
:
int
,
num_attention_heads
:
int
,
attention_head_dim
:
int
,
):
super
().
__init__
()
self
.
norm1
=
AdaLayerNormZero
(
dim
)
if
hasattr
(
F
,
"scaled_dot_product_attention"
):
processor
=
JointAttnProcessor2_0
()
else
:
raise
ValueError
(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self
.
attn
=
Attention
(
query_dim
=
dim
,
dim_head
=
attention_head_dim
,
heads
=
num_attention_heads
,
out_dim
=
dim
,
bias
=
True
,
processor
=
processor
,
eps
=
1e-6
,
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
ff
=
FeedForward
(
dim
=
dim
,
dim_out
=
dim
,
activation_fn
=
"gelu-approximate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
):
norm_hidden_states
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
norm1
(
hidden_states
,
emb
=
temb
)
# Attention.
attn_output
=
self
.
attn
(
hidden_states
=
norm_hidden_states
,
encoder_hidden_states
=
None
,
)
# Process attention outputs for the `hidden_states`.
attn_output
=
gate_msa
.
unsqueeze
(
1
)
*
attn_output
hidden_states
=
hidden_states
+
attn_output
norm_hidden_states
=
self
.
norm2
(
hidden_states
)
norm_hidden_states
=
norm_hidden_states
*
(
1
+
scale_mlp
[:,
None
])
+
shift_mlp
[:,
None
]
ff_output
=
self
.
ff
(
norm_hidden_states
)
ff_output
=
gate_mlp
.
unsqueeze
(
1
)
*
ff_output
hidden_states
=
hidden_states
+
ff_output
return
hidden_states
class
SD3Transformer2DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
"""
The Transformer model introduced in Stable Diffusion 3.
...
...
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
View file @
75bd1e83
...
...
@@ -858,6 +858,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height
=
height
or
self
.
default_sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
default_sample_size
*
self
.
vae_scale_factor
controlnet_config
=
(
self
.
controlnet
.
config
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
)
else
self
.
controlnet
.
nets
[
0
].
config
)
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
...
...
@@ -932,6 +938,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
pooled_prompt_embeds
=
torch
.
cat
([
negative_pooled_prompt_embeds
,
pooled_prompt_embeds
],
dim
=
0
)
# 3. Prepare control image
if
controlnet_config
.
force_zeros_for_pooled_projection
:
# instantx sd3 controlnet does not apply shift factor
vae_shift_factor
=
0
else
:
vae_shift_factor
=
self
.
vae
.
config
.
shift_factor
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
):
control_image
=
self
.
prepare_image
(
image
=
control_image
,
...
...
@@ -947,8 +958,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image
=
self
.
vae
.
encode
(
control_image
).
latent_dist
.
sample
()
control_image
=
control_image
*
self
.
vae
.
config
.
scaling_factor
control_image
=
(
control_image
-
vae_shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
elif
isinstance
(
self
.
controlnet
,
SD3MultiControlNetModel
):
control_images
=
[]
...
...
@@ -966,7 +976,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
)
control_image_
=
self
.
vae
.
encode
(
control_image_
).
latent_dist
.
sample
()
control_image_
=
control_image_
*
self
.
vae
.
config
.
scaling_factor
control_image_
=
(
control_image_
-
vae_shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_images
.
append
(
control_image_
)
...
...
@@ -974,11 +984,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
else
:
assert
False
if
controlnet_pooled_projections
is
None
:
controlnet_pooled_projections
=
torch
.
zeros_like
(
pooled_prompt_embeds
)
else
:
controlnet_pooled_projections
=
controlnet_pooled_projections
or
pooled_prompt_embeds
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timesteps
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
...
...
@@ -1006,6 +1011,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
]
controlnet_keep
.
append
(
keeps
[
0
]
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
)
else
keeps
)
if
controlnet_config
.
force_zeros_for_pooled_projection
:
# instantx sd3 controlnet used zero pooled projection
controlnet_pooled_projections
=
torch
.
zeros_like
(
pooled_prompt_embeds
)
else
:
controlnet_pooled_projections
=
controlnet_pooled_projections
or
pooled_prompt_embeds
if
controlnet_config
.
joint_attention_dim
is
not
None
:
controlnet_encoder_hidden_states
=
prompt_embeds
else
:
# SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states
=
None
# 7. Denoising loop
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
timesteps
):
...
...
@@ -1025,11 +1042,17 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
controlnet_cond_scale
=
controlnet_cond_scale
[
0
]
cond_scale
=
controlnet_cond_scale
*
controlnet_keep
[
i
]
if
controlnet_config
.
use_pos_embed
is
False
:
# sd35 (offical) 8b controlnet
controlnet_model_input
=
self
.
transformer
.
pos_embed
(
latent_model_input
)
else
:
controlnet_model_input
=
latent_model_input
# controlnet(s) inference
control_block_samples
=
self
.
controlnet
(
hidden_states
=
laten
t_model_input
,
hidden_states
=
controlne
t_model_input
,
timestep
=
timestep
,
encoder_hidden_states
=
prompt_embed
s
,
encoder_hidden_states
=
controlnet_encoder_hidden_state
s
,
pooled_projections
=
controlnet_pooled_projections
,
joint_attention_kwargs
=
self
.
joint_attention_kwargs
,
controlnet_cond
=
control_image
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment