Unverified Commit 65d136e0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add improved handling of pil (#1309)

* Better error message for transformers dummy

* [PIL] Better deprecation functionality

* up
parent 46893ada
...@@ -39,9 +39,9 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff ...@@ -39,9 +39,9 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## LDMTextToImagePipeline ## LDMTextToImagePipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline [[autodoc]] LDMTextToImagePipeline
- __call__ - __call__
## LDMSuperResolutionPipeline ## LDMSuperResolutionPipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline [[autodoc]] LDMSuperResolutionPipeline
- __call__ - __call__
...@@ -50,6 +50,7 @@ available a colab notebook to directly try them out. ...@@ -50,6 +50,7 @@ available a colab notebook to directly try them out.
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
......
...@@ -40,6 +40,7 @@ available a colab notebook to directly try them out. ...@@ -40,6 +40,7 @@ available a colab notebook to directly try them out.
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
......
...@@ -17,7 +17,7 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -17,7 +17,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging from diffusers.utils import PIL_INTERPOLATION, logging
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
...@@ -28,7 +28,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -28,7 +28,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
......
...@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, is_accelerate_available, logging from diffusers.utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
...@@ -358,7 +358,7 @@ def get_weighted_text_embeddings( ...@@ -358,7 +358,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image): def preprocess_image(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
...@@ -369,7 +369,7 @@ def preprocess_mask(mask): ...@@ -369,7 +369,7 @@ def preprocess_mask(mask):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
......
...@@ -10,7 +10,7 @@ from diffusers.onnx_utils import OnnxRuntimeModel ...@@ -10,7 +10,7 @@ from diffusers.onnx_utils import OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging from diffusers.utils import PIL_INTERPOLATION, logging
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
...@@ -365,7 +365,7 @@ def get_weighted_text_embeddings( ...@@ -365,7 +365,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image): def preprocess_image(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0 return 2.0 * image - 1.0
...@@ -375,7 +375,7 @@ def preprocess_mask(mask): ...@@ -375,7 +375,7 @@ def preprocess_mask(mask):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
......
...@@ -12,13 +12,13 @@ import torch.nn.functional as F ...@@ -12,13 +12,13 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.utils.data import Dataset from torch.utils.data import Dataset
import PIL
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import PIL_INTERPOLATION
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
...@@ -260,10 +260,10 @@ class TextualInversionDataset(Dataset): ...@@ -260,10 +260,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats self._length = self.num_images * repeats
self.interpolation = { self.interpolation = {
"linear": PIL.Image.LINEAR, "linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL.Image.BILINEAR, "bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL.Image.BICUBIC, "bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL.Image.LANCZOS, "lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation] }[interpolation]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
......
...@@ -14,7 +14,6 @@ from torch.utils.data import Dataset ...@@ -14,7 +14,6 @@ from torch.utils.data import Dataset
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
import PIL
import transformers import transformers
from diffusers import ( from diffusers import (
FlaxAutoencoderKL, FlaxAutoencoderKL,
...@@ -24,6 +23,7 @@ from diffusers import ( ...@@ -24,6 +23,7 @@ from diffusers import (
FlaxUNet2DConditionModel, FlaxUNet2DConditionModel,
) )
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import PIL_INTERPOLATION
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
...@@ -246,10 +246,10 @@ class TextualInversionDataset(Dataset): ...@@ -246,10 +246,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats self._length = self.num_images * repeats
self.interpolation = { self.interpolation = {
"linear": PIL.Image.LINEAR, "linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL.Image.BILINEAR, "bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL.Image.BICUBIC, "bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL.Image.LANCZOS, "lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation] }[interpolation]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
......
...@@ -78,7 +78,7 @@ from setuptools import find_packages, setup ...@@ -78,7 +78,7 @@ from setuptools import find_packages, setup
# 1. all dependencies should be listed here with their version requirements if any # 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [ _deps = [
"Pillow<10.0", # keep the PIL.Image.Resampling deprecation away "Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0", "accelerate>=0.11.0",
"black==22.8", "black==22.8",
"datasets", "datasets",
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# 1. modify the `_deps` dict in setup.py # 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update`` # 2. run `make deps_table_update``
deps = { deps = {
"Pillow": "Pillow<10.0", "Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0", "accelerate": "accelerate>=0.11.0",
"black": "black==22.8", "black": "black==22.8",
"datasets": "datasets", "datasets": "datasets",
......
...@@ -33,7 +33,7 @@ from ...schedulers import ( ...@@ -33,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
...@@ -41,10 +41,11 @@ from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation ...@@ -41,10 +41,11 @@ from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
......
...@@ -17,12 +17,13 @@ from ...schedulers import ( ...@@ -17,12 +17,13 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
......
...@@ -26,7 +26,7 @@ from ...configuration_utils import FrozenDict ...@@ -26,7 +26,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
......
...@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict ...@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
...@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0 return 2.0 * image - 1.0
......
...@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict ...@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
...@@ -44,7 +44,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape): ...@@ -44,7 +44,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape):
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
masked_image = image * (image_mask < 127.5) masked_image = image * (image_mask < 127.5)
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST) mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
mask = np.array(mask.convert("L")) mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0 mask = mask.astype(np.float32) / 255.0
mask = mask[None, None] mask = mask[None, None]
......
...@@ -33,7 +33,7 @@ from ...schedulers import ( ...@@ -33,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
......
...@@ -33,7 +33,7 @@ from ...schedulers import ( ...@@ -33,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__) ...@@ -44,7 +44,7 @@ logger = logging.get_logger(__name__)
def preprocess_image(image): def preprocess_image(image):
w, h = image.size w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
...@@ -55,7 +55,7 @@ def preprocess_mask(mask): ...@@ -55,7 +55,7 @@ def preprocess_mask(mask):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
......
...@@ -38,6 +38,7 @@ from .import_utils import ( ...@@ -38,6 +38,7 @@ from .import_utils import (
) )
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION
if is_torch_available(): if is_torch_available():
......
import PIL.Image
import PIL.ImageOps
from packaging import version
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
...@@ -19,9 +19,8 @@ import unittest ...@@ -19,9 +19,8 @@ import unittest
import numpy as np import numpy as np
import torch import torch
import PIL
from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
from diffusers.utils import floats_tensor, load_image, slow, torch_device from diffusers.utils import PIL_INTERPOLATION, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch from diffusers.utils.testing_utils import require_torch
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -97,7 +96,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase): ...@@ -97,7 +96,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png" "/vq_diffusion/teddy_bear_pool.png"
) )
init_image = init_image.resize((64, 64), resample=PIL.Image.LANCZOS) init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"])
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto") ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device) ldm.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment