Commit d1e20be6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent af3854d6
from typing import Optional, Union, List, Callable, Dict, Any
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
debug_save = False
@torch.no_grad()
......@@ -158,7 +158,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
# mean of the latent distribution
init_latents = [
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean for i in range(batch_size)
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
......@@ -236,7 +237,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
def _make_latent_mask(self, latents, mask):
if mask is not None:
latent_mask = list()
latent_mask = []
if not isinstance(mask, list):
tmp_mask = [mask]
else:
......@@ -250,7 +251,7 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
m = m / 255.0
m = self.image_processor.numpy_to_pil(m)[0]
if m.mode != "L":
m = m.convert('L')
m = m.convert("L")
resized = self.image_processor.resize(m, l_height, l_width)
if self.debug_save:
resized.save("latent_mask.png")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment