Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
3cfe187d
Unverified
Commit
3cfe187d
authored
Apr 19, 2024
by
Dhruv Nair
Committed by
GitHub
Apr 18, 2024
Browse files
Cleanup ControlnetXS (#7701)
* update * update
parent
90250d9e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
35 deletions
+61
-35
src/diffusers/models/controlnet_xs.py
src/diffusers/models/controlnet_xs.py
+29
-6
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
...s/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+0
-29
tests/pipelines/controlnet_xs/test_controlnetxs.py
tests/pipelines/controlnet_xs/test_controlnetxs.py
+32
-0
No files found.
src/diffusers/models/controlnet_xs.py
View file @
3cfe187d
...
...
@@ -22,7 +22,14 @@ from torch import FloatTensor, nn
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
is_torch_version
,
logging
from
..utils.torch_utils
import
apply_freeu
from
.attention_processor
import
Attention
,
AttentionProcessor
from
.attention_processor
import
(
ADDED_KV_ATTENTION_PROCESSORS
,
CROSS_ATTENTION_PROCESSORS
,
Attention
,
AttentionProcessor
,
AttnAddedKVProcessor
,
AttnProcessor
,
)
from
.controlnet
import
ControlNetConditioningEmbedding
from
.embeddings
import
TimestepEmbedding
,
Timesteps
from
.modeling_utils
import
ModelMixin
...
...
@@ -869,7 +876,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
return
processors
#
c
opied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
#
C
opied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.set_attn_processor
def
set_attn_processor
(
self
,
processor
:
Union
[
AttentionProcessor
,
Dict
[
str
,
AttentionProcessor
]]):
r
"""
Sets the attention processor to use to compute attention.
...
...
@@ -904,7 +911,23 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def
set_default_attn_processor
(
self
):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if
all
(
proc
.
__class__
in
ADDED_KV_ATTENTION_PROCESSORS
for
proc
in
self
.
attn_processors
.
values
()):
processor
=
AttnAddedKVProcessor
()
elif
all
(
proc
.
__class__
in
CROSS_ATTENTION_PROCESSORS
for
proc
in
self
.
attn_processors
.
values
()):
processor
=
AttnProcessor
()
else
:
raise
ValueError
(
f
"Cannot call `set_default_attn_processor` when attention processors are of type
{
next
(
iter
(
self
.
attn_processors
.
values
()))
}
"
)
self
.
set_attn_processor
(
processor
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def
enable_freeu
(
self
,
s1
:
float
,
s2
:
float
,
b1
:
float
,
b2
:
float
):
r
"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
...
...
@@ -929,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
setattr
(
upsample_block
,
"b1"
,
b1
)
setattr
(
upsample_block
,
"b2"
,
b2
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.disable_freeu
def
disable_freeu
(
self
):
"""Disables the FreeU mechanism."""
freeu_keys
=
{
"s1"
,
"s2"
,
"b1"
,
"b2"
}
...
...
@@ -938,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if
hasattr
(
upsample_block
,
k
)
or
getattr
(
upsample_block
,
k
,
None
)
is
not
None
:
setattr
(
upsample_block
,
k
,
None
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.fuse_qkv_projections
def
fuse_qkv_projections
(
self
):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
...
...
@@ -962,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if
isinstance
(
module
,
Attention
):
module
.
fuse_projections
(
fuse
=
True
)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
.unfuse_qkv_projections
def
unfuse_qkv_projections
(
self
):
"""Disables the fused QKV projection if enabled.
...
...
src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
View file @
3cfe187d
...
...
@@ -41,7 +41,6 @@ from ...models.lora import adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
USE_PEFT_BACKEND
,
deprecate
,
logging
,
replace_example_docstring
,
scale_lora_layers
,
...
...
@@ -462,7 +461,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt
,
prompt_2
,
image
,
callback_steps
,
negative_prompt
=
None
,
negative_prompt_2
=
None
,
prompt_embeds
=
None
,
...
...
@@ -474,12 +472,6 @@ class StableDiffusionXLControlNetXSPipeline(
control_guidance_end
=
1.0
,
callback_on_step_end_tensor_inputs
=
None
,
):
if
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
):
raise
ValueError
(
f
"`callback_steps` has to be a positive integer but is
{
callback_steps
}
of type"
f
"
{
type
(
callback_steps
)
}
."
)
if
callback_on_step_end_tensor_inputs
is
not
None
and
not
all
(
k
in
self
.
_callback_tensor_inputs
for
k
in
callback_on_step_end_tensor_inputs
):
...
...
@@ -749,7 +741,6 @@ class StableDiffusionXLControlNetXSPipeline(
clip_skip
:
Optional
[
int
]
=
None
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
**
kwargs
,
):
r
"""
The call function to the pipeline for generation.
...
...
@@ -878,22 +869,6 @@ class StableDiffusionXLControlNetXSPipeline(
returned, otherwise a `tuple` is returned containing the output images.
"""
callback
=
kwargs
.
pop
(
"callback"
,
None
)
callback_steps
=
kwargs
.
pop
(
"callback_steps"
,
None
)
if
callback
is
not
None
:
deprecate
(
"callback"
,
"1.0.0"
,
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`"
,
)
if
callback_steps
is
not
None
:
deprecate
(
"callback_steps"
,
"1.0.0"
,
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`"
,
)
unet
=
self
.
unet
.
_orig_mod
if
is_compiled_module
(
self
.
unet
)
else
self
.
unet
# 1. Check inputs. Raise error if not correct
...
...
@@ -901,7 +876,6 @@ class StableDiffusionXLControlNetXSPipeline(
prompt
,
prompt_2
,
image
,
callback_steps
,
negative_prompt
,
negative_prompt_2
,
prompt_embeds
,
...
...
@@ -1089,9 +1063,6 @@ class StableDiffusionXLControlNetXSPipeline(
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
progress_bar
.
update
()
if
callback
is
not
None
and
i
%
callback_steps
==
0
:
step_idx
=
i
//
getattr
(
self
.
scheduler
,
"order"
,
1
)
callback
(
step_idx
,
t
,
latents
)
# manually for max memory savings
if
self
.
vae
.
dtype
==
torch
.
float16
and
self
.
vae
.
config
.
force_upcast
:
...
...
tests/pipelines/controlnet_xs/test_controlnetxs.py
View file @
3cfe187d
...
...
@@ -69,6 +69,13 @@ from ..test_pipelines_common import (
enable_full_determinism
()
def
to_np
(
tensor
):
if
isinstance
(
tensor
,
torch
.
Tensor
):
tensor
=
tensor
.
detach
().
cpu
().
numpy
()
return
tensor
# Will be run via run_test_in_subprocess
def
_test_stable_diffusion_compile
(
in_queue
,
out_queue
,
timeout
):
error
=
None
...
...
@@ -299,6 +306,31 @@ class ControlNetXSPipelineFastTests(
assert
out_vae_np
.
shape
==
out_np
.
shape
@
unittest
.
skipIf
(
torch_device
!=
"cuda"
,
reason
=
"CUDA and CPU are required to switch devices"
)
def
test_to_device
(
self
):
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
to
(
"cpu"
)
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
model_devices
=
[
component
.
device
.
type
for
component
in
pipe
.
components
.
values
()
if
hasattr
(
component
,
"device"
)
]
self
.
assertTrue
(
all
(
device
==
"cpu"
for
device
in
model_devices
))
output_cpu
=
pipe
(
**
self
.
get_dummy_inputs
(
"cpu"
))[
0
]
self
.
assertTrue
(
np
.
isnan
(
output_cpu
).
sum
()
==
0
)
pipe
.
to
(
"cuda"
)
model_devices
=
[
component
.
device
.
type
for
component
in
pipe
.
components
.
values
()
if
hasattr
(
component
,
"device"
)
]
self
.
assertTrue
(
all
(
device
==
"cuda"
for
device
in
model_devices
))
output_cuda
=
pipe
(
**
self
.
get_dummy_inputs
(
"cuda"
))[
0
]
self
.
assertTrue
(
np
.
isnan
(
to_np
(
output_cuda
)).
sum
()
==
0
)
@
slow
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment