Unverified Commit 13994b2d authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

RePaint fast tests and API conforming (#1701)

* add fast tests

* better tests and fp16

* batch fixes

* Reuse preprocessing

* quickfix
parent ea90bf2b
...@@ -13,33 +13,61 @@ ...@@ -13,33 +13,61 @@
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import PIL import PIL
from tqdm.auto import tqdm
from ...models import UNet2DModel from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import RePaintScheduler from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging
def _preprocess_image(image: PIL.Image.Image): logger = logging.get_logger(__name__) # pylint: disable=invalid-name
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image return image
def _preprocess_mask(mask: PIL.Image.Image): def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
mask = np.array(mask.convert("L")) if isinstance(mask, torch.Tensor):
mask = mask.astype(np.float32) / 255.0 return mask
mask = mask[None, None] elif isinstance(mask, PIL.Image.Image):
mask[mask < 0.5] = 0 mask = [mask]
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask) if isinstance(mask[0], PIL.Image.Image):
w, h = mask[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
mask = np.concatenate(mask, axis=0)
mask = mask.astype(np.float32) / 255.0
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
elif isinstance(mask[0], torch.Tensor):
mask = torch.cat(mask, dim=0)
return mask return mask
...@@ -54,8 +82,8 @@ class RePaintPipeline(DiffusionPipeline): ...@@ -54,8 +82,8 @@ class RePaintPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
original_image: Union[torch.FloatTensor, PIL.Image.Image], image: Union[torch.Tensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image],
num_inference_steps: int = 250, num_inference_steps: int = 250,
eta: float = 0.0, eta: float = 0.0,
jump_length: int = 10, jump_length: int = 10,
...@@ -63,10 +91,11 @@ class RePaintPipeline(DiffusionPipeline): ...@@ -63,10 +91,11 @@ class RePaintPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
r""" r"""
Args: Args:
original_image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor` or `PIL.Image.Image`):
The original image to inpaint on. The original image to inpaint on.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`): mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
The mask_image where 0.0 values define which part of the original image to inpaint (change). The mask_image where 0.0 values define which part of the original image to inpaint (change).
...@@ -97,12 +126,14 @@ class RePaintPipeline(DiffusionPipeline): ...@@ -97,12 +126,14 @@ class RePaintPipeline(DiffusionPipeline):
generated images. generated images.
""" """
if not isinstance(original_image, torch.FloatTensor): message = "Please use `image` instead of `original_image`."
original_image = _preprocess_image(original_image) original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs)
original_image = original_image.to(self.device) original_image = original_image or image
if not isinstance(mask_image, torch.FloatTensor):
mask_image = _preprocess_mask(mask_image) original_image = _preprocess_image(original_image)
mask_image = mask_image.to(self.device) original_image = original_image.to(device=self.device, dtype=self.unet.dtype)
mask_image = _preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype)
# sample gaussian noise to begin the loop # sample gaussian noise to begin the loop
image = torch.randn( image = torch.randn(
...@@ -110,14 +141,14 @@ class RePaintPipeline(DiffusionPipeline): ...@@ -110,14 +141,14 @@ class RePaintPipeline(DiffusionPipeline):
generator=generator, generator=generator,
device=self.device, device=self.device,
) )
image = image.to(self.device) image = image.to(device=self.device, dtype=self.unet.dtype)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
self.scheduler.eta = eta self.scheduler.eta = eta
t_last = self.scheduler.timesteps[0] + 1 t_last = self.scheduler.timesteps[0] + 1
for i, t in enumerate(tqdm(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
if t < t_last: if t < t_last:
# predict the noise residual # predict the noise residual
model_output = self.unet(image, t).sample model_output = self.unet(image, t).sample
......
...@@ -270,9 +270,13 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -270,9 +270,13 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
# been observed. # been observed.
# 5. Add noise # 5. Add noise
noise = torch.randn( device = model_output.device
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device if device.type == "mps":
) # randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
noise = noise.to(device)
else:
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
variance = 0 variance = 0
...@@ -305,7 +309,12 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -305,7 +309,12 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
for i in range(n): for i in range(n):
beta = self.betas[timestep + i] beta = self.betas[timestep + i]
noise = torch.randn(sample.shape, generator=generator, device=sample.device) if sample.device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator)
noise = noise.to(sample.device)
else:
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
......
...@@ -21,10 +21,68 @@ import torch ...@@ -21,10 +21,68 @@ import torch
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = RePaintPipeline
test_cpu_offload = False
def get_dummy_components(self):
torch.manual_seed(0)
torch.manual_seed(0)
unet = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
scheduler = RePaintScheduler()
components = {"unet": unet, "scheduler": scheduler}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = np.random.RandomState(seed).standard_normal((1, 3, 32, 32))
image = torch.from_numpy(image).to(device=device, dtype=torch.float32)
mask = (image > 0).to(device=device, dtype=torch.float32)
inputs = {
"image": image,
"mask_image": mask,
"generator": generator,
"num_inference_steps": 5,
"eta": 0.0,
"jump_length": 2,
"jump_n_sample": 2,
"output_type": "numpy",
}
return inputs
def test_repaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = RePaintPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow @slow
@require_torch_gpu @require_torch_gpu
class RepaintPipelineIntegrationTests(unittest.TestCase): class RepaintPipelineIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment