Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d1e20be6
Commit
d1e20be6
authored
Aug 30, 2023
by
Patrick von Platen
Browse files
make style
parent
af3854d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
13 deletions
+14
-13
examples/community/masked_stable_diffusion_img2img.py
examples/community/masked_stable_diffusion_img2img.py
+14
-13
No files found.
examples/community/masked_stable_diffusion_img2img.py
View file @
d1e20be6
from
typing
import
Optional
,
Union
,
List
,
Callable
,
Dict
,
Any
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
torch
import
torch
from
diffusers
import
StableDiffusionImg2ImgPipeline
from
diffusers
import
StableDiffusionImg2ImgPipeline
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
class
MaskedStableDiffusionImg2ImgPipeline
(
StableDiffusionImg2ImgPipeline
):
class
MaskedStableDiffusionImg2ImgPipeline
(
StableDiffusionImg2ImgPipeline
):
debug_save
=
False
debug_save
=
False
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
callback_steps
:
int
=
1
,
callback_steps
:
int
=
1
,
cross_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
cross_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mask
:
Union
[
mask
:
Union
[
torch
.
FloatTensor
,
torch
.
FloatTensor
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
np
.
ndarray
,
np
.
ndarray
,
List
[
torch
.
FloatTensor
],
List
[
torch
.
FloatTensor
],
List
[
PIL
.
Image
.
Image
],
List
[
PIL
.
Image
.
Image
],
List
[
np
.
ndarray
],
List
[
np
.
ndarray
],
]
=
None
,
]
=
None
,
):
):
r
"""
r
"""
The call function to the pipeline for generation.
The call function to the pipeline for generation.
...
@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
# mean of the latent distribution
# mean of the latent distribution
init_latents
=
[
init_latents
=
[
self
.
vae
.
encode
(
image
.
to
(
device
=
device
,
dtype
=
prompt_embeds
.
dtype
)[
i
:
i
+
1
]).
latent_dist
.
mean
for
i
in
range
(
batch_size
)
self
.
vae
.
encode
(
image
.
to
(
device
=
device
,
dtype
=
prompt_embeds
.
dtype
)[
i
:
i
+
1
]).
latent_dist
.
mean
for
i
in
range
(
batch_size
)
]
]
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
...
@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
latents
=
torch
.
lerp
(
init_latents
*
self
.
vae
.
config
.
scaling_factor
,
latents
,
latent_mask
)
latents
=
torch
.
lerp
(
init_latents
*
self
.
vae
.
config
.
scaling_factor
,
latents
,
latent_mask
)
noise_pred
=
torch
.
lerp
(
torch
.
zeros_like
(
noise_pred
),
noise_pred
,
latent_mask
)
noise_pred
=
torch
.
lerp
(
torch
.
zeros_like
(
noise_pred
),
noise_pred
,
latent_mask
)
# compute the previous noisy sample x_t -> x_t-1
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
,
return_dict
=
False
)[
0
]
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
,
return_dict
=
False
)[
0
]
# call the callback, if provided
# call the callback, if provided
...
@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
def
_make_latent_mask
(
self
,
latents
,
mask
):
def
_make_latent_mask
(
self
,
latents
,
mask
):
if
mask
is
not
None
:
if
mask
is
not
None
:
latent_mask
=
list
()
latent_mask
=
[]
if
not
isinstance
(
mask
,
list
):
if
not
isinstance
(
mask
,
list
):
tmp_mask
=
[
mask
]
tmp_mask
=
[
mask
]
else
:
else
:
...
@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
m
=
m
/
255.0
m
=
m
/
255.0
m
=
self
.
image_processor
.
numpy_to_pil
(
m
)[
0
]
m
=
self
.
image_processor
.
numpy_to_pil
(
m
)[
0
]
if
m
.
mode
!=
"L"
:
if
m
.
mode
!=
"L"
:
m
=
m
.
convert
(
'L'
)
m
=
m
.
convert
(
"L"
)
resized
=
self
.
image_processor
.
resize
(
m
,
l_height
,
l_width
)
resized
=
self
.
image_processor
.
resize
(
m
,
l_height
,
l_width
)
if
self
.
debug_save
:
if
self
.
debug_save
:
resized
.
save
(
"latent_mask.png"
)
resized
.
save
(
"latent_mask.png"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment