Commit d1e20be6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent af3854d6
from typing import Optional, Union, List, Callable, Dict, Any from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
debug_save = False debug_save = False
@torch.no_grad() @torch.no_grad()
...@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -38,13 +38,13 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
mask: Union[ mask: Union[
torch.FloatTensor, torch.FloatTensor,
PIL.Image.Image, PIL.Image.Image,
np.ndarray, np.ndarray,
List[torch.FloatTensor], List[torch.FloatTensor],
List[PIL.Image.Image], List[PIL.Image.Image],
List[np.ndarray], List[np.ndarray],
] = None, ] = None,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
# mean of the latent distribution # mean of the latent distribution
init_latents = [ init_latents = [
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean for i in range(batch_size) self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
...@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -194,7 +195,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask) latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask)
noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask) noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
...@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
def _make_latent_mask(self, latents, mask): def _make_latent_mask(self, latents, mask):
if mask is not None: if mask is not None:
latent_mask = list() latent_mask = []
if not isinstance(mask, list): if not isinstance(mask, list):
tmp_mask = [mask] tmp_mask = [mask]
else: else:
...@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
m = m / 255.0 m = m / 255.0
m = self.image_processor.numpy_to_pil(m)[0] m = self.image_processor.numpy_to_pil(m)[0]
if m.mode != "L": if m.mode != "L":
m = m.convert('L') m = m.convert("L")
resized = self.image_processor.resize(m, l_height, l_width) resized = self.image_processor.resize(m, l_height, l_width)
if self.debug_save: if self.debug_save:
resized.save("latent_mask.png") resized.save("latent_mask.png")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment