scheduling_ipndm.py 8.56 KB
Newer Older
1
# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
16
from typing import List, Optional, Tuple, Union
17

18
import numpy as np
19
20
21
22
23
24
25
26
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput


class IPNDMScheduler(SchedulerMixin, ConfigMixin):
    """
27
    A fourth-order Improved Pseudo Linear Multistep scheduler.
28

29
30
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
31
32

    Args:
33
34
35
36
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
37
38
    """

39
40
    order = 1

41
    @register_to_config
42
43
44
    def __init__(
        self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
    ):
45
46
47
48
49
50
51
52
53
54
55
56
57
        # set `betas`, `alphas`, `timesteps`
        self.set_timesteps(num_train_timesteps)

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # For now we only support F-PNDM, i.e. the runge-kutta method
        # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
        # mainly at formula (9), (12), (13) and the Algorithm 2.
        self.pndm_order = 4

        # running values
        self.ets = []
YiYi Xu's avatar
YiYi Xu committed
58
        self._step_index = None
59
        self._begin_index = None
YiYi Xu's avatar
YiYi Xu committed
60
61
62
63

    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
64
        The index counter for current timestep. It will increase 1 after each scheduler step.
YiYi Xu's avatar
YiYi Xu committed
65
66
        """
        return self._step_index
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`):
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

86
87
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
88
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
89
90
91

        Args:
            num_inference_steps (`int`):
92
93
94
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95
96
97
98
99
        """
        self.num_inference_steps = num_inference_steps
        steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
        steps = torch.cat([steps, torch.tensor([0.0])])

100
101
102
103
104
        if self.config.trained_betas is not None:
            self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
        else:
            self.betas = torch.sin(steps * math.pi / 2) ** 2

105
106
107
108
109
110
        self.alphas = (1.0 - self.betas**2) ** 0.5

        timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
        self.timesteps = timesteps.to(device)

        self.ets = []
YiYi Xu's avatar
YiYi Xu committed
111
        self._step_index = None
112
        self._begin_index = None
YiYi Xu's avatar
YiYi Xu committed
113

114
115
116
117
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
YiYi Xu's avatar
YiYi Xu committed
118

119
        indices = (schedule_timesteps == timestep).nonzero()
YiYi Xu's avatar
YiYi Xu committed
120
121
122
123
124

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
125
        pos = 1 if len(indices) > 1 else 0
YiYi Xu's avatar
YiYi Xu committed
126

127
128
129
130
131
132
133
134
135
136
        return indices[pos].item()

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
    def _init_step_index(self, timestep):
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
137
138
139

    def step(
        self,
140
        model_output: torch.Tensor,
141
        timestep: Union[int, torch.Tensor],
142
        sample: torch.Tensor,
143
144
145
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
146
147
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the linear multistep method. It performs one forward pass multiple times to approximate the solution.
148
149

        Args:
150
            model_output (`torch.Tensor`):
151
152
153
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
154
            sample (`torch.Tensor`):
155
156
157
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
158
159

        Returns:
160
161
162
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
163
164
165
166
167
        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )
YiYi Xu's avatar
YiYi Xu committed
168
169
        if self.step_index is None:
            self._init_step_index(timestep)
170

YiYi Xu's avatar
YiYi Xu committed
171
172
        timestep_index = self.step_index
        prev_timestep_index = self.step_index + 1
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

        ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
        self.ets.append(ets)

        if len(self.ets) == 1:
            ets = self.ets[-1]
        elif len(self.ets) == 2:
            ets = (3 * self.ets[-1] - self.ets[-2]) / 2
        elif len(self.ets) == 3:
            ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
        else:
            ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])

        prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)

YiYi Xu's avatar
YiYi Xu committed
188
189
190
        # upon completion increase step index by one
        self._step_index += 1

191
192
193
194
195
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

196
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
197
198
199
200
201
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
202
            sample (`torch.Tensor`):
203
                The input sample.
204
205

        Returns:
206
            `torch.Tensor`:
207
                A scaled input sample.
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        """
        return sample

    def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
        alpha = self.alphas[timestep_index]
        sigma = self.betas[timestep_index]

        next_alpha = self.alphas[prev_timestep_index]
        next_sigma = self.betas[prev_timestep_index]

        pred = (sample - sigma * ets) / max(alpha, 1e-8)
        prev_sample = next_alpha * pred + ets * next_sigma

        return prev_sample

    def __len__(self):
        return self.config.num_train_timesteps