scheduling_utils_flax.py 12.1 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import importlib
15
import math
16
import os
17
from dataclasses import dataclass
Kashif Rasul's avatar
Kashif Rasul committed
18
from enum import Enum
Patrick von Platen's avatar
Patrick von Platen committed
19
from typing import Optional, Tuple, Union
20

21
import flax
22
import jax.numpy as jnp
23
from huggingface_hub.utils import validate_hf_hub_args
24

25
from ..utils import BaseOutput, PushToHubMixin
26
27
28


SCHEDULER_CONFIG_NAME = "scheduler_config.json"
Kashif Rasul's avatar
Kashif Rasul committed
29
30


31
32
33
34
# NOTE: We make this type an enum because it simplifies usage in docs and prevents
# circular imports when used for `_compatibles` within the schedulers module.
# When it's used as a type in pipelines, it really is a Union because the actual
# scheduler instance is passed in.
Kashif Rasul's avatar
Kashif Rasul committed
35
36
37
38
39
40
class FlaxKarrasDiffusionSchedulers(Enum):
    FlaxDDIMScheduler = 1
    FlaxDDPMScheduler = 2
    FlaxPNDMScheduler = 3
    FlaxLMSDiscreteScheduler = 4
    FlaxDPMSolverMultistepScheduler = 5
Pedro Cuenca's avatar
Pedro Cuenca committed
41
    FlaxEulerDiscreteScheduler = 6
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


@dataclass
class FlaxSchedulerOutput(BaseOutput):
    """
    Base class for the scheduler's step function output.

    Args:
        prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
    """

    prev_sample: jnp.ndarray


58
class FlaxSchedulerMixin(PushToHubMixin):
59
60
    """
    Mixin containing common functions for the schedulers.
61
62
63
64
65

    Class attributes:
        - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
          `from_config` can be used from a class different than the one used to save the config (should be overridden
          by parent class).
66
67
68
    """

    config_name = SCHEDULER_CONFIG_NAME
69
    ignore_for_config = ["dtype"]
70
71
72
73
    _compatibles = []
    has_compatibles = True

    @classmethod
74
    @validate_hf_hub_args
75
76
    def from_pretrained(
        cls,
Anh71me's avatar
Anh71me committed
77
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        subfolder: Optional[str] = None,
        return_unused_kwargs=False,
        **kwargs,
    ):
        r"""
        Instantiate a Scheduler class from a pre-defined JSON-file.

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
                Can be either:

                    - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
                      organization name, like `google/ddpm-celebahq-256`.
                    - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
                      e.g., `./my_model_directory/`.
            subfolder (`str`, *optional*):
                In case the relevant files are located inside a subfolder of the model repo (either remote in
                huggingface.co or downloaded locally), you can specify the folder name here.
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                Whether kwargs that are not consumed by the Python class should be returned or not.

            cache_dir (`Union[str, os.PathLike]`, *optional*):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (`bool`, *optional*, defaults to `False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (`Dict[str, str]`, *optional*):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
            local_files_only(`bool`, *optional*, defaults to `False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
115
            token (`str` or *bool*, *optional*):
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
                when running `transformers-cli login` (stored in `~/.huggingface`).
            revision (`str`, *optional*, defaults to `"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
                identifier allowed by git.

        <Tip>

         It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
         models](https://huggingface.co/docs/hub/models-gated#gated-models).

        </Tip>

        <Tip>

        Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
        use this method in a firewalled environment.

        </Tip>

        """
        config, kwargs = cls.load_config(
139
140
141
142
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            subfolder=subfolder,
            return_unused_kwargs=True,
            **kwargs,
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        )
        scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)

        if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
            state = scheduler.create_state()

        if return_unused_kwargs:
            return scheduler, state, unused_kwargs

        return scheduler, state

    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        """
        Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
        [`~FlaxSchedulerMixin.from_pretrained`] class method.

        Args:
            save_directory (`str` or `os.PathLike`):
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
162
163
164
165
166
167
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        """
        self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)

    @property
    def compatibles(self):
        """
        Returns all schedulers that are compatible with this scheduler

        Returns:
            `List[SchedulerMixin]`: List of compatible schedulers
        """
        return self._get_compatibles()

    @classmethod
    def _get_compatibles(cls):
        compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
        diffusers_library = importlib.import_module(__name__.split(".")[0])
        compatible_classes = [
            getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
        ]
        return compatible_classes
189
190
191
192
193


def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
    assert len(shape) >= x.ndim
    return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265


def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray:
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.

    Returns:
        betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
    """

    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return jnp.array(betas, dtype=dtype)


@flax.struct.dataclass
class CommonSchedulerState:
    alphas: jnp.ndarray
    betas: jnp.ndarray
    alphas_cumprod: jnp.ndarray

    @classmethod
    def create(cls, scheduler):
        config = scheduler.config

        if config.trained_betas is not None:
            betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype)
        elif config.beta_schedule == "linear":
            betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype)
        elif config.beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            betas = (
                jnp.linspace(
                    config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype
                )
                ** 2
            )
        elif config.beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype)
        else:
            raise NotImplementedError(
                f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}"
            )

        alphas = 1.0 - betas

        alphas_cumprod = jnp.cumprod(alphas, axis=0)

        return cls(
            alphas=alphas,
            betas=betas,
            alphas_cumprod=alphas_cumprod,
        )


266
def get_sqrt_alpha_prod(
267
268
269
270
271
272
273
274
275
276
277
278
    state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
):
    alphas_cumprod = state.alphas_cumprod

    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

279
280
281
282
283
284
285
    return sqrt_alpha_prod, sqrt_one_minus_alpha_prod


def add_noise_common(
    state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
):
    sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
286
287
    noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
    return noisy_samples
288
289
290
291
292
293


def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
    sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
    velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
    return velocity