"...text-generation-inference.git" did not exist on "53ee09c0b0004777f029f594ce44cffa6350ed08"
Unverified Commit 79ea8eb2 authored by Ishan Modi's avatar Ishan Modi Committed by GitHub
Browse files

[BUG] fixes in kadinsky pipeline (#11080)

* bug fix kadinsky pipeline
parent e7f3a737
...@@ -116,6 +116,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -116,6 +116,7 @@ class VaeImageProcessor(ConfigMixin):
vae_scale_factor: int = 8, vae_scale_factor: int = 8,
vae_latent_channels: int = 4, vae_latent_channels: int = 4,
resample: str = "lanczos", resample: str = "lanczos",
reducing_gap: int = None,
do_normalize: bool = True, do_normalize: bool = True,
do_binarize: bool = False, do_binarize: bool = False,
do_convert_rgb: bool = False, do_convert_rgb: bool = False,
...@@ -498,7 +499,11 @@ class VaeImageProcessor(ConfigMixin): ...@@ -498,7 +499,11 @@ class VaeImageProcessor(ConfigMixin):
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}") raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
if resize_mode == "default": if resize_mode == "default":
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) image = image.resize(
(width, height),
resample=PIL_INTERPOLATION[self.config.resample],
reducing_gap=self.config.reducing_gap,
)
elif resize_mode == "fill": elif resize_mode == "fill":
image = self._resize_and_fill(image, width, height) image = self._resize_and_fill(image, width, height)
elif resize_mode == "crop": elif resize_mode == "crop":
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from PIL import Image
from transformers import ( from transformers import (
XLMRobertaTokenizer, XLMRobertaTokenizer,
) )
from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import ( from ...utils import (
...@@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8): ...@@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8):
return new_h * scale_factor, new_w * scale_factor return new_h * scale_factor, new_w * scale_factor
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyImg2ImgPipeline(DiffusionPipeline): class KandinskyImg2ImgPipeline(DiffusionPipeline):
""" """
Pipeline for image-to-image generation using Kandinsky Pipeline for image-to-image generation using Kandinsky
...@@ -143,7 +133,16 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): ...@@ -143,7 +133,16 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
movq=movq, movq=movq,
) )
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) self.movq_scale_factor = (
2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
)
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
...@@ -417,7 +416,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): ...@@ -417,7 +416,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
) )
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device) image = image.to(dtype=prompt_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"] latents = self.movq.encode(image)["latents"]
...@@ -498,13 +497,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): ...@@ -498,13 +497,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
if output_type not in ["pt", "np", "pil"]: if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]: image = self.image_processor.postprocess(image, output_type)
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from PIL import Image
from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
...@@ -105,27 +104,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -105,27 +104,6 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
""" """
Pipeline for image-to-image generation using Kandinsky Pipeline for image-to-image generation using Kandinsky
...@@ -157,7 +135,14 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): ...@@ -157,7 +135,14 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
movq=movq, movq=movq,
) )
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
...@@ -316,7 +301,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): ...@@ -316,7 +301,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
) )
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device) image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"] latents = self.movq.encode(image)["latents"]
...@@ -324,7 +309,6 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): ...@@ -324,7 +309,6 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents( latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
) )
...@@ -379,13 +363,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): ...@@ -379,13 +363,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
if output_type not in ["pt", "np", "pil"]: if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]: image = self.image_processor.postprocess(image, output_type)
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from PIL import Image
from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import deprecate, is_torch_xla_available, logging from ...utils import deprecate, is_torch_xla_available, logging
...@@ -76,27 +75,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -76,27 +75,6 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyV22Img2ImgPipeline(DiffusionPipeline): class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
""" """
Pipeline for image-to-image generation using Kandinsky Pipeline for image-to-image generation using Kandinsky
...@@ -129,7 +107,14 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -129,7 +107,14 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
movq=movq, movq=movq,
) )
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
...@@ -319,7 +304,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -319,7 +304,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
) )
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device) image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"] latents = self.movq.encode(image)["latents"]
...@@ -327,7 +312,6 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -327,7 +312,6 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents( latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
) )
...@@ -383,21 +367,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline): ...@@ -383,21 +367,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent": if not output_type == "latent":
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type in ["np", "pil"]: image = self.image_processor.postprocess(image, output_type)
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
else: else:
image = latents image = latents
......
import inspect import inspect
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL import PIL
import PIL.Image import PIL.Image
import torch import torch
from transformers import T5EncoderModel, T5Tokenizer from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin from ...loaders import StableDiffusionLoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
...@@ -53,24 +53,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -53,24 +53,6 @@ EXAMPLE_DOC_STRING = """
""" """
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->movq->unet->movq" model_cpu_offload_seq = "text_encoder->movq->unet->movq"
_callback_tensor_inputs = [ _callback_tensor_inputs = [
...@@ -94,6 +76,14 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi ...@@ -94,6 +76,14 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
self.register_modules( self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
) )
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep # get the original timestep using init_timestep
...@@ -566,7 +556,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi ...@@ -566,7 +556,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
) )
image = torch.cat([prepare_image(i) for i in image], dim=0) image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device) image = image.to(dtype=prompt_embeds.dtype, device=device)
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -630,20 +620,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi ...@@ -630,20 +620,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
xm.mark_step() xm.mark_step()
# post-processing # post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent": if not output_type == "latent":
image = self.movq.decode(latents, force_not_quantize=True)["sample"] image = self.movq.decode(latents, force_not_quantize=True)["sample"]
image = self.image_processor.postprocess(image, output_type)
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
else: else:
image = latents image = latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment