pipeline_repaint.py 7.65 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
Revist's avatar
Revist committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16
import warnings
17
from typing import List, Optional, Tuple, Union
Revist's avatar
Revist committed
18
19
20

import numpy as np
import PIL
21
import torch
Revist's avatar
Revist committed
22
23
24

from ...models import UNet2DModel
from ...schedulers import RePaintScheduler
25
from ...utils import PIL_INTERPOLATION, logging, randn_tensor
26
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
Revist's avatar
Revist committed
27
28


29
30
31
32
33
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
34
35
36
37
38
    warnings.warn(
        "The preprocess method is deprecated and will be removed in a future version. Please"
        " use VaeImageProcessor.preprocess instead",
        FutureWarning,
    )
39
40
41
42
43
44
45
    if isinstance(image, torch.Tensor):
        return image
    elif isinstance(image, PIL.Image.Image):
        image = [image]

    if isinstance(image[0], PIL.Image.Image):
        w, h = image[0].size
46
        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8
47
48
49
50
51
52
53
54
55

        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
        image = np.concatenate(image, axis=0)
        image = np.array(image).astype(np.float32) / 255.0
        image = image.transpose(0, 3, 1, 2)
        image = 2.0 * image - 1.0
        image = torch.from_numpy(image)
    elif isinstance(image[0], torch.Tensor):
        image = torch.cat(image, dim=0)
Revist's avatar
Revist committed
56
57
58
    return image


59
60
61
62
63
64
65
66
def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
    if isinstance(mask, torch.Tensor):
        return mask
    elif isinstance(mask, PIL.Image.Image):
        mask = [mask]

    if isinstance(mask[0], PIL.Image.Image):
        w, h = mask[0].size
67
        w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32
68
69
70
71
72
73
74
75
        mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
        mask = np.concatenate(mask, axis=0)
        mask = mask.astype(np.float32) / 255.0
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1
        mask = torch.from_numpy(mask)
    elif isinstance(mask[0], torch.Tensor):
        mask = torch.cat(mask, dim=0)
Revist's avatar
Revist committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    return mask


class RePaintPipeline(DiffusionPipeline):
    unet: UNet2DModel
    scheduler: RePaintScheduler

    def __init__(self, unet, scheduler):
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)

    @torch.no_grad()
    def __call__(
        self,
90
91
        image: Union[torch.Tensor, PIL.Image.Image],
        mask_image: Union[torch.Tensor, PIL.Image.Image],
Revist's avatar
Revist committed
92
93
94
95
        num_inference_steps: int = 250,
        eta: float = 0.0,
        jump_length: int = 10,
        jump_n_sample: int = 10,
96
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Revist's avatar
Revist committed
97
98
99
100
101
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ) -> Union[ImagePipelineOutput, Tuple]:
        r"""
        Args:
102
            image (`torch.FloatTensor` or `PIL.Image.Image`):
Revist's avatar
Revist committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                The original image to inpaint on.
            mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
                The mask_image where 0.0 values define which part of the original image to inpaint (change).
            num_inference_steps (`int`, *optional*, defaults to 1000):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            eta (`float`):
                The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM
                and 1.0 is DDPM scheduler respectively.
            jump_length (`int`, *optional*, defaults to 10):
                The number of steps taken forward in time before going backward in time for a single jump ("j" in
                RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
            jump_n_sample (`int`, *optional*, defaults to 10):
                The number of times we will make forward time jump for a given chosen time sample. Take a look at
                Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
            generator (`torch.Generator`, *optional*):
119
120
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
Revist's avatar
Revist committed
121
122
123
124
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
125
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Revist's avatar
Revist committed
126
127

        Returns:
128
129
            [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
            True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
Revist's avatar
Revist committed
130
131
        """

132
        original_image = image
133
134

        original_image = _preprocess_image(original_image)
135
        original_image = original_image.to(device=self._execution_device, dtype=self.unet.dtype)
136
        mask_image = _preprocess_mask(mask_image)
137
        mask_image = mask_image.to(device=self._execution_device, dtype=self.unet.dtype)
Revist's avatar
Revist committed
138

139
140
        batch_size = original_image.shape[0]

Revist's avatar
Revist committed
141
        # sample gaussian noise to begin the loop
142
143
144
145
146
147
148
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        image_shape = original_image.shape
149
        image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
Revist's avatar
Revist committed
150
151

        # set step values
152
        self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self._execution_device)
Revist's avatar
Revist committed
153
154
155
        self.scheduler.eta = eta

        t_last = self.scheduler.timesteps[0] + 1
156
        generator = generator[0] if isinstance(generator, list) else generator
157
        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
Revist's avatar
Revist committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            if t < t_last:
                # predict the noise residual
                model_output = self.unet(image, t).sample
                # compute previous image: x_t -> x_t-1
                image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample

            else:
                # compute the reverse: x_t-1 -> x_t
                image = self.scheduler.undo_step(image, t_last, generator)
            t_last = t

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)