Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
57239dac
Commit
57239dac
authored
Oct 16, 2023
by
Patrick von Platen
Browse files
make style
parent
de12776b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
16 deletions
+41
-16
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+9
-7
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+32
-9
No files found.
src/diffusers/models/unet_2d_condition.py
View file @
57239dac
...
...
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..loaders
import
UNet2DConditionLoadersMixin
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
logging
,
deprecate
,
scale_lora_layers
,
unscale_lora_layers
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
.activations
import
get_activation
from
.attention_processor
import
(
ADDED_KV_ATTENTION_PROCESSORS
,
...
...
@@ -824,8 +824,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks
for
example from ControlNet side model(s)
additional residuals to be added to UNet long skip connections from down blocks to up blocks
for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
...
...
@@ -1014,12 +1014,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
if
not
is_adapter
and
mid_block_additional_residual
is
None
and
down_block_additional_residuals
is
not
None
:
deprecate
(
"T2I should not use down_block_additional_residuals"
,
"1.3.0"
,
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated
\
deprecate
(
"T2I should not use down_block_additional_residuals"
,
"1.3.0"
,
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated
\
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used
\
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. "
,
standard_warn
=
False
)
standard_warn
=
False
,
)
down_intrablock_additional_residuals
=
down_block_additional_residuals
is_adapter
=
True
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
57239dac
...
...
@@ -987,6 +987,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
added_cond_kwargs
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
,
down_block_additional_residuals
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
mid_block_additional_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
down_intrablock_additional_residuals
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
UNet2DConditionOutput
,
Tuple
]:
...
...
@@ -1031,6 +1032,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
...
...
@@ -1216,15 +1224,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
scale_lora_layers
(
self
,
lora_scale
)
is_controlnet
=
mid_block_additional_residual
is
not
None
and
down_block_additional_residuals
is
not
None
is_adapter
=
mid_block_additional_residual
is
None
and
down_block_additional_residuals
is
not
None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter
=
down_intrablock_additional_residuals
is
not
None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
if
not
is_adapter
and
mid_block_additional_residual
is
None
and
down_block_additional_residuals
is
not
None
:
deprecate
(
"T2I should not use down_block_additional_residuals"
,
"1.3.0"
,
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
" and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only"
" be used for ControlNet. Please make sure use"
" `down_intrablock_additional_residuals` instead. "
,
standard_warn
=
False
,
)
down_intrablock_additional_residuals
=
down_block_additional_residuals
is_adapter
=
True
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
down_blocks
:
if
hasattr
(
downsample_block
,
"has_cross_attention"
)
and
downsample_block
.
has_cross_attention
:
# For t2i-adapter CrossAttnDownBlockFlat
additional_residuals
=
{}
if
is_adapter
and
len
(
down_block_additional_residuals
)
>
0
:
additional_residuals
[
"additional_residuals"
]
=
down_block_additional_residuals
.
pop
(
0
)
if
is_adapter
and
len
(
down_
intra
block_additional_residuals
)
>
0
:
additional_residuals
[
"additional_residuals"
]
=
down_
intra
block_additional_residuals
.
pop
(
0
)
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
...
...
@@ -1237,9 +1261,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
)
else
:
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
scale
=
lora_scale
)
if
is_adapter
and
len
(
down_block_additional_residuals
)
>
0
:
sample
+=
down_block_additional_residuals
.
pop
(
0
)
if
is_adapter
and
len
(
down_intrablock_additional_residuals
)
>
0
:
sample
+=
down_intrablock_additional_residuals
.
pop
(
0
)
down_block_res_samples
+=
res_samples
...
...
@@ -1267,10 +1290,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# To support T2I-Adapter-XL
if
(
is_adapter
and
len
(
down_block_additional_residuals
)
>
0
and
sample
.
shape
==
down_block_additional_residuals
[
0
].
shape
and
len
(
down_
intra
block_additional_residuals
)
>
0
and
sample
.
shape
==
down_
intra
block_additional_residuals
[
0
].
shape
):
sample
+=
down_block_additional_residuals
.
pop
(
0
)
sample
+=
down_
intra
block_additional_residuals
.
pop
(
0
)
if
is_controlnet
:
sample
=
sample
+
mid_block_additional_residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment