Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
d1e20be6
Commit
d1e20be6
authored
Aug 30, 2023
by
Patrick von Platen
Browse files
make style
parent
af3854d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
13 deletions
+14
-13
examples/community/masked_stable_diffusion_img2img.py
examples/community/masked_stable_diffusion_img2img.py
+14
-13
No files found.
examples/community/masked_stable_diffusion_img2img.py
View file @
d1e20be6
from
typing
import
Optional
,
Union
,
List
,
Callable
,
Dict
,
Any
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
torch
import
torch
from
diffusers
import
StableDiffusionImg2ImgPipeline
from
diffusers
import
StableDiffusionImg2ImgPipeline
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionPipelineOutput
class
MaskedStableDiffusionImg2ImgPipeline
(
StableDiffusionImg2ImgPipeline
):
class
MaskedStableDiffusionImg2ImgPipeline
(
StableDiffusionImg2ImgPipeline
):
debug_save
=
False
debug_save
=
False
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
callback_steps
:
int
=
1
,
callback_steps
:
int
=
1
,
cross_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
cross_attention_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mask
:
Union
[
mask
:
Union
[
torch
.
FloatTensor
,
torch
.
FloatTensor
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
np
.
ndarray
,
np
.
ndarray
,
List
[
torch
.
FloatTensor
],
List
[
torch
.
FloatTensor
],
List
[
PIL
.
Image
.
Image
],
List
[
PIL
.
Image
.
Image
],
List
[
np
.
ndarray
],
List
[
np
.
ndarray
],
]
=
None
,
]
=
None
,
):
):
r
"""
r
"""
The call function to the pipeline for generation.
The call function to the pipeline for generation.
...
@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
# mean of the latent distribution
# mean of the latent distribution
init_latents
=
[
init_latents
=
[
self
.
vae
.
encode
(
image
.
to
(
device
=
device
,
dtype
=
prompt_embeds
.
dtype
)[
i
:
i
+
1
]).
latent_dist
.
mean
for
i
in
range
(
batch_size
)
self
.
vae
.
encode
(
image
.
to
(
device
=
device
,
dtype
=
prompt_embeds
.
dtype
)[
i
:
i
+
1
]).
latent_dist
.
mean
for
i
in
range
(
batch_size
)
]
]
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
init_latents
=
torch
.
cat
(
init_latents
,
dim
=
0
)
...
@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
latents
=
torch
.
lerp
(
init_latents
*
self
.
vae
.
config
.
scaling_factor
,
latents
,
latent_mask
)
latents
=
torch
.
lerp
(
init_latents
*
self
.
vae
.
config
.
scaling_factor
,
latents
,
latent_mask
)
noise_pred
=
torch
.
lerp
(
torch
.
zeros_like
(
noise_pred
),
noise_pred
,
latent_mask
)
noise_pred
=
torch
.
lerp
(
torch
.
zeros_like
(
noise_pred
),
noise_pred
,
latent_mask
)
# compute the previous noisy sample x_t -> x_t-1
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
,
return_dict
=
False
)[
0
]
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
,
return_dict
=
False
)[
0
]
# call the callback, if provided
# call the callback, if provided
...
@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
def
_make_latent_mask
(
self
,
latents
,
mask
):
def
_make_latent_mask
(
self
,
latents
,
mask
):
if
mask
is
not
None
:
if
mask
is
not
None
:
latent_mask
=
list
()
latent_mask
=
[]
if
not
isinstance
(
mask
,
list
):
if
not
isinstance
(
mask
,
list
):
tmp_mask
=
[
mask
]
tmp_mask
=
[
mask
]
else
:
else
:
...
@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
...
@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
m
=
m
/
255.0
m
=
m
/
255.0
m
=
self
.
image_processor
.
numpy_to_pil
(
m
)[
0
]
m
=
self
.
image_processor
.
numpy_to_pil
(
m
)[
0
]
if
m
.
mode
!=
"L"
:
if
m
.
mode
!=
"L"
:
m
=
m
.
convert
(
'L'
)
m
=
m
.
convert
(
"L"
)
resized
=
self
.
image_processor
.
resize
(
m
,
l_height
,
l_width
)
resized
=
self
.
image_processor
.
resize
(
m
,
l_height
,
l_width
)
if
self
.
debug_save
:
if
self
.
debug_save
:
resized
.
save
(
"latent_mask.png"
)
resized
.
save
(
"latent_mask.png"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment