Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b756ec6e
Unverified
Commit
b756ec6e
authored
Dec 20, 2024
by
djm
Committed by
GitHub
Dec 19, 2024
Browse files
unet's `sample_size` attribute is to accept tuple(h, w) in `StableDiffusionPipeline` (#10181)
parent
d8825e76
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
4 deletions
+27
-4
src/diffusers/models/unets/unet_2d_condition.py
src/diffusers/models/unets/unet_2d_condition.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+18
-3
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+8
-0
No files found.
src/diffusers/models/unets/unet_2d_condition.py
View file @
b756ec6e
...
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
...
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
sample_size
:
Optional
[
int
]
=
None
,
sample_size
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]
]
=
None
,
in_channels
:
int
=
4
,
in_channels
:
int
=
4
,
out_channels
:
int
=
4
,
out_channels
:
int
=
4
,
center_input_sample
:
bool
=
False
,
center_input_sample
:
bool
=
False
,
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
b756ec6e
...
@@ -255,7 +255,12 @@ class StableDiffusionPipeline(
...
@@ -255,7 +255,12 @@ class StableDiffusionPipeline(
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
is_unet_version_less_0_9_0
=
hasattr
(
unet
.
config
,
"_diffusers_version"
)
and
version
.
parse
(
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
version
.
parse
(
unet
.
config
.
_diffusers_version
).
base_version
)
<
version
.
parse
(
"0.9.0.dev0"
)
)
<
version
.
parse
(
"0.9.0.dev0"
)
is_unet_sample_size_less_64
=
hasattr
(
unet
.
config
,
"sample_size"
)
and
unet
.
config
.
sample_size
<
64
self
.
_is_unet_config_sample_size_int
=
isinstance
(
unet
.
config
.
sample_size
,
int
)
is_unet_sample_size_less_64
=
(
hasattr
(
unet
.
config
,
"sample_size"
)
and
self
.
_is_unet_config_sample_size_int
and
unet
.
config
.
sample_size
<
64
)
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
if
is_unet_version_less_0_9_0
and
is_unet_sample_size_less_64
:
deprecation_message
=
(
deprecation_message
=
(
"The configuration file of the unet has set the default `sample_size` to smaller than"
"The configuration file of the unet has set the default `sample_size` to smaller than"
...
@@ -902,8 +907,18 @@ class StableDiffusionPipeline(
...
@@ -902,8 +907,18 @@ class StableDiffusionPipeline(
callback_on_step_end_tensor_inputs
=
callback_on_step_end
.
tensor_inputs
callback_on_step_end_tensor_inputs
=
callback_on_step_end
.
tensor_inputs
# 0. Default height and width to unet
# 0. Default height and width to unet
height
=
height
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
if
not
height
or
not
width
:
width
=
width
or
self
.
unet
.
config
.
sample_size
*
self
.
vae_scale_factor
height
=
(
self
.
unet
.
config
.
sample_size
if
self
.
_is_unet_config_sample_size_int
else
self
.
unet
.
config
.
sample_size
[
0
]
)
width
=
(
self
.
unet
.
config
.
sample_size
if
self
.
_is_unet_config_sample_size_int
else
self
.
unet
.
config
.
sample_size
[
1
]
)
height
,
width
=
height
*
self
.
vae_scale_factor
,
width
*
self
.
vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
# 1. Check inputs. Raise error if not correct
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
b756ec6e
...
@@ -840,6 +840,14 @@ class StableDiffusionPipelineFastTests(
...
@@ -840,6 +840,14 @@ class StableDiffusionPipelineFastTests(
# they should be the same
# they should be the same
assert
torch
.
allclose
(
intermediate_latent
,
output_interrupted
,
atol
=
1e-4
)
assert
torch
.
allclose
(
intermediate_latent
,
output_interrupted
,
atol
=
1e-4
)
def
test_pipeline_accept_tuple_type_unet_sample_size
(
self
):
# the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
sd_repo_id
=
"stable-diffusion-v1-5/stable-diffusion-v1-5"
sample_size
=
[
60
,
80
]
customised_unet
=
UNet2DConditionModel
(
sample_size
=
sample_size
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
sd_repo_id
,
unet
=
customised_unet
)
assert
pipe
.
unet
.
config
.
sample_size
==
sample_size
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment