vae.py 31.7 KB
Newer Older
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Optional, Tuple
Partho's avatar
Partho committed
16

patil-suraj's avatar
patil-suraj committed
17
18
19
20
import numpy as np
import torch
import torch.nn as nn

21
from ...utils import BaseOutput
22
23
24
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
25
from ..unets.unet_2d_blocks import (
Suraj Patil's avatar
Suraj Patil committed
26
27
28
29
30
    AutoencoderTinyBlock,
    UNetMidBlock2D,
    get_down_block,
    get_up_block,
)
patil-suraj's avatar
patil-suraj committed
31
32


33
34
35
36
37
38
39
40
41
42
43
44
45
@dataclass
class EncoderOutput(BaseOutput):
    r"""
    Output of encoding method.

    Args:
        latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
            The encoded latent.
    """

    latent: torch.Tensor


46
47
@dataclass
class DecoderOutput(BaseOutput):
48
    r"""
49
50
51
    Output of decoding method.

    Args:
52
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
Steven Liu's avatar
Steven Liu committed
53
            The decoded output sample from the last layer of the model.
54
55
    """

56
    sample: torch.Tensor
57
    commit_loss: Optional[torch.FloatTensor] = None
58
59


patil-suraj's avatar
patil-suraj committed
60
class Encoder(nn.Module):
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    r"""
    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
            options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        double_z (`bool`, *optional*, defaults to `True`):
            Whether to double the number of output channels for the last block.
    """

patil-suraj's avatar
patil-suraj committed
84
85
    def __init__(
        self,
86
87
88
89
90
91
92
93
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        double_z: bool = True,
Will Berman's avatar
Will Berman committed
94
        mid_block_add_attention=True,
patil-suraj's avatar
patil-suraj committed
95
96
    ):
        super().__init__()
97
98
        self.layers_per_block = layers_per_block

Kashif Rasul's avatar
Kashif Rasul committed
99
        self.conv_in = nn.Conv2d(
100
101
102
103
104
105
            in_channels,
            block_out_channels[0],
            kernel_size=3,
            stride=1,
            padding=1,
        )
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

        self.down_blocks = nn.ModuleList([])

        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=self.layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                add_downsample=not is_final_block,
                resnet_eps=1e-6,
123
                downsample_padding=0,
124
                resnet_act_fn=act_fn,
125
                resnet_groups=norm_num_groups,
126
                attention_head_dim=output_channel,
127
128
129
130
131
132
133
134
135
136
137
                temb_channels=None,
            )
            self.down_blocks.append(down_block)

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            resnet_eps=1e-6,
            resnet_act_fn=act_fn,
            output_scale_factor=1,
            resnet_time_scale_shift="default",
138
            attention_head_dim=block_out_channels[-1],
139
            resnet_groups=norm_num_groups,
140
            temb_channels=None,
Will Berman's avatar
Will Berman committed
141
            add_attention=mid_block_add_attention,
patil-suraj's avatar
patil-suraj committed
142
143
        )

144
        # out
145
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
146
147
148
149
        self.conv_act = nn.SiLU()

        conv_out_channels = 2 * out_channels if double_z else out_channels
        self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
150

151
152
        self.gradient_checkpointing = False

153
    def forward(self, sample: torch.Tensor) -> torch.Tensor:
154
        r"""The forward method of the `Encoder` class."""
155

156
157
        sample = self.conv_in(sample)

158
        if torch.is_grad_enabled() and self.gradient_checkpointing:
159
            # down
160
161
162
163
            for down_block in self.down_blocks:
                sample = self._gradient_checkpointing_func(down_block, sample)
            # middle
            sample = self._gradient_checkpointing_func(self.mid_block, sample)
164
165
166
167
168

        else:
            # down
            for down_block in self.down_blocks:
                sample = down_block(sample)
patil-suraj's avatar
patil-suraj committed
169

170
171
            # middle
            sample = self.mid_block(sample)
172
173
174
175
176
177
178

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        return sample
patil-suraj's avatar
patil-suraj committed
179
180
181


class Decoder(nn.Module):
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    r"""
    The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        norm_type (`str`, *optional*, defaults to `"group"`):
            The normalization type to use. Can be either `"group"` or `"spatial"`.
    """

patil-suraj's avatar
patil-suraj committed
204
205
    def __init__(
        self,
206
207
208
209
210
211
212
213
        in_channels: int = 3,
        out_channels: int = 3,
        up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        norm_type: str = "group",  # group, spatial
Will Berman's avatar
Will Berman committed
214
        mid_block_add_attention=True,
patil-suraj's avatar
patil-suraj committed
215
216
    ):
        super().__init__()
217
218
        self.layers_per_block = layers_per_block

219
220
221
222
223
224
225
        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[-1],
            kernel_size=3,
            stride=1,
            padding=1,
        )
226
227
228

        self.up_blocks = nn.ModuleList([])

YiYi Xu's avatar
YiYi Xu committed
229
230
        temb_channels = in_channels if norm_type == "spatial" else None

231
232
233
234
235
236
        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            resnet_eps=1e-6,
            resnet_act_fn=act_fn,
            output_scale_factor=1,
YiYi Xu's avatar
YiYi Xu committed
237
            resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
238
            attention_head_dim=block_out_channels[-1],
239
            resnet_groups=norm_num_groups,
YiYi Xu's avatar
YiYi Xu committed
240
            temb_channels=temb_channels,
Will Berman's avatar
Will Berman committed
241
            add_attention=mid_block_add_attention,
patil-suraj's avatar
patil-suraj committed
242
243
        )

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]

            is_final_block = i == len(block_out_channels) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=self.layers_per_block + 1,
                in_channels=prev_output_channel,
                out_channels=output_channel,
                prev_output_channel=None,
                add_upsample=not is_final_block,
                resnet_eps=1e-6,
                resnet_act_fn=act_fn,
262
                resnet_groups=norm_num_groups,
263
                attention_head_dim=output_channel,
YiYi Xu's avatar
YiYi Xu committed
264
265
                temb_channels=temb_channels,
                resnet_time_scale_shift=norm_type,
266
267
268
269
270
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
YiYi Xu's avatar
YiYi Xu committed
271
272
273
274
        if norm_type == "spatial":
            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
        else:
            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
275
276
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
patil-suraj's avatar
patil-suraj committed
277

278
279
        self.gradient_checkpointing = False

280
    def forward(
Suraj Patil's avatar
Suraj Patil committed
281
        self,
282
283
284
        sample: torch.Tensor,
        latent_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
285
        r"""The forward method of the `Decoder` class."""
286

287
        sample = self.conv_in(sample)
patil-suraj's avatar
patil-suraj committed
288

289
        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
290
        if torch.is_grad_enabled() and self.gradient_checkpointing:
291
292
293
            # middle
            sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
            sample = sample.to(upscale_dtype)
patil-suraj's avatar
patil-suraj committed
294

295
296
297
            # up
            for up_block in self.up_blocks:
                sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
298
299
        else:
            # middle
YiYi Xu's avatar
YiYi Xu committed
300
            sample = self.mid_block(sample, latent_embeds)
301
            sample = sample.to(upscale_dtype)
302
303
304

            # up
            for up_block in self.up_blocks:
YiYi Xu's avatar
YiYi Xu committed
305
                sample = up_block(sample, latent_embeds)
patil-suraj's avatar
patil-suraj committed
306

307
        # post-process
YiYi Xu's avatar
YiYi Xu committed
308
309
310
311
        if latent_embeds is None:
            sample = self.conv_norm_out(sample)
        else:
            sample = self.conv_norm_out(sample, latent_embeds)
312
313
314
315
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        return sample
patil-suraj's avatar
patil-suraj committed
316
317


Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
318
class UpSample(nn.Module):
319
320
321
322
323
324
325
326
327
328
    r"""
    The `UpSample` layer of a variational autoencoder that upsamples its input.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
    """

Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
329
330
331
332
333
334
335
336
337
338
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
    ) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)

339
    def forward(self, x: torch.Tensor) -> torch.Tensor:
340
        r"""The forward method of the `UpSample` class."""
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        x = torch.relu(x)
        x = self.deconv(x)
        return x


class MaskConditionEncoder(nn.Module):
    """
    used in AsymmetricAutoencoderKL
    """

    def __init__(
        self,
        in_ch: int,
        out_ch: int = 192,
        res_ch: int = 768,
        stride: int = 16,
    ) -> None:
        super().__init__()

        channels = []
        while stride > 1:
            stride = stride // 2
            in_ch_ = out_ch * 2
            if out_ch > res_ch:
                out_ch = res_ch
            if stride == 1:
                in_ch_ = res_ch
            channels.append((in_ch_, out_ch))
            out_ch *= 2

        out_channels = []
        for _in_ch, _out_ch in channels:
            out_channels.append(_out_ch)
        out_channels.append(channels[-1][0])

        layers = []
        in_ch_ = in_ch
        for l in range(len(out_channels)):
            out_ch_ = out_channels[l]
            if l == 0 or l == 1:
                layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
            else:
                layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
            in_ch_ = out_ch_

        self.layers = nn.Sequential(*layers)

388
    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
389
        r"""The forward method of the `MaskConditionEncoder` class."""
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
390
391
392
393
394
395
396
397
398
399
        out = {}
        for l in range(len(self.layers)):
            layer = self.layers[l]
            x = layer(x)
            out[str(tuple(x.shape))] = x
            x = torch.relu(x)
        return out


class MaskConditionDecoder(nn.Module):
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
    decoder with a conditioner on the mask and masked image.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        norm_type (`str`, *optional*, defaults to `"group"`):
            The normalization type to use. Can be either `"group"` or `"spatial"`.
    """
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
421
422
423

    def __init__(
        self,
424
425
426
427
428
429
430
431
        in_channels: int = 3,
        out_channels: int = 3,
        up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        norm_type: str = "group",  # group, spatial
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    ):
        super().__init__()
        self.layers_per_block = layers_per_block

        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[-1],
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.up_blocks = nn.ModuleList([])

        temb_channels = in_channels if norm_type == "spatial" else None

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            resnet_eps=1e-6,
            resnet_act_fn=act_fn,
            output_scale_factor=1,
            resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
            attention_head_dim=block_out_channels[-1],
            resnet_groups=norm_num_groups,
            temb_channels=temb_channels,
        )

        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]

            is_final_block = i == len(block_out_channels) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=self.layers_per_block + 1,
                in_channels=prev_output_channel,
                out_channels=output_channel,
                prev_output_channel=None,
                add_upsample=not is_final_block,
                resnet_eps=1e-6,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=output_channel,
                temb_channels=temb_channels,
                resnet_time_scale_shift=norm_type,
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # condition encoder
        self.condition_encoder = MaskConditionEncoder(
            in_ch=out_channels,
            out_ch=block_out_channels[0],
            res_ch=block_out_channels[-1],
        )

        # out
        if norm_type == "spatial":
            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
        else:
            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

        self.gradient_checkpointing = False

503
504
    def forward(
        self,
505
506
507
508
509
        z: torch.Tensor,
        image: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        latent_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
510
        r"""The forward method of the `MaskConditionDecoder` class."""
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
511
512
513
514
        sample = z
        sample = self.conv_in(sample)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
515
        if torch.is_grad_enabled() and self.gradient_checkpointing:
516
517
518
            # middle
            sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
            sample = sample.to(upscale_dtype)
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
519

520
521
522
523
524
525
526
            # condition encoder
            if image is not None and mask is not None:
                masked_image = (1 - mask) * image
                im_x = self._gradient_checkpointing_func(
                    self.condition_encoder,
                    masked_image,
                    mask,
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
527
528
                )

529
530
            # up
            for up_block in self.up_blocks:
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
531
                if image is not None and mask is not None:
532
533
534
535
536
537
                    sample_ = im_x[str(tuple(sample.shape))]
                    mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
                    sample = sample * mask_ + sample_ * (1 - mask_)
                sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
            if image is not None and mask is not None:
                sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
Ruslan Vorovchenko's avatar
Ruslan Vorovchenko committed
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        else:
            # middle
            sample = self.mid_block(sample, latent_embeds)
            sample = sample.to(upscale_dtype)

            # condition encoder
            if image is not None and mask is not None:
                masked_image = (1 - mask) * image
                im_x = self.condition_encoder(masked_image, mask)

            # up
            for up_block in self.up_blocks:
                if image is not None and mask is not None:
                    sample_ = im_x[str(tuple(sample.shape))]
                    mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
                    sample = sample * mask_ + sample_ * (1 - mask_)
                sample = up_block(sample, latent_embeds)
            if image is not None and mask is not None:
                sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)

        # post-process
        if latent_embeds is None:
            sample = self.conv_norm_out(sample)
        else:
            sample = self.conv_norm_out(sample, latent_embeds)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        return sample


patil-suraj's avatar
patil-suraj committed
569
570
571
572
573
574
575
576
577
class VectorQuantizer(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
    multiplications and allows for post-hoc remapping of indices.
    """

    # NOTE: due to a bug the beta term was applied to the wrong term. for
    # backwards compatibility we use the buggy version by default, but you can
    # specify legacy=False to fix it.
Will Berman's avatar
Will Berman committed
578
    def __init__(
579
580
581
582
583
584
585
586
        self,
        n_e: int,
        vq_embed_dim: int,
        beta: float,
        remap=None,
        unknown_index: str = "random",
        sane_index_shape: bool = False,
        legacy: bool = True,
Will Berman's avatar
Will Berman committed
587
    ):
patil-suraj's avatar
patil-suraj committed
588
589
        super().__init__()
        self.n_e = n_e
Will Berman's avatar
Will Berman committed
590
        self.vq_embed_dim = vq_embed_dim
patil-suraj's avatar
patil-suraj committed
591
592
593
        self.beta = beta
        self.legacy = legacy

Will Berman's avatar
Will Berman committed
594
        self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
patil-suraj's avatar
patil-suraj committed
595
596
597
598
599
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        self.remap = remap
        if self.remap is not None:
            self.register_buffer("used", torch.tensor(np.load(self.remap)))
600
            self.used: torch.Tensor
patil-suraj's avatar
patil-suraj committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            self.re_embed = self.used.shape[0]
            self.unknown_index = unknown_index  # "random" or "extra" or integer
            if self.unknown_index == "extra":
                self.unknown_index = self.re_embed
                self.re_embed = self.re_embed + 1
            print(
                f"Remapping {self.n_e} indices to {self.re_embed} indices. "
                f"Using {self.unknown_index} for unknown indices."
            )
        else:
            self.re_embed = n_e

        self.sane_index_shape = sane_index_shape

615
    def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
patil-suraj's avatar
patil-suraj committed
616
617
618
619
620
621
622
623
624
625
626
627
628
        ishape = inds.shape
        assert len(ishape) > 1
        inds = inds.reshape(ishape[0], -1)
        used = self.used.to(inds)
        match = (inds[:, :, None] == used[None, None, ...]).long()
        new = match.argmax(-1)
        unknown = match.sum(2) < 1
        if self.unknown_index == "random":
            new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
        else:
            new[unknown] = self.unknown_index
        return new.reshape(ishape)

629
    def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
patil-suraj's avatar
patil-suraj committed
630
631
632
633
634
635
636
637
638
        ishape = inds.shape
        assert len(ishape) > 1
        inds = inds.reshape(ishape[0], -1)
        used = self.used.to(inds)
        if self.re_embed > self.used.shape[0]:  # extra token
            inds[inds >= self.used.shape[0]] = 0  # simply set to zero
        back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
        return back.reshape(ishape)

639
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
patil-suraj's avatar
patil-suraj committed
640
641
        # reshape z -> (batch, height, width, channel) and flatten
        z = z.permute(0, 2, 3, 1).contiguous()
Will Berman's avatar
Will Berman committed
642
        z_flattened = z.view(-1, self.vq_embed_dim)
patil-suraj's avatar
patil-suraj committed
643

644
645
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
patil-suraj's avatar
patil-suraj committed
646
647
648
649
650
651
652
653
654
655
656
657

        z_q = self.embedding(min_encoding_indices).view(z.shape)
        perplexity = None
        min_encodings = None

        # compute loss for embedding
        if not self.legacy:
            loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
        else:
            loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
658
        z_q: torch.Tensor = z + (z_q - z).detach()
patil-suraj's avatar
patil-suraj committed
659
660
661
662
663
664
665
666
667
668
669
670
671
672

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        if self.remap is not None:
            min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1)  # add batch axis
            min_encoding_indices = self.remap_to_used(min_encoding_indices)
            min_encoding_indices = min_encoding_indices.reshape(-1, 1)  # flatten

        if self.sane_index_shape:
            min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])

        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

673
    def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
patil-suraj's avatar
patil-suraj committed
674
675
676
677
678
679
680
        # shape specifying (batch, height, width, channel)
        if self.remap is not None:
            indices = indices.reshape(shape[0], -1)  # add batch axis
            indices = self.unmap_to_all(indices)
            indices = indices.reshape(-1)  # flatten again

        # get quantized latent vectors
681
        z_q: torch.Tensor = self.embedding(indices)
patil-suraj's avatar
patil-suraj committed
682
683
684
685
686
687
688
689
690
691

        if shape is not None:
            z_q = z_q.view(shape)
            # reshape back to match original input shape
            z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q


class DiagonalGaussianDistribution(object):
692
    def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
patil-suraj's avatar
patil-suraj committed
693
694
695
696
697
698
699
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
700
701
702
            self.var = self.std = torch.zeros_like(
                self.mean, device=self.parameters.device, dtype=self.parameters.dtype
            )
patil-suraj's avatar
patil-suraj committed
703

704
    def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
705
        # make sure sample is on the same device as the parameters and has same dtype
706
        sample = randn_tensor(
Suraj Patil's avatar
Suraj Patil committed
707
708
709
710
            self.mean.shape,
            generator=generator,
            device=self.parameters.device,
            dtype=self.parameters.dtype,
711
        )
712
        x = self.mean + self.std * sample
patil-suraj's avatar
patil-suraj committed
713
714
        return x

715
    def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
patil-suraj's avatar
patil-suraj committed
716
717
718
719
        if self.deterministic:
            return torch.Tensor([0.0])
        else:
            if other is None:
Suraj Patil's avatar
Suraj Patil committed
720
721
722
723
                return 0.5 * torch.sum(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
                    dim=[1, 2, 3],
                )
patil-suraj's avatar
patil-suraj committed
724
725
726
727
728
729
730
731
732
733
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var
                    - 1.0
                    - self.logvar
                    + other.logvar,
                    dim=[1, 2, 3],
                )

734
    def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
patil-suraj's avatar
patil-suraj committed
735
736
737
        if self.deterministic:
            return torch.Tensor([0.0])
        logtwopi = np.log(2.0 * np.pi)
Suraj Patil's avatar
Suraj Patil committed
738
739
740
741
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )
patil-suraj's avatar
patil-suraj committed
742

743
    def mode(self) -> torch.Tensor:
patil-suraj's avatar
patil-suraj committed
744
        return self.mean
745
746
747


class EncoderTiny(nn.Module):
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
    r"""
    The `EncoderTiny` layer is a simpler version of the `Encoder` layer.

    Args:
        in_channels (`int`):
            The number of input channels.
        out_channels (`int`):
            The number of output channels.
        num_blocks (`Tuple[int, ...]`):
            Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
            use.
        block_out_channels (`Tuple[int, ...]`):
            The number of output channels for each block.
        act_fn (`str`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
    """

765
766
767
768
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
769
770
        num_blocks: Tuple[int, ...],
        block_out_channels: Tuple[int, ...],
771
772
773
774
775
776
777
778
779
780
781
        act_fn: str,
    ):
        super().__init__()

        layers = []
        for i, num_block in enumerate(num_blocks):
            num_channels = block_out_channels[i]

            if i == 0:
                layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
            else:
Suraj Patil's avatar
Suraj Patil committed
782
783
784
785
786
787
788
789
790
791
                layers.append(
                    nn.Conv2d(
                        num_channels,
                        num_channels,
                        kernel_size=3,
                        padding=1,
                        stride=2,
                        bias=False,
                    )
                )
792
793
794
795
796
797
798
799
800

            for _ in range(num_block):
                layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))

        layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))

        self.layers = nn.Sequential(*layers)
        self.gradient_checkpointing = False

801
    def forward(self, x: torch.Tensor) -> torch.Tensor:
802
        r"""The forward method of the `EncoderTiny` class."""
803
        if torch.is_grad_enabled() and self.gradient_checkpointing:
804
            x = self._gradient_checkpointing_func(self.layers, x)
805
806

        else:
807
808
            # scale image from [-1, 1] to [0, 1] to match TAESD convention
            x = self.layers(x.add(1).div(2))
809
810
811
812
813

        return x


class DecoderTiny(nn.Module):
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
    r"""
    The `DecoderTiny` layer is a simpler version of the `Decoder` layer.

    Args:
        in_channels (`int`):
            The number of input channels.
        out_channels (`int`):
            The number of output channels.
        num_blocks (`Tuple[int, ...]`):
            Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
            use.
        block_out_channels (`Tuple[int, ...]`):
            The number of output channels for each block.
        upsampling_scaling_factor (`int`):
            The scaling factor to use for upsampling.
        act_fn (`str`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
    """

833
834
835
836
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
837
838
        num_blocks: Tuple[int, ...],
        block_out_channels: Tuple[int, ...],
839
840
        upsampling_scaling_factor: int,
        act_fn: str,
841
        upsample_fn: str,
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
    ):
        super().__init__()

        layers = [
            nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
            get_activation(act_fn),
        ]

        for i, num_block in enumerate(num_blocks):
            is_final_block = i == (len(num_blocks) - 1)
            num_channels = block_out_channels[i]

            for _ in range(num_block):
                layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))

            if not is_final_block:
858
                layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
859
860

            conv_out_channel = num_channels if not is_final_block else out_channels
Suraj Patil's avatar
Suraj Patil committed
861
862
863
864
865
866
867
868
869
            layers.append(
                nn.Conv2d(
                    num_channels,
                    conv_out_channel,
                    kernel_size=3,
                    padding=1,
                    bias=is_final_block,
                )
            )
870
871
872
873

        self.layers = nn.Sequential(*layers)
        self.gradient_checkpointing = False

874
    def forward(self, x: torch.Tensor) -> torch.Tensor:
875
        r"""The forward method of the `DecoderTiny` class."""
876
877
878
        # Clamp.
        x = torch.tanh(x / 3) * 3

879
        if torch.is_grad_enabled() and self.gradient_checkpointing:
880
            x = self._gradient_checkpointing_func(self.layers, x)
881
882
883
        else:
            x = self.layers(x)

884
885
        # scale image from [0, 1] to [-1, 1] to match diffusers convention
        return x.mul(2).sub(1)