"vscode:/vscode.git/clone" did not exist on "b35d88c536081c7e12c4605cf5b7c9ae2e20af72"
adapter.py 23.9 KB
Newer Older
Will Berman's avatar
Will Berman committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
import os
from typing import Callable, List, Optional, Union
Will Berman's avatar
Will Berman committed
16
17
18
19
20

import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
21
from ..utils import logging
Will Berman's avatar
Will Berman committed
22
23
24
from .modeling_utils import ModelMixin


25
26
27
logger = logging.get_logger(__name__)


Will Berman's avatar
Will Berman committed
28
29
30
31
32
class MultiAdapter(ModelMixin):
    r"""
    MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
    user-assigned weighting.

33
34
    This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
    or saving.
Will Berman's avatar
Will Berman committed
35

36
    Args:
Will Berman's avatar
Will Berman committed
37
38
39
40
41
42
43
44
45
46
        adapters (`List[T2IAdapter]`, *optional*, defaults to None):
            A list of `T2IAdapter` model instances.
    """

    def __init__(self, adapters: List["T2IAdapter"]):
        super(MultiAdapter, self).__init__()

        self.num_adapter = len(adapters)
        self.adapters = nn.ModuleList(adapters)

47
48
49
50
51
52
        if len(adapters) == 0:
            raise ValueError("Expecting at least one adapter")

        if len(adapters) == 1:
            raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")

53
54
55
56
57
        # The outputs from each adapter are added together with a weight.
        # This means that the change in dimensions from downsampling must
        # be the same for all adapters. Inductively, it also means the
        # downscale_factor and total_downscale_factor must be the same for all
        # adapters.
58
        first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
59
        first_adapter_downscale_factor = adapters[0].downscale_factor
60
        for idx in range(1, len(adapters)):
61
62
63
64
            if (
                adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
                or adapters[idx].downscale_factor != first_adapter_downscale_factor
            ):
65
                raise ValueError(
66
67
68
69
70
                    f"Expecting all adapters to have the same downscaling behavior, but got:\n"
                    f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
                    f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
                    f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
                    f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
71
72
                )

73
74
        self.total_downscale_factor = first_adapter_total_downscale_factor
        self.downscale_factor = first_adapter_downscale_factor
75

Will Berman's avatar
Will Berman committed
76
77
78
79
    def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
        r"""
        Args:
            xs (`torch.Tensor`):
80
81
82
83
                A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
                models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
                `num_adapter` * number of channel per image.

Will Berman's avatar
Will Berman committed
84
            adapter_weights (`List[float]`, *optional*, defaults to None):
85
86
                A list of floats representing the weights which will be multiplied by each adapter's output before
                summing them together. If `None`, equal weights will be used for all adapters.
Will Berman's avatar
Will Berman committed
87
88
89
90
91
92
93
        """
        if adapter_weights is None:
            adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
        else:
            adapter_weights = torch.tensor(adapter_weights)

        accume_state = None
94
        for x, w, adapter in zip(xs, adapter_weights, self.adapters):
Will Berman's avatar
Will Berman committed
95
96
97
            features = adapter(x)
            if accume_state is None:
                accume_state = features
98
99
                for i in range(len(accume_state)):
                    accume_state[i] = w * accume_state[i]
Will Berman's avatar
Will Berman committed
100
101
102
103
104
            else:
                for i in range(len(features)):
                    accume_state[i] += w * features[i]
        return accume_state

105
106
107
108
109
110
111
112
113
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        is_main_process: bool = True,
        save_function: Callable = None,
        safe_serialization: bool = True,
        variant: Optional[str] = None,
    ):
        """
114
        Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
115
116
        `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.

117
        Args:
118
            save_directory (`str` or `os.PathLike`):
119
120
121
122
123
                The directory where the model will be saved. If the directory does not exist, it will be created.
            is_main_process (`bool`, optional, defaults=True):
                Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
                TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
                for the main process to avoid race conditions.
124
            save_function (`Callable`):
125
126
127
128
129
                Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
                `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
                variable.
            safe_serialization (`bool`, optional, defaults=True):
                If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
130
            variant (`str`, *optional*):
131
                If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        """
        idx = 0
        model_path_to_save = save_directory
        for adapter in self.adapters:
            adapter.save_pretrained(
                model_path_to_save,
                is_main_process=is_main_process,
                save_function=save_function,
                safe_serialization=safe_serialization,
                variant=variant,
            )

            idx += 1
            model_path_to_save = model_path_to_save + f"_{idx}"

    @classmethod
    def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
        r"""
150
        Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
151
152

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
153
        the model, set it back to training mode using `model.train()`.
154

155
156
157
158
        Warnings:
            *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
            with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
            from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.
159

160
        Args:
161
162
            pretrained_model_path (`os.PathLike`):
                A path to a *directory* containing model weights saved using
Patrick von Platen's avatar
Patrick von Platen committed
163
                [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
164
165
            torch_dtype (`torch.dtype`, *optional*):
                Override the default `torch.dtype` and load the model under this dtype.
166
167
168
169
170
171
172
173
174
175
176
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
                A map that specifies where each submodule should go. It doesn't need to be refined to each
                parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
                same device.

                To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
                more information about each option see [designing a device
                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
            max_memory (`Dict`, *optional*):
177
178
                A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
                available for each GPU and the available CPU RAM if unset.
179
180
181
182
183
184
            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
                Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
                also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
                model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
                setting this argument to `True` will raise an error.
            variant (`str`, *optional*):
185
186
                If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
                be ignored when using `from_flax`.
187
            use_safetensors (`bool`, *optional*, defaults to `None`):
188
189
190
                If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
                installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
                `safetensors` is not used.
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        """
        idx = 0
        adapters = []

        # load adapter and append to list until no adapter directory exists anymore
        # first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
        # second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
        model_path_to_load = pretrained_model_path
        while os.path.isdir(model_path_to_load):
            adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
            adapters.append(adapter)

            idx += 1
            model_path_to_load = pretrained_model_path + f"_{idx}"

        logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")

        if len(adapters) == 0:
            raise ValueError(
                f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
            )

        return cls(adapters)

Will Berman's avatar
Will Berman committed
215
216
217
218
219
220
221
222
223
224

class T2IAdapter(ModelMixin, ConfigMixin):
    r"""
    A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
    generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
    architecture follows the original implementation of
    [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
     and
     [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).

225
226
    This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
    downloading or saving.
Will Berman's avatar
Will Berman committed
227

228
229
230
231
    Args:
        in_channels (`int`, *optional*, defaults to `3`):
            The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
            image.
Will Berman's avatar
Will Berman committed
232
        channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
233
234
235
            The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
            determines the number of downsample blocks in the adapter.
        num_res_blocks (`int`, *optional*, defaults to `2`):
236
            Number of ResNet blocks in each downsample block.
237
        downscale_factor (`int`, *optional*, defaults to `8`):
238
239
            A factor that determines the total downscale factor of the Adapter.
        adapter_type (`str`, *optional*, defaults to `full_adapter`):
240
            Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
Will Berman's avatar
Will Berman committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    """

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        channels: List[int] = [320, 640, 1280, 1280],
        num_res_blocks: int = 2,
        downscale_factor: int = 8,
        adapter_type: str = "full_adapter",
    ):
        super().__init__()

        if adapter_type == "full_adapter":
            self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
256
257
        elif adapter_type == "full_adapter_xl":
            self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
Will Berman's avatar
Will Berman committed
258
259
260
        elif adapter_type == "light_adapter":
            self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
        else:
261
262
263
264
            raise ValueError(
                f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
                "'full_adapter_xl' or 'light_adapter'."
            )
Will Berman's avatar
Will Berman committed
265
266

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
267
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
268
269
270
271
        This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
        each representing information extracted at a different scale from the input. The length of the list is
        determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
        `num_res_blocks` parameters during initialization.
272
        """
Will Berman's avatar
Will Berman committed
273
274
275
276
277
278
        return self.adapter(x)

    @property
    def total_downscale_factor(self):
        return self.adapter.total_downscale_factor

279
280
281
282
283
284
285
    @property
    def downscale_factor(self):
        """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
        not evenly divisible by the downscale_factor then an exception will be raised.
        """
        return self.adapter.unshuffle.downscale_factor

Will Berman's avatar
Will Berman committed
286
287
288
289
290

# full adapter


class FullAdapter(nn.Module):
291
292
293
294
    r"""
    See [`T2IAdapter`] for more information.
    """

Will Berman's avatar
Will Berman committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    def __init__(
        self,
        in_channels: int = 3,
        channels: List[int] = [320, 640, 1280, 1280],
        num_res_blocks: int = 2,
        downscale_factor: int = 8,
    ):
        super().__init__()

        in_channels = in_channels * downscale_factor**2

        self.unshuffle = nn.PixelUnshuffle(downscale_factor)
        self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)

        self.body = nn.ModuleList(
            [
                AdapterBlock(channels[0], channels[0], num_res_blocks),
                *[
                    AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
                    for i in range(1, len(channels))
                ],
            ]
        )

        self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
322
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
323
324
325
326
        This method processes the input tensor `x` through the FullAdapter model and performs operations including
        pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
        capturing information at a different stage of processing within the FullAdapter model. The number of feature
        tensors in the list is determined by the number of downsample blocks specified during initialization.
327
        """
Will Berman's avatar
Will Berman committed
328
329
330
331
332
333
334
335
336
337
338
339
        x = self.unshuffle(x)
        x = self.conv_in(x)

        features = []

        for block in self.body:
            x = block(x)
            features.append(x)

        return features


340
class FullAdapterXL(nn.Module):
341
342
343
344
    r"""
    See [`T2IAdapter`] for more information.
    """

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    def __init__(
        self,
        in_channels: int = 3,
        channels: List[int] = [320, 640, 1280, 1280],
        num_res_blocks: int = 2,
        downscale_factor: int = 16,
    ):
        super().__init__()

        in_channels = in_channels * downscale_factor**2

        self.unshuffle = nn.PixelUnshuffle(downscale_factor)
        self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)

        self.body = []
        # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
        for i in range(len(channels)):
            if i == 1:
                self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
            elif i == 2:
                self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
            else:
                self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))

        self.body = nn.ModuleList(self.body)
370
371
        # XL has only one downsampling AdapterBlock.
        self.total_downscale_factor = downscale_factor * 2
372
373

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
374
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
375
        This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
376
377
        including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
        """
378
379
380
381
382
383
384
385
386
387
388
389
        x = self.unshuffle(x)
        x = self.conv_in(x)

        features = []

        for block in self.body:
            x = block(x)
            features.append(x)

        return features


Will Berman's avatar
Will Berman committed
390
class AdapterBlock(nn.Module):
391
392
393
394
    r"""
    An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
    `FullAdapterXL` models.

395
    Args:
396
397
398
399
400
401
402
        in_channels (`int`):
            Number of channels of AdapterBlock's input.
        out_channels (`int`):
            Number of channels of AdapterBlock's output.
        num_res_blocks (`int`):
            Number of ResNet blocks in the AdapterBlock.
        down (`bool`, *optional*, defaults to `False`):
403
            If `True`, perform downsampling on AdapterBlock's input.
404
405
406
    """

    def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
Will Berman's avatar
Will Berman committed
407
408
409
410
        super().__init__()

        self.downsample = None
        if down:
411
            self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
Will Berman's avatar
Will Berman committed
412
413
414
415
416
417
418
419
420

        self.in_conv = None
        if in_channels != out_channels:
            self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        self.resnets = nn.Sequential(
            *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
        )

421
    def forward(self, x: torch.Tensor) -> torch.Tensor:
422
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
423
424
425
        This method takes tensor x as input and performs operations downsampling and convolutional layers if the
        self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
        residual blocks to the input tensor.
426
        """
Will Berman's avatar
Will Berman committed
427
428
429
430
431
432
433
434
435
436
437
438
        if self.downsample is not None:
            x = self.downsample(x)

        if self.in_conv is not None:
            x = self.in_conv(x)

        x = self.resnets(x)

        return x


class AdapterResnetBlock(nn.Module):
439
440
441
    r"""
    An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.

442
    Args:
443
444
445
446
447
        channels (`int`):
            Number of channels of AdapterResnetBlock's input and output.
    """

    def __init__(self, channels: int):
Will Berman's avatar
Will Berman committed
448
449
450
451
452
        super().__init__()
        self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.act = nn.ReLU()
        self.block2 = nn.Conv2d(channels, channels, kernel_size=1)

453
    def forward(self, x: torch.Tensor) -> torch.Tensor:
454
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
455
456
        This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
        layer on the input tensor. It returns addition with the input tensor.
457
        """
458
459

        h = self.act(self.block1(x))
Will Berman's avatar
Will Berman committed
460
461
462
463
464
465
466
467
468
        h = self.block2(h)

        return h + x


# light adapter


class LightAdapter(nn.Module):
469
470
471
472
    r"""
    See [`T2IAdapter`] for more information.
    """

Will Berman's avatar
Will Berman committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    def __init__(
        self,
        in_channels: int = 3,
        channels: List[int] = [320, 640, 1280],
        num_res_blocks: int = 4,
        downscale_factor: int = 8,
    ):
        super().__init__()

        in_channels = in_channels * downscale_factor**2

        self.unshuffle = nn.PixelUnshuffle(downscale_factor)

        self.body = nn.ModuleList(
            [
                LightAdapterBlock(in_channels, channels[0], num_res_blocks),
                *[
                    LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
                    for i in range(len(channels) - 1)
                ],
                LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
            ]
        )

        self.total_downscale_factor = downscale_factor * (2 ** len(channels))

499
    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
500
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
501
502
        This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
        feature tensor corresponds to a different level of processing within the LightAdapter.
503
        """
Will Berman's avatar
Will Berman committed
504
505
506
507
508
509
510
511
512
513
514
515
        x = self.unshuffle(x)

        features = []

        for block in self.body:
            x = block(x)
            features.append(x)

        return features


class LightAdapterBlock(nn.Module):
516
517
518
519
    r"""
    A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
    `LightAdapter` model.

520
    Args:
521
522
523
524
525
526
527
        in_channels (`int`):
            Number of channels of LightAdapterBlock's input.
        out_channels (`int`):
            Number of channels of LightAdapterBlock's output.
        num_res_blocks (`int`):
            Number of LightAdapterResnetBlocks in the LightAdapterBlock.
        down (`bool`, *optional*, defaults to `False`):
528
            If `True`, perform downsampling on LightAdapterBlock's input.
529
530
531
    """

    def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
Will Berman's avatar
Will Berman committed
532
533
534
535
536
        super().__init__()
        mid_channels = out_channels // 4

        self.downsample = None
        if down:
537
            self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
Will Berman's avatar
Will Berman committed
538
539
540
541
542

        self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
        self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
        self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)

543
    def forward(self, x: torch.Tensor) -> torch.Tensor:
544
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
545
546
        This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
        layer, a sequence of residual blocks, and out convolutional layer.
547
        """
Will Berman's avatar
Will Berman committed
548
549
550
551
552
553
554
555
556
557
558
        if self.downsample is not None:
            x = self.downsample(x)

        x = self.in_conv(x)
        x = self.resnets(x)
        x = self.out_conv(x)

        return x


class LightAdapterResnetBlock(nn.Module):
559
560
561
562
    """
    A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
    architecture than `AdapterResnetBlock`.

563
    Args:
564
565
566
567
568
        channels (`int`):
            Number of channels of LightAdapterResnetBlock's input and output.
    """

    def __init__(self, channels: int):
Will Berman's avatar
Will Berman committed
569
570
571
572
573
        super().__init__()
        self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.act = nn.ReLU()
        self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

574
    def forward(self, x: torch.Tensor) -> torch.Tensor:
575
        r"""
Patrick von Platen's avatar
Patrick von Platen committed
576
577
        This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
        another convolutional layer and adds it to input tensor.
578
        """
579
580

        h = self.act(self.block1(x))
Will Berman's avatar
Will Berman committed
581
582
583
        h = self.block2(h)

        return h + x