"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c6629e6f111114e4432678ddd4beb83571dc25b9"
Unverified Commit 8fd3a743 authored by Penn's avatar Penn Committed by GitHub
Browse files

Fix using non-square images with UNet2DModel and DDIM/DDPM pipelines (#1289)



* fix non square images with UNet2DModel and DDIM/DDPM pipelines

* fix unet_2d `sample_size` docstring

* update pipeline tests for unet uncond
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 44e56de9
...@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.) implements for all the model (such as downloading or saving, etc.)
Parameters: Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Input sample size. Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
...@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
sample_size: Optional[int] = None, sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
center_input_sample: bool = False, center_input_sample: bool = False,
......
...@@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
implements for all the models (such as downloading or saving, etc.) implements for all the models (such as downloading or saving, etc.)
Parameters: Parameters:
sample_size (`int`, *optional*): The size of the input sample. sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
......
...@@ -89,7 +89,11 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -89,7 +89,11 @@ class DDIMPipeline(DiffusionPipeline):
generator = None generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
if self.device.type == "mps": if self.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator) image = torch.randn(image_shape, generator=generator)
......
...@@ -94,7 +94,11 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -94,7 +94,11 @@ class DDPMPipeline(DiffusionPipeline):
generator = None generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
if self.device.type == "mps": if self.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator) image = torch.randn(image_shape, generator=generator)
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import random import random
import tempfile import tempfile
import unittest import unittest
from functools import partial
import numpy as np import numpy as np
import torch import torch
...@@ -46,6 +47,7 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -46,6 +47,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -247,7 +249,6 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -247,7 +249,6 @@ class CustomPipelineTests(unittest.TestCase):
class PipelineFastTests(unittest.TestCase): class PipelineFastTests(unittest.TestCase):
@property
def dummy_image(self): def dummy_image(self):
batch_size = 1 batch_size = 1
num_channels = 3 num_channels = 3
...@@ -256,13 +257,12 @@ class PipelineFastTests(unittest.TestCase): ...@@ -256,13 +257,12 @@ class PipelineFastTests(unittest.TestCase):
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image return image
@property def dummy_uncond_unet(self, sample_size=32):
def dummy_uncond_unet(self):
torch.manual_seed(0) torch.manual_seed(0)
model = UNet2DModel( model = UNet2DModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=sample_size,
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"), down_block_types=("DownBlock2D", "AttnDownBlock2D"),
...@@ -270,13 +270,12 @@ class PipelineFastTests(unittest.TestCase): ...@@ -270,13 +270,12 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property def dummy_cond_unet(self, sample_size=32):
def dummy_cond_unet(self):
torch.manual_seed(0) torch.manual_seed(0)
model = UNet2DConditionModel( model = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=sample_size,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
...@@ -285,13 +284,12 @@ class PipelineFastTests(unittest.TestCase): ...@@ -285,13 +284,12 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property def dummy_cond_unet_inpaint(self, sample_size=32):
def dummy_cond_unet_inpaint(self):
torch.manual_seed(0) torch.manual_seed(0)
model = UNet2DConditionModel( model = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=sample_size,
in_channels=9, in_channels=9,
out_channels=4, out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
...@@ -300,7 +298,6 @@ class PipelineFastTests(unittest.TestCase): ...@@ -300,7 +298,6 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property
def dummy_vq_model(self): def dummy_vq_model(self):
torch.manual_seed(0) torch.manual_seed(0)
model = VQModel( model = VQModel(
...@@ -313,7 +310,6 @@ class PipelineFastTests(unittest.TestCase): ...@@ -313,7 +310,6 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property
def dummy_vae(self): def dummy_vae(self):
torch.manual_seed(0) torch.manual_seed(0)
model = AutoencoderKL( model = AutoencoderKL(
...@@ -326,7 +322,6 @@ class PipelineFastTests(unittest.TestCase): ...@@ -326,7 +322,6 @@ class PipelineFastTests(unittest.TestCase):
) )
return model return model
@property
def dummy_text_encoder(self): def dummy_text_encoder(self):
torch.manual_seed(0) torch.manual_seed(0)
config = CLIPTextConfig( config = CLIPTextConfig(
...@@ -342,7 +337,6 @@ class PipelineFastTests(unittest.TestCase): ...@@ -342,7 +337,6 @@ class PipelineFastTests(unittest.TestCase):
) )
return CLIPTextModel(config) return CLIPTextModel(config)
@property
def dummy_extractor(self): def dummy_extractor(self):
def extract(*args, **kwargs): def extract(*args, **kwargs):
class Out: class Out:
...@@ -357,15 +351,43 @@ class PipelineFastTests(unittest.TestCase): ...@@ -357,15 +351,43 @@ class PipelineFastTests(unittest.TestCase):
return extract return extract
def test_components(self): @parameterized.expand(
[
[DDIMScheduler, DDIMPipeline, 32],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32],
[DDIMScheduler, DDIMPipeline, (32, 64)],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)],
]
)
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
unet = self.dummy_uncond_unet(sample_size)
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
scheduler = scheduler_fn()
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
# Device type MPS is not supported for torch.Generator() api.
if torch_device == "mps":
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out_image = pipeline(
generator=generator,
num_inference_steps=2,
output_type="np",
).images
sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size
assert out_image.shape == (1, *sample_size, 3)
def test_stable_diffusion_components(self):
"""Test that components property works correctly""" """Test that components property works correctly"""
unet = self.dummy_cond_unet unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True) scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae vae = self.dummy_vae()
bert = self.dummy_text_encoder bert = self.dummy_text_encoder()
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB") init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
...@@ -377,7 +399,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -377,7 +399,7 @@ class PipelineFastTests(unittest.TestCase):
text_encoder=bert, text_encoder=bert,
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor(),
).to(torch_device) ).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device) text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment