Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
79329715
Unverified
Commit
79329715
authored
Dec 05, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 05, 2022
Browse files
[Upscaling] Fix batch size (#1525)
parent
720dbfc9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
2 deletions
+55
-2
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
...nes/stable_diffusion/pipeline_stable_diffusion_upscale.py
+4
-2
tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
...lines/stable_diffusion_2/test_stable_diffusion_upscale.py
+51
-0
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
View file @
79329715
...
@@ -459,8 +459,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
...
@@ -459,8 +459,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else
:
else
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
text_embeddings
.
dtype
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
text_embeddings
.
dtype
)
image
=
self
.
low_res_scheduler
.
add_noise
(
image
,
noise
,
noise_level
)
image
=
self
.
low_res_scheduler
.
add_noise
(
image
,
noise
,
noise_level
)
image
=
torch
.
cat
([
image
]
*
2
)
if
do_classifier_free_guidance
else
image
noise_level
=
torch
.
cat
([
noise_level
]
*
2
)
if
do_classifier_free_guidance
else
noise_level
batch_multiplier
=
2
if
do_classifier_free_guidance
else
1
image
=
torch
.
cat
([
image
]
*
batch_multiplier
*
num_images_per_prompt
)
noise_level
=
torch
.
cat
([
noise_level
]
*
image
.
shape
[
0
])
# 6. Prepare latent variables
# 6. Prepare latent variables
height
,
width
=
image
.
shape
[
2
:]
height
,
width
=
image
.
shape
[
2
:]
...
...
tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
View file @
79329715
...
@@ -161,6 +161,57 @@ class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.Test
...
@@ -161,6 +161,57 @@ class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.Test
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_stable_diffusion_upscale_batch
(
self
):
device
=
"cpu"
# ensure determinism for the device-dependent torch.Generator
unet
=
self
.
dummy_cond_unet_upscale
low_res_scheduler
=
DDPMScheduler
()
scheduler
=
DDIMScheduler
(
prediction_type
=
"v_prediction"
)
vae
=
self
.
dummy_vae
text_encoder
=
self
.
dummy_text_encoder
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-clip"
)
image
=
self
.
dummy_image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)[
0
]
low_res_image
=
Image
.
fromarray
(
np
.
uint8
(
image
)).
convert
(
"RGB"
).
resize
((
64
,
64
))
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionUpscalePipeline
(
unet
=
unet
,
low_res_scheduler
=
low_res_scheduler
,
scheduler
=
scheduler
,
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
max_noise_level
=
350
,
)
sd_pipe
=
sd_pipe
.
to
(
device
)
sd_pipe
.
set_progress_bar_config
(
disable
=
None
)
prompt
=
"A painting of a squirrel eating a burger"
output
=
sd_pipe
(
2
*
[
prompt
],
image
=
2
*
[
low_res_image
],
guidance_scale
=
6.0
,
noise_level
=
20
,
num_inference_steps
=
2
,
output_type
=
"np"
,
)
image
=
output
.
images
assert
image
.
shape
[
0
]
==
2
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
0
)
output
=
sd_pipe
(
[
prompt
],
image
=
low_res_image
,
generator
=
generator
,
num_images_per_prompt
=
2
,
guidance_scale
=
6.0
,
noise_level
=
20
,
num_inference_steps
=
2
,
output_type
=
"np"
,
)
image
=
output
.
images
assert
image
.
shape
[
0
]
==
2
@
unittest
.
skipIf
(
torch_device
!=
"cuda"
,
"This test requires a GPU"
)
@
unittest
.
skipIf
(
torch_device
!=
"cuda"
,
"This test requires a GPU"
)
def
test_stable_diffusion_upscale_fp16
(
self
):
def
test_stable_diffusion_upscale_fp16
(
self
):
"""Test that stable diffusion upscale works with fp16"""
"""Test that stable diffusion upscale works with fp16"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment