utils.py 90 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sayak Paul's avatar
Sayak Paul committed
15
import inspect
16
17
18
import os
import tempfile
import unittest
UmerHA's avatar
UmerHA committed
19
from itertools import product
20
21
22
23
24
25
26
27
28
29

import numpy as np
import torch

from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
    LCMScheduler,
    UNet2DConditionModel,
)
30
from diffusers.utils import logging
31
32
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
33
    CaptureLogger,
34
35
36
    floats_tensor,
    require_peft_backend,
    require_peft_version_greater,
37
    require_transformers_version_greater,
38
    skip_mps,
39
40
41
42
43
    torch_device,
)


if is_peft_available():
44
    from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    from peft.tuners.tuners_utils import BaseTunerLayer
    from peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


71
72
73
74
75
76
def initialize_dummy_state_dict(state_dict):
    if not all(v.device.type == "meta" for _, v in state_dict.items()):
        raise ValueError("`state_dict` has non-meta values.")
    return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}


77
78
79
@require_peft_backend
class PeftLoraLoaderMixinTests:
    pipeline_class = None
Aryan's avatar
Aryan committed
80

81
82
    scheduler_cls = None
    scheduler_kwargs = None
Aryan's avatar
Aryan committed
83
    scheduler_classes = [DDIMScheduler, LCMScheduler]
Sayak Paul's avatar
Sayak Paul committed
84

85
    has_two_text_encoders = False
86
    has_three_text_encoders = False
Sayak Paul's avatar
Sayak Paul committed
87
88
89
90
91
92
93
    text_encoder_cls, text_encoder_id = None, None
    text_encoder_2_cls, text_encoder_2_id = None, None
    text_encoder_3_cls, text_encoder_3_id = None, None
    tokenizer_cls, tokenizer_id = None, None
    tokenizer_2_cls, tokenizer_2_id = None, None
    tokenizer_3_cls, tokenizer_3_id = None, None

94
    unet_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
95
    transformer_cls = None
96
    transformer_kwargs = None
Aryan's avatar
Aryan committed
97
    vae_cls = AutoencoderKL
98
99
    vae_kwargs = None

Aryan's avatar
Aryan committed
100
101
    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]

102
    def get_dummy_components(self, scheduler_cls=None, use_dora=False):
103
104
105
106
107
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

108
        scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
109
110
111
        rank = 4

        torch.manual_seed(0)
112
113
114
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
Sayak Paul's avatar
Sayak Paul committed
115
            transformer = self.transformer_cls(**self.transformer_kwargs)
116
117
118
119

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        torch.manual_seed(0)
Aryan's avatar
Aryan committed
120
        vae = self.vae_cls(**self.vae_kwargs)
121

Sayak Paul's avatar
Sayak Paul committed
122
123
        text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
124

Sayak Paul's avatar
Sayak Paul committed
125
126
127
        if self.text_encoder_2_cls is not None:
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
128

Sayak Paul's avatar
Sayak Paul committed
129
130
131
        if self.text_encoder_3_cls is not None:
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
132

133
134
135
        text_lora_config = LoraConfig(
            r=rank,
            lora_alpha=rank,
Aryan's avatar
Aryan committed
136
            target_modules=self.text_encoder_target_modules,
137
            init_lora_weights=False,
138
            use_dora=use_dora,
139
140
        )

141
        denoiser_lora_config = LoraConfig(
142
143
144
145
146
            r=rank,
            lora_alpha=rank,
            target_modules=["to_q", "to_k", "to_v", "to_out.0"],
            init_lora_weights=False,
            use_dora=use_dora,
147
148
        )

Sayak Paul's avatar
Sayak Paul committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})
175

176
        return pipeline_components, text_lora_config, denoiser_lora_config
177

Sayak Paul's avatar
Sayak Paul committed
178
179
180
181
    @property
    def output_shape(self):
        raise NotImplementedError

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = torch.manual_seed(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

203
    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
204
205
206
207
208
209
210
211
212
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

213
214
215
216
217
218
219
220
221
222
223
    def _get_lora_state_dicts(self, modules_to_save):
        state_dicts = {}
        for module_name, module in modules_to_save.items():
            if module is not None:
                state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
        return state_dicts

    def _get_modules_to_save(self, pipe, has_denoiser=False):
        modules_to_save = {}
        lora_loadable_modules = self.pipeline_class._lora_loadable_modules

224
225
226
227
228
        if (
            "text_encoder" in lora_loadable_modules
            and hasattr(pipe, "text_encoder")
            and getattr(pipe.text_encoder, "peft_config", None) is not None
        ):
229
230
            modules_to_save["text_encoder"] = pipe.text_encoder

231
232
233
234
235
        if (
            "text_encoder_2" in lora_loadable_modules
            and hasattr(pipe, "text_encoder_2")
            and getattr(pipe.text_encoder_2, "peft_config", None) is not None
        ):
236
237
238
239
240
241
242
243
244
245
246
            modules_to_save["text_encoder_2"] = pipe.text_encoder_2

        if has_denoiser:
            if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
                modules_to_save["unet"] = pipe.unet

            if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
                modules_to_save["transformer"] = pipe.transformer

        return modules_to_save

247
248
249
250
    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
251
        for scheduler_cls in self.scheduler_classes:
252
253
254
255
256
257
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
258
            output_no_lora = pipe(**inputs)[0]
Sayak Paul's avatar
Sayak Paul committed
259
            self.assertTrue(output_no_lora.shape == self.output_shape)
260
261
262
263
264
265

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
266
        for scheduler_cls in self.scheduler_classes:
267
268
269
270
271
272
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
273
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
274
            self.assertTrue(output_no_lora.shape == self.output_shape)
275
276
277
278

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

279
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
280
281
282
283
284
285
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
286

Aryan's avatar
Aryan committed
287
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
288
289
290
291
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    @require_peft_version_greater("0.13.1")
    def test_low_cpu_mem_usage_with_injection(self):
        """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
                )
                self.assertTrue(
                    "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
                    "The LoRA params should be on 'meta' device.",
                )

                te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
                set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
                self.assertTrue(
                    "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
                    "No param should be on 'meta' device.",
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
            self.assertTrue(
                "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
            )

            denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
            set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
            self.assertTrue(
                "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
            )

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    self.assertTrue(
                        "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
                        "The LoRA params should be on 'meta' device.",
                    )

                    te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
                    set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
                    self.assertTrue(
                        "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
                        "No param should be on 'meta' device.",
                    )

            _, _, inputs = self.get_dummy_inputs()
            output_lora = pipe(**inputs)[0]
            self.assertTrue(output_lora.shape == self.output_shape)

    @require_peft_version_greater("0.13.1")
354
    @require_transformers_version_greater("4.45.2")
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    def test_low_cpu_mem_usage_with_loading(self):
        """Tests if we can load LoRA state dict with low_cpu_mem_usage."""

        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

            with tempfile.TemporaryDirectory() as tmpdirname:
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)

                for module_name, module in modules_to_save.items():
                    self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

                images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
                self.assertTrue(
                    np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                    "Loading from saved checkpoints should give same results.",
                )

                # Now, check for `low_cpu_mem_usage.`
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)

                for module_name, module in modules_to_save.items():
                    self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

                images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
                self.assertTrue(
                    np.allclose(
                        images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
                    ),
                    "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
                )

422
423
424
425
426
    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
427
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
Aryan's avatar
Aryan committed
428
429

        # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
Aryan's avatar
Aryan committed
430
431
432
433
434
435
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
436
        for scheduler_cls in self.scheduler_classes:
437
438
439
440
441
442
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
443
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
444
            self.assertTrue(output_no_lora.shape == self.output_shape)
445
446
447
448

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

449
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
450
451
452
453
454
455
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
456

Aryan's avatar
Aryan committed
457
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
458
459
460
461
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
462
463
464
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

465
466
467
468
469
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
470
471
472
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

473
474
475
476
477
478
479
480
481
482
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
483
        for scheduler_cls in self.scheduler_classes:
484
485
486
487
488
489
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
490
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
491
            self.assertTrue(output_no_lora.shape == self.output_shape)
492
493
494
495

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

496
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
497
498
499
500
501
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
502
503
504
505
506

            pipe.fuse_lora()
            # Fusing should still keep the LoRA layers
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

507
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
508
509
510
511
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
512

Aryan's avatar
Aryan committed
513
            ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
514
515
516
517
518
519
520
521
522
            self.assertFalse(
                np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
523
        for scheduler_cls in self.scheduler_classes:
524
525
526
527
528
529
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
530
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
531
            self.assertTrue(output_no_lora.shape == self.output_shape)
532

Aryan's avatar
Aryan committed
533
534
535
536
537
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
538

539
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
540
541
542
543
544
545
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
546
547
548
549
550
551
552

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )

553
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
554
555
556
557
558
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
559

Aryan's avatar
Aryan committed
560
            ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
561
562
563
564
565
566
567
568
569
            self.assertTrue(
                np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
Aryan's avatar
Aryan committed
570
        for scheduler_cls in self.scheduler_classes:
571
572
573
574
575
576
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
577
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
578
            self.assertTrue(output_no_lora.shape == self.output_shape)
579

Aryan's avatar
Aryan committed
580
581
582
583
584
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
585

586
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
587
588
589
590
591
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
592

Aryan's avatar
Aryan committed
593
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
594
595

            with tempfile.TemporaryDirectory() as tmpdirname:
596
597
                modules_to_save = self._get_modules_to_save(pipe)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
598

599
600
601
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )
Sayak Paul's avatar
Sayak Paul committed
602

603
604
605
606
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

607
608
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
609

610
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
611
612
613
614
615
616

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

617
618
619
620
621
622
    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
623
        for scheduler_cls in self.scheduler_classes:
624
            components, _, _ = self.get_dummy_components(scheduler_cls)
625
            # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
626
627
628
629
630
631
632
633
634
635
636
637
638
            text_lora_config = LoraConfig(
                r=4,
                rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
                lora_alpha=4,
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
                init_lora_weights=False,
                use_dora=False,
            )
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
639
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
640
            self.assertTrue(output_no_lora.shape == self.output_shape)
641
642
643
644
645
646
647
648
649
650
651

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
            # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
            # supports missing layers (PR#8324).
            state_dict = {
                f"text_encoder.{module_name}": param
                for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                if "text_model.encoder.layers.4" not in module_name
            }

652
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
653
654
655
656
657
658
659
660
661
662
663
664
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    state_dict.update(
                        {
                            f"text_encoder_2.{module_name}": param
                            for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                            if "text_model.encoder.layers.4" not in module_name
                        }
                    )
665

Aryan's avatar
Aryan committed
666
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
667
668
669
670
671
672
673
674
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            # Unload lora and load it back using the pipe.load_lora_weights machinery
            pipe.unload_lora_weights()
            pipe.load_lora_weights(state_dict)

Aryan's avatar
Aryan committed
675
            output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
676
677
678
679
680
            self.assertTrue(
                not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
                "Removing adapters should change the output",
            )

681
682
683
684
    def test_simple_inference_save_pretrained(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
Aryan's avatar
Aryan committed
685
        for scheduler_cls in self.scheduler_classes:
686
687
688
689
690
691
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
692
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
693
            self.assertTrue(output_no_lora.shape == self.output_shape)
694
695
696
697

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

698
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
699
700
701
702
703
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
704

Aryan's avatar
Aryan committed
705
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
706
707
708
709
710
711
712
713
714
715
716
717

            with tempfile.TemporaryDirectory() as tmpdirname:
                pipe.save_pretrained(tmpdirname)

                pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
                pipe_from_pretrained.to(torch_device)

            self.assertTrue(
                check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                "Lora not correctly set in text encoder",
            )

718
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
719
720
721
722
723
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
                    )
724

Aryan's avatar
Aryan committed
725
            images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
726
727
728
729
730
731

            self.assertTrue(
                np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

732
    def test_simple_inference_with_text_denoiser_lora_save_load(self):
733
734
735
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
Aryan's avatar
Aryan committed
736
        for scheduler_cls in self.scheduler_classes:
737
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
738
739
740
741
742
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
743
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
744
            self.assertTrue(output_no_lora.shape == self.output_shape)
745

Aryan's avatar
Aryan committed
746
747
748
749
750
751
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
752
753
754
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
755

756
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
757
758
759
760
761
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
762

Aryan's avatar
Aryan committed
763
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
764
765

            with tempfile.TemporaryDirectory() as tmpdirname:
766
767
768
769
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
Aryan's avatar
Aryan committed
770
                )
771

772
773
774
775
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

776
777
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
778

779
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
780
781
782
783
784
            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

785
    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
786
787
788
789
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
790
791
792
793
794
795
796
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
797
        for scheduler_cls in self.scheduler_classes:
798
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
799
800
801
802
803
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
804
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
805
            self.assertTrue(output_no_lora.shape == self.output_shape)
806

Aryan's avatar
Aryan committed
807
808
809
810
811
812
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
813
814
815
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
816

817
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
818
819
820
821
822
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
823

Aryan's avatar
Aryan committed
824
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
825
826
827
828
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
829
830
831
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

832
833
834
835
836
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
837
838
839
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

840
841
842
843
844
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

Aryan's avatar
Aryan committed
845
846
847
848
849
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                    "The scaling parameter has not been correctly restored!",
                )
850

851
    def test_simple_inference_with_text_lora_denoiser_fused(self):
852
853
854
855
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
Aryan's avatar
Aryan committed
856
        for scheduler_cls in self.scheduler_classes:
857
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
858
859
860
861
862
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
863
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
864
            self.assertTrue(output_no_lora.shape == self.output_shape)
865

Aryan's avatar
Aryan committed
866
867
868
869
870
871
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
872
873
874
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
875

876
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
877
878
879
880
881
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
882

Aryan's avatar
Aryan committed
883
884
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

885
            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
886
887
888
889
890
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
891
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
892

893
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
894
895
896
897
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
898

Aryan's avatar
Aryan committed
899
            output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
900
            self.assertFalse(
Aryan's avatar
Aryan committed
901
                np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
902
903
            )

904
    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
905
906
907
908
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
909
        for scheduler_cls in self.scheduler_classes:
910
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
911
912
913
914
915
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
916
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
917
            self.assertTrue(output_no_lora.shape == self.output_shape)
918

Aryan's avatar
Aryan committed
919
920
921
922
923
924
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
925
926
927
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
928

929
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
930
931
932
933
934
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
935
936
937
938
939
940

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )
Aryan's avatar
Aryan committed
941
            self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
942

943
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
944
945
946
947
948
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
949

Aryan's avatar
Aryan committed
950
            output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
951
            self.assertTrue(
Aryan's avatar
Aryan committed
952
                np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
953
954
955
                "Fused lora should change the output",
            )

Aryan's avatar
Aryan committed
956
957
958
    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
959
960
961
962
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
963
        for scheduler_cls in self.scheduler_classes:
964
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
965
966
967
968
969
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
970
971
972
973
974
975
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
976
977
978
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
979

980
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
981
982
983
984
985
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
986

Aryan's avatar
Aryan committed
987
988
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
989

Aryan's avatar
Aryan committed
990
991
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
992
993

            # unloading should remove the LoRA layers
Aryan's avatar
Aryan committed
994
995
996
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

Aryan's avatar
Aryan committed
997
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
998

999
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1000
1001
1002
1003
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )
1004
1005
1006

            # Fuse and unfuse should lead to the same results
            self.assertTrue(
Aryan's avatar
Aryan committed
1007
1008
                np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
1009
1010
            )

1011
    def test_simple_inference_with_text_denoiser_multi_adapter(self):
1012
1013
1014
1015
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1016
        for scheduler_cls in self.scheduler_classes:
1017
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1018
1019
1020
1021
1022
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1023
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1024

Aryan's avatar
Aryan committed
1025
1026
1027
1028
1029
1030
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1031

Aryan's avatar
Aryan committed
1032
1033
1034
1035
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1036

1037
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1038
1039
1040
1041
1042
1043
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1044
1045

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1046
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1047
1048
1049
1050
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1051
1052

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1053
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1054
1055
1056
1057
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1058
1059

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1060
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1061
1062
1063
1064
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1083
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1084
1085
1086
1087
1088
1089

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    def test_wrong_adapter_name_raises_error(self):
        scheduler_cls = self.scheduler_classes[0]
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

        with self.assertRaises(ValueError) as err_context:
            pipe.set_adapters("test")

        self.assertTrue("not in the list of present adapters" in str(err_context.exception))

        # test this works.
        pipe.set_adapters("adapter-1")
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

1122
    def test_simple_inference_with_text_denoiser_block_scale(self):
UmerHA's avatar
UmerHA committed
1123
1124
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
Aryan's avatar
Aryan committed
1125
        one adapter and set different weights for different blocks (i.e. block lora)
UmerHA's avatar
UmerHA committed
1126
        """
Aryan's avatar
Aryan committed
1127
        for scheduler_cls in self.scheduler_classes:
1128
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1129
1130
1131
1132
1133
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1134
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1135
1136
1137

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1138
1139
1140
1141

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1142

1143
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1144
1145
1146
1147
1148
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1149
1150
1151

            weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
            pipe.set_adapters("adapter-1", weights_1)
Aryan's avatar
Aryan committed
1152
            output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1153
1154
1155

            weights_2 = {"unet": {"up": 5}}
            pipe.set_adapters("adapter-1", weights_2)
Aryan's avatar
Aryan committed
1156
            output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171

            self.assertFalse(
                np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
                "LoRA weights 1 and 2 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 1 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 2 should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1172
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1173
1174
1175
1176
1177
1178

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1179
    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
UmerHA's avatar
UmerHA committed
1180
1181
1182
1183
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set differnt weights for different blocks (i.e. block lora)
        """
Aryan's avatar
Aryan committed
1184
        for scheduler_cls in self.scheduler_classes:
1185
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1186
1187
1188
1189
1190
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1191
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1192

Aryan's avatar
Aryan committed
1193
1194
1195
1196
1197
1198
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
UmerHA's avatar
UmerHA committed
1199

Aryan's avatar
Aryan committed
1200
1201
1202
1203
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1204

1205
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1206
1207
1208
1209
1210
1211
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1212
1213
1214
1215

            scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
            scales_2 = {"unet": {"down": 5, "mid": 5}}

Aryan's avatar
Aryan committed
1216
1217
            pipe.set_adapters("adapter-1", scales_1)
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1218
1219

            pipe.set_adapters("adapter-2", scales_2)
Aryan's avatar
Aryan committed
1220
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1221
1222

            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
Aryan's avatar
Aryan committed
1223
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1242
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            # a mismatching number of adapter_names and adapter_weights should raise an error
            with self.assertRaises(ValueError):
                pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])

1253
    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
UmerHA's avatar
UmerHA committed
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

1319
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
UmerHA's avatar
UmerHA committed
1320
1321
1322
1323
1324
1325
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1326
1327
1328

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1329

1330
        if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1331
1332
1333
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1334
1335
1336
1337
1338
1339
1340
1341

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

1342
    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1343
1344
1345
1346
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
Aryan's avatar
Aryan committed
1347
        for scheduler_cls in self.scheduler_classes:
1348
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1349
1350
1351
1352
1353
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1354
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1355

Aryan's avatar
Aryan committed
1356
1357
1358
1359
1360
1361
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1362

Aryan's avatar
Aryan committed
1363
1364
1365
1366
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1367

1368
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1369
1370
1371
1372
1373
1374
1375
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1376
1377

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1378
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1379
1380

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1381
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1382
1383

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1384
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.delete_adapters("adapter-1")
Aryan's avatar
Aryan committed
1402
            output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1403
1404
1405
1406
1407
1408
1409

            self.assertTrue(
                np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            pipe.delete_adapters("adapter-2")
Aryan's avatar
Aryan committed
1410
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1411
1412
1413
1414
1415
1416

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

Aryan's avatar
Aryan committed
1417
1418
1419
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1420

Aryan's avatar
Aryan committed
1421
1422
1423
1424
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1425
1426
1427
1428

            pipe.set_adapters(["adapter-1", "adapter-2"])
            pipe.delete_adapters(["adapter-1", "adapter-2"])

Aryan's avatar
Aryan committed
1429
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1430
1431
1432
1433
1434
1435

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1436
    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1437
1438
1439
1440
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1441
        for scheduler_cls in self.scheduler_classes:
1442
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1443
1444
1445
1446
1447
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1448
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1449

Aryan's avatar
Aryan committed
1450
1451
1452
1453
1454
1455
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1456

Aryan's avatar
Aryan committed
1457
1458
1459
1460
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1461

1462
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1463
1464
1465
1466
1467
1468
1469
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1470
1471

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1472
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1473
1474

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1475
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1476
1477

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1478
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
Aryan's avatar
Aryan committed
1497
            output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
1498
1499
1500
1501
1502
1503
1504

            self.assertFalse(
                np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Weighted adapter and mixed adapter should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1505
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1506
1507
1508
1509
1510
1511

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1512
    @skip_mps
1513
    def test_lora_fuse_nan(self):
Aryan's avatar
Aryan committed
1514
        for scheduler_cls in self.scheduler_classes:
1515
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1516
1517
1518
1519
1520
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1521
1522
1523
1524
1525
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1526

Aryan's avatar
Aryan committed
1527
1528
1529
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1530
1531
1532

            # corrupt one LoRA weight with `inf` values
            with torch.no_grad():
1533
1534
1535
1536
1537
1538
                if self.unet_kwargs:
                    pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
                        "adapter-1"
                    ].weight += float("inf")
                else:
                    pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
1539
1540
1541

            # with `safe_fusing=True` we should see an Error
            with self.assertRaises(ValueError):
Aryan's avatar
Aryan committed
1542
                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
1543
1544

            # without we should not see an error, but every image will be black
Aryan's avatar
Aryan committed
1545
1546
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
            out = pipe("test", num_inference_steps=2, output_type="np")[0]
1547
1548
1549
1550
1551
1552
1553
1554

            self.assertTrue(np.isnan(out).all())

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1555
        for scheduler_cls in self.scheduler_classes:
1556
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1557
1558
1559
1560
1561
1562
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1563
1564
1565

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1566
1567
1568
1569
1570

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-1"])

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
Aryan's avatar
Aryan committed
1571
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-2"])

            pipe.set_adapters(["adapter-1", "adapter-2"])
            self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1584
        for scheduler_cls in self.scheduler_classes:
1585
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1586
1587
1588
1589
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

Aryan's avatar
Aryan committed
1590
1591
1592
1593
1594
1595
            # 1.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                dicts_to_be_checked = {"text_encoder": ["adapter-1"]}

1596
1597
1598
1599
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
                dicts_to_be_checked.update({"unet": ["adapter-1"]})
            else:
Aryan's avatar
Aryan committed
1600
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1601
                dicts_to_be_checked.update({"transformer": ["adapter-1"]})
1602

Aryan's avatar
Aryan committed
1603
1604
1605
1606
1607
1608
1609
1610
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 2.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1611
1612
1613
1614
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
Aryan's avatar
Aryan committed
1615
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1616
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
1617

Aryan's avatar
Aryan committed
1618
1619
1620
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 3.
1621
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1622
1623
1624
1625
1626

            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1627
1628
1629
1630
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
Aryan's avatar
Aryan committed
1631

1632
1633
            self.assertDictEqual(
                pipe.get_list_adapters(),
1634
                dicts_to_be_checked,
1635
1636
            )

Aryan's avatar
Aryan committed
1637
1638
1639
1640
1641
            # 4.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1642
            if self.unet_kwargs is not None:
Aryan's avatar
Aryan committed
1643
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
1644
1645
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
            else:
Aryan's avatar
Aryan committed
1646
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
1647
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
Aryan's avatar
Aryan committed
1648

1649
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
1650
1651

    @require_peft_version_greater(peft_version="0.6.2")
Aryan's avatar
Aryan committed
1652
1653
1654
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
1655
1656
1657
1658
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
Aryan's avatar
Aryan committed
1659
        for scheduler_cls in self.scheduler_classes:
1660
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1661
1662
1663
1664
1665
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1666
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1667
            self.assertTrue(output_no_lora.shape == self.output_shape)
1668

Aryan's avatar
Aryan committed
1669
1670
1671
1672
1673
1674
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
1675
1676
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1677
1678

            # Attach a second adapter
Aryan's avatar
Aryan committed
1679
1680
1681
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

Aryan's avatar
Aryan committed
1682
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1683

Aryan's avatar
Aryan committed
1684
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1685

1686
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1687
1688
1689
1690
1691
1692
1693
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1694
1695
1696

            # set them to multi-adapter inference mode
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1697
            outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1698
1699

            pipe.set_adapters(["adapter-1"])
Aryan's avatar
Aryan committed
1700
            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1701

Aryan's avatar
Aryan committed
1702
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1703
1704

            # Fusing should still keep the LoRA layers so outpout should remain the same
Aryan's avatar
Aryan committed
1705
            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1706
1707

            self.assertTrue(
Aryan's avatar
Aryan committed
1708
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1709
1710
1711
                "Fused lora should not change the output",
            )

Aryan's avatar
Aryan committed
1712
1713
1714
1715
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
            )
1716
1717

            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
1718
            output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1719
            self.assertTrue(
Aryan's avatar
Aryan committed
1720
                np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
1721
1722
1723
                "Fused lora should not change the output",
            )

1724
1725
    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
Aryan's avatar
Aryan committed
1726
        for scheduler_cls in self.scheduler_classes:
1727
1728
1729
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
                scheduler_cls, use_dora=True
            )
1730
1731
1732
1733
1734
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1735
            output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1736
            self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1737
1738
1739

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1740
1741
1742
1743

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1744

1745
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1746
1747
1748
1749
1750
1751
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1752

Aryan's avatar
Aryan committed
1753
            output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1754
1755
1756
1757
1758
1759

            self.assertFalse(
                np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
                "DoRA lora should change the output",
            )

1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
    def test_missing_keys_warning(self):
        scheduler_cls = self.scheduler_classes[0]
        # Skip text encoder check for now as that is handled with `transformers`.
        components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

        # To make things dynamic since we cannot settle with a single key for all the models where we
        # offer PEFT support.
        missing_key = [k for k in state_dict if "lora_A" in k][0]
        del state_dict[missing_key]

        logger = (
            logging.get_logger("diffusers.loaders.unet")
            if self.unet_kwargs is not None
1790
            else logging.get_logger("diffusers.loaders.peft")
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
        )
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        # Since the missing key won't contain the adapter name ("default_0").
        # Also strip out the component prefix (such as "unet." from `missing_key`).
        component = list({k.split(".")[0] for k in state_dict})[0]
        self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

    def test_unexpected_keys_warning(self):
        scheduler_cls = self.scheduler_classes[0]
        # Skip text encoder check for now as that is handled with `transformers`.
        components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

        unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
        state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)

        logger = (
            logging.get_logger("diffusers.loaders.unet")
            if self.unet_kwargs is not None
1829
            else logging.get_logger("diffusers.loaders.peft")
1830
1831
1832
1833
1834
1835
1836
        )
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        self.assertTrue(".diffusers_cat" in cap_logger.out)

1837
    @unittest.skip("This is failing for now - need to investigate")
1838
    def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1839
1840
1841
1842
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
1843
        for scheduler_cls in self.scheduler_classes:
1844
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1845
1846
1847
1848
1849
1850
1851
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1852
1853
1854
1855

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1856

1857
            if self.has_two_text_encoders or self.has_three_text_encoders:
1858
1859
1860
1861
1862
1863
1864
1865
                pipe.text_encoder_2.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)

1866
            if self.has_two_text_encoders or self.has_three_text_encoders:
1867
1868
1869
                pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)

            # Just makes sure it works..
Aryan's avatar
Aryan committed
1870
            _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1871
1872
1873
1874
1875
1876
1877

    def test_modify_padding_mode(self):
        def set_pad_mode(network, mode="circular"):
            for _, module in network.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    module.padding_mode = mode

Aryan's avatar
Aryan committed
1878
        for scheduler_cls in self.scheduler_classes:
1879
1880
1881
1882
1883
1884
1885
1886
1887
            components, _, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _pad_mode = "circular"
            set_pad_mode(pipe.vae, _pad_mode)
            set_pad_mode(pipe.unet, _pad_mode)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
1888
            _ = pipe(**inputs)[0]