Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3cfe187d
Unverified
Commit
3cfe187d
authored
Apr 19, 2024
by
Dhruv Nair
Committed by
GitHub
Apr 18, 2024
Browse files
Cleanup ControlnetXS (#7701)
* update * update
parent
90250d9e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
35 deletions
+61
-35
src/diffusers/models/controlnet_xs.py
src/diffusers/models/controlnet_xs.py
+29
-6
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
...s/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+0
-29
tests/pipelines/controlnet_xs/test_controlnetxs.py
tests/pipelines/controlnet_xs/test_controlnetxs.py
+32
-0
No files found.
src/diffusers/models/controlnet_xs.py
View file @
3cfe187d
...
@@ -22,7 +22,14 @@ from torch import FloatTensor, nn
...
@@ -22,7 +22,14 @@ from torch import FloatTensor, nn
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
is_torch_version
,
logging
from
..utils
import
BaseOutput
,
is_torch_version
,
logging
from
..utils.torch_utils
import
apply_freeu
from
..utils.torch_utils
import
apply_freeu
from
.attention_processor
import
Attention
,
AttentionProcessor
from
.attention_processor
import
(
ADDED_KV_ATTENTION_PROCESSORS
,
CROSS_ATTENTION_PROCESSORS
,
Attention
,
AttentionProcessor
,
AttnAddedKVProcessor
,
AttnProcessor
,
)
from
.controlnet
import
ControlNetConditioningEmbedding
from
.controlnet
import
ControlNetConditioningEmbedding
from
.embeddings
import
TimestepEmbedding
,
Timesteps
from
.embeddings
import
TimestepEmbedding
,
Timesteps
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
...
@@ -869,7 +876,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
...
@@ -869,7 +876,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
return
processors
return
processors
#
c
opied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
#
C
opied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.set_attn_processor
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]]):
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]]):
r
"""
r
"""
Sets the attention processor to use to compute attention.
Sets the attention processor to use to compute attention.
...
@@ -904,7 +911,23 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
...
@@ -904,7 +911,23 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
for
name
,
module
in
self
.
named_children
():
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
fn_recursive_attn_processor
(
name
,
module
,
processor
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def
set_default_attn_processor
(
self
):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if
all
(
proc
.
__class__
in
ADDED_KV_ATTENTION_PROCESSORS
for
proc
in
self
.
attn_processors
.
values
()):
processor
=
AttnAddedKVProcessor
()
elif
all
(
proc
.
__class__
in
CROSS_ATTENTION_PROCESSORS
for
proc
in
self
.
attn_processors
.
values
()):
processor
=
AttnProcessor
()
else
:
raise
ValueError
(
f
"Cannot call `set_default_attn_processor` when attention processors are of type
{
next
(
iter
(
self
.
attn_processors
.
values
()))
}
"
)
self
.
set_attn_processor
(
processor
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def
enable_freeu
(
self
,
s1
:
float
,
s2
:
float
,
b1
:
float
,
b2
:
float
):
def
enable_freeu
(
self
,
s1
:
float
,
s2
:
float
,
b1
:
float
,
b2
:
float
):
r
"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
r
"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
...
@@ -929,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
...
@@ -929,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
setattr
(
upsample_block
,
"b1"
,
b1
)
setattr
(
upsample_block
,
"b1"
,
b1
)
setattr
(
upsample_block
,
"b2"
,
b2
)
setattr
(
upsample_block
,
"b2"
,
b2
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.disable_freeu
def
disable_freeu
(
self
):
def
disable_freeu
(
self
):
"""Disables the FreeU mechanism."""
"""Disables the FreeU mechanism."""
freeu_keys
=
{
"s1"
,
"s2"
,
"b1"
,
"b2"
}
freeu_keys
=
{
"s1"
,
"s2"
,
"b1"
,
"b2"
}
...
@@ -938,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
...
@@ -938,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if
hasattr
(
upsample_block
,
k
)
or
getattr
(
upsample_block
,
k
,
None
)
is
not
None
:
if
hasattr
(
upsample_block
,
k
)
or
getattr
(
upsample_block
,
k
,
None
)
is
not
None
:
setattr
(
upsample_block
,
k
,
None
)
setattr
(
upsample_block
,
k
,
None
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.fuse_qkv_projections
def
fuse_qkv_projections
(
self
):
def
fuse_qkv_projections
(
self
):
"""
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
...
@@ -962,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
...
@@ -962,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if
isinstance
(
module
,
Attention
):
if
isinstance
(
module
,
Attention
):
module
.
fuse_projections
(
fuse
=
True
)
module
.
fuse_projections
(
fuse
=
True
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.unfuse_qkv_projections
def
unfuse_qkv_projections
(
self
):
def
unfuse_qkv_projections
(
self
):
"""Disables the fused QKV projection if enabled.
"""Disables the fused QKV projection if enabled.
...
...
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
View file @
3cfe187d
...
@@ -41,7 +41,6 @@ from ...models.lora import adjust_lora_scale_text_encoder
...
@@ -41,7 +41,6 @@ from ...models.lora import adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
USE_PEFT_BACKEND
,
USE_PEFT_BACKEND
,
deprecate
,
logging
,
logging
,
replace_example_docstring
,
replace_example_docstring
,
scale_lora_layers
,
scale_lora_layers
,
...
@@ -462,7 +461,6 @@ class StableDiffusionXLControlNetXSPipeline(
...
@@ -462,7 +461,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt
,
prompt
,
prompt_2
,
prompt_2
,
image
,
image
,
callback_steps
,
negative_prompt
=
None
,
negative_prompt
=
None
,
negative_prompt_2
=
None
,
negative_prompt_2
=
None
,
prompt_embeds
=
None
,
prompt_embeds
=
None
,
...
@@ -474,12 +472,6 @@ class StableDiffusionXLControlNetXSPipeline(
...
@@ -474,12 +472,6 @@ class StableDiffusionXLControlNetXSPipeline(
control_guidance_end
=
1.0
,
control_guidance_end
=
1.0
,
callback_on_step_end_tensor_inputs
=
None
,
callback_on_step_end_tensor_inputs
=
None
,
):
):
if
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
if
callback_on_step_end_tensor_inputs
is
not
None
and
not
all
(
if
callback_on_step_end_tensor_inputs
is
not
None
and
not
all
(
k
in
self
.
_callback_tensor_inputs
for
k
in
callback_on_step_end_tensor_inputs
k
in
self
.
_callback_tensor_inputs
for
k
in
callback_on_step_end_tensor_inputs
):
):
...
@@ -749,7 +741,6 @@ class StableDiffusionXLControlNetXSPipeline(
...
@@ -749,7 +741,6 @@ class StableDiffusionXLControlNetXSPipeline(
clip_skip
:
Optional
[
int
]
=
None
,
clip_skip
:
Optional
[
int
]
=
None
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
**
kwargs
,
):
):
r
"""
r
"""
The call function to the pipeline for generation.
The call function to the pipeline for generation.
...
@@ -878,22 +869,6 @@ class StableDiffusionXLControlNetXSPipeline(
...
@@ -878,22 +869,6 @@ class StableDiffusionXLControlNetXSPipeline(
returned, otherwise a `tuple` is returned containing the output images.
returned, otherwise a `tuple` is returned containing the output images.
"""
"""
callback
=
kwargs
.
pop
(
"callback"
,
None
)
callback_steps
=
kwargs
.
pop
(
"callback_steps"
,
None
)
if
callback
is
not
None
:
deprecate
(
"callback"
,
"1.0.0"
,
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`"
,
)
if
callback_steps
is
not
None
:
deprecate
(
"callback_steps"
,
"1.0.0"
,
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`"
,
)
unet
=
self
.
unet
.
_orig_mod
if
is_compiled_module
(
self
.
unet
)
else
self
.
unet
unet
=
self
.
unet
.
_orig_mod
if
is_compiled_module
(
self
.
unet
)
else
self
.
unet
# 1. Check inputs. Raise error if not correct
# 1. Check inputs. Raise error if not correct
...
@@ -901,7 +876,6 @@ class StableDiffusionXLControlNetXSPipeline(
...
@@ -901,7 +876,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt
,
prompt
,
prompt_2
,
prompt_2
,
image
,
image
,
callback_steps
,
negative_prompt
,
negative_prompt
,
negative_prompt_2
,
negative_prompt_2
,
prompt_embeds
,
prompt_embeds
,
...
@@ -1089,9 +1063,6 @@ class StableDiffusionXLControlNetXSPipeline(
...
@@ -1089,9 +1063,6 @@ class StableDiffusionXLControlNetXSPipeline(
# call the callback, if provided
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
# manually for max memory savings
# manually for max memory savings
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
...
...
tests/pipelines/controlnet_xs/test_controlnetxs.py
View file @
3cfe187d
...
@@ -69,6 +69,13 @@ from ..test_pipelines_common import (
...
@@ -69,6 +69,13 @@ from ..test_pipelines_common import (
enable_full_determinism
()
enable_full_determinism
()
def
to_np
(
tensor
):
if
isinstance
(
tensor
,
torch
.
Tensor
):
tensor
=
tensor
.
detach
().
cpu
().
numpy
()
return
tensor
# Will be run via run_test_in_subprocess
# Will be run via run_test_in_subprocess
def
_test_stable_diffusion_compile
(
in_queue
,
out_queue
,
timeout
):
def
_test_stable_diffusion_compile
(
in_queue
,
out_queue
,
timeout
):
error
=
None
error
=
None
...
@@ -299,6 +306,31 @@ class ControlNetXSPipelineFastTests(
...
@@ -299,6 +306,31 @@ class ControlNetXSPipelineFastTests(
assert
out_vae_np
.
shape
==
out_np
.
shape
assert
out_vae_np
.
shape
==
out_np
.
shape
@
unittest
.
skipIf
(
torch_device
!=
"cuda"
,
reason
=
"CUDA and CPU are required to switch devices"
)
def
test_to_device
(
self
):
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
to
(
"cpu"
)
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
model_devices
=
[
component
.
device
.
type
for
component
in
pipe
.
components
.
values
()
if
hasattr
(
component
,
"device"
)
]
self
.
assertTrue
(
all
(
device
==
"cpu"
for
device
in
model_devices
))
output_cpu
=
pipe
(
**
self
.
get_dummy_inputs
(
"cpu"
))[
0
]
self
.
assertTrue
(
np
.
isnan
(
output_cpu
).
sum
()
==
0
)
pipe
.
to
(
"cuda"
)
model_devices
=
[
component
.
device
.
type
for
component
in
pipe
.
components
.
values
()
if
hasattr
(
component
,
"device"
)
]
self
.
assertTrue
(
all
(
device
==
"cuda"
for
device
in
model_devices
))
output_cuda
=
pipe
(
**
self
.
get_dummy_inputs
(
"cuda"
))[
0
]
self
.
assertTrue
(
np
.
isnan
(
to_np
(
output_cuda
)).
sum
()
==
0
)
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment