Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
75bd1e83
Unverified
Commit
75bd1e83
authored
Nov 27, 2024
by
YiYi Xu
Committed by
GitHub
Nov 27, 2024
Browse files
Sd35 controlnet (#10020)
* add model/pipeline Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
8d477dae
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
367 additions
and
43 deletions
+367
-43
scripts/convert_sd3_controlnet_to_diffusers.py
scripts/convert_sd3_controlnet_to_diffusers.py
+185
-0
src/diffusers/models/controlnets/controlnet_sd3.py
src/diffusers/models/controlnets/controlnet_sd3.py
+73
-30
src/diffusers/models/transformers/transformer_sd3.py
src/diffusers/models/transformers/transformer_sd3.py
+76
-3
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
.../controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+33
-10
No files found.
scripts/convert_sd3_controlnet_to_diffusers.py
0 → 100644
View file @
75bd1e83
"""
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
Example:
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py
\
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors"
\
--output_path "output/sd35-controlnet-canny"
\
--dtype "fp16" # optional, defaults to fp32
```
Or download and convert from HuggingFace repository:
```bash
python scripts/convert_sd3_controlnet_to_diffusers.py
\
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets"
\
--filename "sd3.5_large_controlnet_canny.safetensors"
\
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers"
\
--dtype "fp32" # optional, defaults to fp32
```
Note:
The script supports the following ControlNet types from SD3.5:
- Canny edge detection
- Depth estimation
- Blur detection
The checkpoint files can be downloaded from:
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
"""
import
argparse
import
safetensors.torch
import
torch
from
huggingface_hub
import
hf_hub_download
from
diffusers
import
SD3ControlNetModel
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to local checkpoint file"
)
parser
.
add_argument
(
"--original_state_dict_repo_id"
,
type
=
str
,
default
=
None
,
help
=
"HuggingFace repo ID containing the checkpoint"
)
parser
.
add_argument
(
"--filename"
,
type
=
str
,
default
=
None
,
help
=
"Filename of the checkpoint in the HF repo"
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to save the converted model"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"fp32"
,
help
=
"Data type for the converted model (fp16, bf16, or fp32)"
)
args
=
parser
.
parse_args
()
def
load_original_checkpoint
(
args
):
if
args
.
original_state_dict_repo_id
is
not
None
:
if
args
.
filename
is
None
:
raise
ValueError
(
"When using `original_state_dict_repo_id`, `filename` must also be specified"
)
print
(
f
"Downloading checkpoint from
{
args
.
original_state_dict_repo_id
}
/
{
args
.
filename
}
"
)
ckpt_path
=
hf_hub_download
(
repo_id
=
args
.
original_state_dict_repo_id
,
filename
=
args
.
filename
)
elif
args
.
checkpoint_path
is
not
None
:
print
(
f
"Loading checkpoint from local path:
{
args
.
checkpoint_path
}
"
)
ckpt_path
=
args
.
checkpoint_path
else
:
raise
ValueError
(
"Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`"
)
original_state_dict
=
safetensors
.
torch
.
load_file
(
ckpt_path
)
return
original_state_dict
def
convert_sd3_controlnet_checkpoint_to_diffusers
(
original_state_dict
):
converted_state_dict
=
{}
# Direct mappings for controlnet blocks
for
i
in
range
(
19
):
# 19 controlnet blocks
converted_state_dict
[
f
"controlnet_blocks.
{
i
}
.weight"
]
=
original_state_dict
[
f
"controlnet_blocks.
{
i
}
.weight"
]
converted_state_dict
[
f
"controlnet_blocks.
{
i
}
.bias"
]
=
original_state_dict
[
f
"controlnet_blocks.
{
i
}
.bias"
]
# Positional embeddings
converted_state_dict
[
"pos_embed_input.proj.weight"
]
=
original_state_dict
[
"pos_embed_input.proj.weight"
]
converted_state_dict
[
"pos_embed_input.proj.bias"
]
=
original_state_dict
[
"pos_embed_input.proj.bias"
]
# Time and text embeddings
time_text_mappings
=
{
"time_text_embed.timestep_embedder.linear_1.weight"
:
"time_text_embed.timestep_embedder.linear_1.weight"
,
"time_text_embed.timestep_embedder.linear_1.bias"
:
"time_text_embed.timestep_embedder.linear_1.bias"
,
"time_text_embed.timestep_embedder.linear_2.weight"
:
"time_text_embed.timestep_embedder.linear_2.weight"
,
"time_text_embed.timestep_embedder.linear_2.bias"
:
"time_text_embed.timestep_embedder.linear_2.bias"
,
"time_text_embed.text_embedder.linear_1.weight"
:
"time_text_embed.text_embedder.linear_1.weight"
,
"time_text_embed.text_embedder.linear_1.bias"
:
"time_text_embed.text_embedder.linear_1.bias"
,
"time_text_embed.text_embedder.linear_2.weight"
:
"time_text_embed.text_embedder.linear_2.weight"
,
"time_text_embed.text_embedder.linear_2.bias"
:
"time_text_embed.text_embedder.linear_2.bias"
,
}
for
new_key
,
old_key
in
time_text_mappings
.
items
():
if
old_key
in
original_state_dict
:
converted_state_dict
[
new_key
]
=
original_state_dict
[
old_key
]
# Transformer blocks
for
i
in
range
(
19
):
# Split QKV into separate Q, K, V
qkv_weight
=
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.qkv.weight"
]
qkv_bias
=
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.qkv.bias"
]
q
,
k
,
v
=
torch
.
chunk
(
qkv_weight
,
3
,
dim
=
0
)
q_bias
,
k_bias
,
v_bias
=
torch
.
chunk
(
qkv_bias
,
3
,
dim
=
0
)
block_mappings
=
{
f
"transformer_blocks.
{
i
}
.attn.to_q.weight"
:
q
,
f
"transformer_blocks.
{
i
}
.attn.to_q.bias"
:
q_bias
,
f
"transformer_blocks.
{
i
}
.attn.to_k.weight"
:
k
,
f
"transformer_blocks.
{
i
}
.attn.to_k.bias"
:
k_bias
,
f
"transformer_blocks.
{
i
}
.attn.to_v.weight"
:
v
,
f
"transformer_blocks.
{
i
}
.attn.to_v.bias"
:
v_bias
,
# Output projections
f
"transformer_blocks.
{
i
}
.attn.to_out.0.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.proj.weight"
],
f
"transformer_blocks.
{
i
}
.attn.to_out.0.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.proj.bias"
],
# Feed forward
f
"transformer_blocks.
{
i
}
.ff.net.0.proj.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc1.weight"
],
f
"transformer_blocks.
{
i
}
.ff.net.0.proj.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc1.bias"
],
f
"transformer_blocks.
{
i
}
.ff.net.2.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc2.weight"
],
f
"transformer_blocks.
{
i
}
.ff.net.2.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.mlp.fc2.bias"
],
# Norms
f
"transformer_blocks.
{
i
}
.norm1.linear.weight"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.adaLN_modulation.1.weight"
],
f
"transformer_blocks.
{
i
}
.norm1.linear.bias"
:
original_state_dict
[
f
"transformer_blocks.
{
i
}
.adaLN_modulation.1.bias"
],
}
converted_state_dict
.
update
(
block_mappings
)
return
converted_state_dict
def
main
(
args
):
original_ckpt
=
load_original_checkpoint
(
args
)
original_dtype
=
next
(
iter
(
original_ckpt
.
values
())).
dtype
# Initialize dtype with fp32 as default
if
args
.
dtype
==
"fp16"
:
dtype
=
torch
.
float16
elif
args
.
dtype
==
"bf16"
:
dtype
=
torch
.
bfloat16
elif
args
.
dtype
==
"fp32"
:
dtype
=
torch
.
float32
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
args
.
dtype
}
. Must be one of: fp16, bf16, fp32"
)
if
dtype
!=
original_dtype
:
print
(
f
"Converting checkpoint from
{
original_dtype
}
to
{
dtype
}
. This can lead to unexpected results, proceed with caution."
)
converted_controlnet_state_dict
=
convert_sd3_controlnet_checkpoint_to_diffusers
(
original_ckpt
)
controlnet
=
SD3ControlNetModel
(
patch_size
=
2
,
in_channels
=
16
,
num_layers
=
19
,
attention_head_dim
=
64
,
num_attention_heads
=
38
,
joint_attention_dim
=
None
,
caption_projection_dim
=
2048
,
pooled_projection_dim
=
2048
,
out_channels
=
16
,
pos_embed_max_size
=
None
,
pos_embed_type
=
None
,
use_pos_embed
=
False
,
force_zeros_for_pooled_projection
=
False
,
)
controlnet
.
load_state_dict
(
converted_controlnet_state_dict
,
strict
=
True
)
print
(
f
"Saving SD3 ControlNet in Diffusers format in
{
args
.
output_path
}
."
)
controlnet
.
to
(
dtype
).
save_pretrained
(
args
.
output_path
)
if
__name__
==
"__main__"
:
main
(
args
)
src/diffusers/models/controlnets/controlnet_sd3.py
View file @
75bd1e83
...
@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP
...
@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP
from
..embeddings
import
CombinedTimestepTextProjEmbeddings
,
PatchEmbed
from
..embeddings
import
CombinedTimestepTextProjEmbeddings
,
PatchEmbed
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
..transformers.transformer_sd3
import
SD3SingleTransformerBlock
from
.controlnet
import
BaseOutput
,
zero_module
from
.controlnet
import
BaseOutput
,
zero_module
...
@@ -58,12 +59,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
...
@@ -58,12 +59,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
extra_conditioning_channels
:
int
=
0
,
extra_conditioning_channels
:
int
=
0
,
dual_attention_layers
:
Tuple
[
int
,
...]
=
(),
dual_attention_layers
:
Tuple
[
int
,
...]
=
(),
qk_norm
:
Optional
[
str
]
=
None
,
qk_norm
:
Optional
[
str
]
=
None
,
pos_embed_type
:
Optional
[
str
]
=
"sincos"
,
use_pos_embed
:
bool
=
True
,
force_zeros_for_pooled_projection
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
default_out_channels
=
in_channels
default_out_channels
=
in_channels
self
.
out_channels
=
out_channels
if
out_channels
is
not
None
else
default_out_channels
self
.
out_channels
=
out_channels
if
out_channels
is
not
None
else
default_out_channels
self
.
inner_dim
=
num_attention_heads
*
attention_head_dim
self
.
inner_dim
=
num_attention_heads
*
attention_head_dim
if
use_pos_embed
:
self
.
pos_embed
=
PatchEmbed
(
self
.
pos_embed
=
PatchEmbed
(
height
=
sample_size
,
height
=
sample_size
,
width
=
sample_size
,
width
=
sample_size
,
...
@@ -71,10 +76,14 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
...
@@ -71,10 +76,14 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
in_channels
=
in_channels
,
in_channels
=
in_channels
,
embed_dim
=
self
.
inner_dim
,
embed_dim
=
self
.
inner_dim
,
pos_embed_max_size
=
pos_embed_max_size
,
pos_embed_max_size
=
pos_embed_max_size
,
pos_embed_type
=
pos_embed_type
,
)
)
else
:
self
.
pos_embed
=
None
self
.
time_text_embed
=
CombinedTimestepTextProjEmbeddings
(
self
.
time_text_embed
=
CombinedTimestepTextProjEmbeddings
(
embedding_dim
=
self
.
inner_dim
,
pooled_projection_dim
=
pooled_projection_dim
embedding_dim
=
self
.
inner_dim
,
pooled_projection_dim
=
pooled_projection_dim
)
)
if
joint_attention_dim
is
not
None
:
self
.
context_embedder
=
nn
.
Linear
(
joint_attention_dim
,
caption_projection_dim
)
self
.
context_embedder
=
nn
.
Linear
(
joint_attention_dim
,
caption_projection_dim
)
# `attention_head_dim` is doubled to account for the mixing.
# `attention_head_dim` is doubled to account for the mixing.
...
@@ -92,6 +101,18 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
...
@@ -92,6 +101,18 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
for
i
in
range
(
num_layers
)
for
i
in
range
(
num_layers
)
]
]
)
)
else
:
self
.
context_embedder
=
None
self
.
transformer_blocks
=
nn
.
ModuleList
(
[
SD3SingleTransformerBlock
(
dim
=
self
.
inner_dim
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
self
.
config
.
attention_head_dim
,
)
for
_
in
range
(
num_layers
)
]
)
# controlnet_blocks
# controlnet_blocks
self
.
controlnet_blocks
=
nn
.
ModuleList
([])
self
.
controlnet_blocks
=
nn
.
ModuleList
([])
...
@@ -318,8 +339,26 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
...
@@ -318,8 +339,26 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
)
if
self
.
pos_embed
is
not
None
and
hidden_states
.
ndim
!=
4
:
raise
ValueError
(
"hidden_states must be 4D when pos_embed is used"
)
# SD3.5 8b controlnet does not have a `pos_embed`,
# it use the `pos_embed` from the transformer to process input before passing to controlnet
elif
self
.
pos_embed
is
None
and
hidden_states
.
ndim
!=
3
:
raise
ValueError
(
"hidden_states must be 3D when pos_embed is not used"
)
if
self
.
context_embedder
is
not
None
and
encoder_hidden_states
is
None
:
raise
ValueError
(
"encoder_hidden_states must be provided when context_embedder is used"
)
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
elif
self
.
context_embedder
is
None
and
encoder_hidden_states
is
not
None
:
raise
ValueError
(
"encoder_hidden_states should not be provided when context_embedder is not used"
)
if
self
.
pos_embed
is
not
None
:
hidden_states
=
self
.
pos_embed
(
hidden_states
)
# takes care of adding positional embeddings too.
hidden_states
=
self
.
pos_embed
(
hidden_states
)
# takes care of adding positional embeddings too.
temb
=
self
.
time_text_embed
(
timestep
,
pooled_projections
)
temb
=
self
.
time_text_embed
(
timestep
,
pooled_projections
)
if
self
.
context_embedder
is
not
None
:
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
encoder_hidden_states
=
self
.
context_embedder
(
encoder_hidden_states
)
# add
# add
...
@@ -349,9 +388,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
...
@@ -349,9 +388,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
)
)
else
:
else
:
if
self
.
context_embedder
is
not
None
:
encoder_hidden_states
,
hidden_states
=
block
(
encoder_hidden_states
,
hidden_states
=
block
(
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
)
)
else
:
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
hidden_states
=
block
(
hidden_states
,
temb
)
block_res_samples
=
block_res_samples
+
(
hidden_states
,)
block_res_samples
=
block_res_samples
+
(
hidden_states
,)
...
...
src/diffusers/models/transformers/transformer_sd3.py
View file @
75bd1e83
...
@@ -18,14 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
...
@@ -18,14 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...models.attention
import
JointTransformerBlock
from
...models.attention
import
FeedForward
,
JointTransformerBlock
from
...models.attention_processor
import
Attention
,
AttentionProcessor
,
FusedJointAttnProcessor2_0
from
...models.attention_processor
import
(
Attention
,
AttentionProcessor
,
FusedJointAttnProcessor2_0
,
JointAttnProcessor2_0
,
)
from
...models.modeling_utils
import
ModelMixin
from
...models.modeling_utils
import
ModelMixin
from
...models.normalization
import
AdaLayerNormContinuous
from
...models.normalization
import
AdaLayerNormContinuous
,
AdaLayerNormZero
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils.torch_utils
import
maybe_allow_in_graph
from
..embeddings
import
CombinedTimestepTextProjEmbeddings
,
PatchEmbed
from
..embeddings
import
CombinedTimestepTextProjEmbeddings
,
PatchEmbed
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_outputs
import
Transformer2DModelOutput
...
@@ -33,6 +40,72 @@ from ..modeling_outputs import Transformer2DModelOutput
...
@@ -33,6 +40,72 @@ from ..modeling_outputs import Transformer2DModelOutput
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
@
maybe_allow_in_graph
class
SD3SingleTransformerBlock
(
nn
.
Module
):
r
"""
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""
def
__init__
(
self
,
dim
:
int
,
num_attention_heads
:
int
,
attention_head_dim
:
int
,
):
super
().
__init__
()
self
.
norm1
=
AdaLayerNormZero
(
dim
)
if
hasattr
(
F
,
"scaled_dot_product_attention"
):
processor
=
JointAttnProcessor2_0
()
else
:
raise
ValueError
(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self
.
attn
=
Attention
(
query_dim
=
dim
,
dim_head
=
attention_head_dim
,
heads
=
num_attention_heads
,
out_dim
=
dim
,
bias
=
True
,
processor
=
processor
,
eps
=
1e-6
,
)
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
ff
=
FeedForward
(
dim
=
dim
,
dim_out
=
dim
,
activation_fn
=
"gelu-approximate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
):
norm_hidden_states
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
=
self
.
norm1
(
hidden_states
,
emb
=
temb
)
# Attention.
attn_output
=
self
.
attn
(
hidden_states
=
norm_hidden_states
,
encoder_hidden_states
=
None
,
)
# Process attention outputs for the `hidden_states`.
attn_output
=
gate_msa
.
unsqueeze
(
1
)
*
attn_output
hidden_states
=
hidden_states
+
attn_output
norm_hidden_states
=
self
.
norm2
(
hidden_states
)
norm_hidden_states
=
norm_hidden_states
*
(
1
+
scale_mlp
[:,
None
])
+
shift_mlp
[:,
None
]
ff_output
=
self
.
ff
(
norm_hidden_states
)
ff_output
=
gate_mlp
.
unsqueeze
(
1
)
*
ff_output
hidden_states
=
hidden_states
+
ff_output
return
hidden_states
class
SD3Transformer2DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
class
SD3Transformer2DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
"""
"""
The Transformer model introduced in Stable Diffusion 3.
The Transformer model introduced in Stable Diffusion 3.
...
...
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
View file @
75bd1e83
...
@@ -858,6 +858,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
...
@@ -858,6 +858,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height
=
height
or
self
.
default_sample_size
*
self
.
vae_scale_factor
height
=
height
or
self
.
default_sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
default_sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
default_sample_size
*
self
.
vae_scale_factor
controlnet_config
=
(
self
.
controlnet
.
config
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
)
else
self
.
controlnet
.
nets
[
0
].
config
)
# align format for control guidance
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
...
@@ -932,6 +938,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
...
@@ -932,6 +938,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
pooled_prompt_embeds
=
torch
.
cat
([
negative_pooled_prompt_embeds
,
pooled_prompt_embeds
],
dim
=
0
)
pooled_prompt_embeds
=
torch
.
cat
([
negative_pooled_prompt_embeds
,
pooled_prompt_embeds
],
dim
=
0
)
# 3. Prepare control image
# 3. Prepare control image
if
controlnet_config
.
force_zeros_for_pooled_projection
:
# instantx sd3 controlnet does not apply shift factor
vae_shift_factor
=
0
else
:
vae_shift_factor
=
self
.
vae
.
config
.
shift_factor
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
):
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
):
control_image
=
self
.
prepare_image
(
control_image
=
self
.
prepare_image
(
image
=
control_image
,
image
=
control_image
,
...
@@ -947,8 +958,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
...
@@ -947,8 +958,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height
,
width
=
control_image
.
shape
[
-
2
:]
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image
=
self
.
vae
.
encode
(
control_image
).
latent_dist
.
sample
()
control_image
=
self
.
vae
.
encode
(
control_image
).
latent_dist
.
sample
()
control_image
=
control_image
*
self
.
vae
.
config
.
scaling_factor
control_image
=
(
control_image
-
vae_shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
elif
isinstance
(
self
.
controlnet
,
SD3MultiControlNetModel
):
elif
isinstance
(
self
.
controlnet
,
SD3MultiControlNetModel
):
control_images
=
[]
control_images
=
[]
...
@@ -966,7 +976,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
...
@@ -966,7 +976,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
)
)
control_image_
=
self
.
vae
.
encode
(
control_image_
).
latent_dist
.
sample
()
control_image_
=
self
.
vae
.
encode
(
control_image_
).
latent_dist
.
sample
()
control_image_
=
control_image_
*
self
.
vae
.
config
.
scaling_factor
control_image_
=
(
control_image_
-
vae_shift_factor
)
*
self
.
vae
.
config
.
scaling_factor
control_images
.
append
(
control_image_
)
control_images
.
append
(
control_image_
)
...
@@ -974,11 +984,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
...
@@ -974,11 +984,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
else
:
else
:
assert
False
assert
False
if
controlnet_pooled_projections
is
None
:
controlnet_pooled_projections
=
torch
.
zeros_like
(
pooled_prompt_embeds
)
else
:
controlnet_pooled_projections
=
controlnet_pooled_projections
or
pooled_prompt_embeds
# 4. Prepare timesteps
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timesteps
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timesteps
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
...
@@ -1006,6 +1011,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
...
@@ -1006,6 +1011,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
]
]
controlnet_keep
.
append
(
keeps
[
0
]
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
)
else
keeps
)
controlnet_keep
.
append
(
keeps
[
0
]
if
isinstance
(
self
.
controlnet
,
SD3ControlNetModel
)
else
keeps
)
if
controlnet_config
.
force_zeros_for_pooled_projection
:
# instantx sd3 controlnet used zero pooled projection
controlnet_pooled_projections
=
torch
.
zeros_like
(
pooled_prompt_embeds
)
else
:
controlnet_pooled_projections
=
controlnet_pooled_projections
or
pooled_prompt_embeds
if
controlnet_config
.
joint_attention_dim
is
not
None
:
controlnet_encoder_hidden_states
=
prompt_embeds
else
:
# SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states
=
None
# 7. Denoising loop
# 7. Denoising loop
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
with
self
.
progress_bar
(
total
=
num_inference_steps
)
as
progress_bar
:
for
i
,
t
in
enumerate
(
timesteps
):
for
i
,
t
in
enumerate
(
timesteps
):
...
@@ -1025,11 +1042,17 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
...
@@ -1025,11 +1042,17 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
controlnet_cond_scale
=
controlnet_cond_scale
[
0
]
controlnet_cond_scale
=
controlnet_cond_scale
[
0
]
cond_scale
=
controlnet_cond_scale
*
controlnet_keep
[
i
]
cond_scale
=
controlnet_cond_scale
*
controlnet_keep
[
i
]
if
controlnet_config
.
use_pos_embed
is
False
:
# sd35 (offical) 8b controlnet
controlnet_model_input
=
self
.
transformer
.
pos_embed
(
latent_model_input
)
else
:
controlnet_model_input
=
latent_model_input
# controlnet(s) inference
# controlnet(s) inference
control_block_samples
=
self
.
controlnet
(
control_block_samples
=
self
.
controlnet
(
hidden_states
=
laten
t_model_input
,
hidden_states
=
controlne
t_model_input
,
timestep
=
timestep
,
timestep
=
timestep
,
encoder_hidden_states
=
prompt_embed
s
,
encoder_hidden_states
=
controlnet_encoder_hidden_state
s
,
pooled_projections
=
controlnet_pooled_projections
,
pooled_projections
=
controlnet_pooled_projections
,
joint_attention_kwargs
=
self
.
joint_attention_kwargs
,
joint_attention_kwargs
=
self
.
joint_attention_kwargs
,
controlnet_cond
=
control_image
,
controlnet_cond
=
control_image
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment