utils.py 90 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sayak Paul's avatar
Sayak Paul committed
15
import inspect
16
17
18
import os
import tempfile
import unittest
UmerHA's avatar
UmerHA committed
19
from itertools import product
20
21

import numpy as np
22
import pytest
23
24
25
26
27
28
29
30
import torch

from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
    LCMScheduler,
    UNet2DConditionModel,
)
31
from diffusers.utils import logging
32
33
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
34
    CaptureLogger,
35
    floats_tensor,
36
    is_torch_version,
37
38
    require_peft_backend,
    require_peft_version_greater,
39
    require_transformers_version_greater,
40
    skip_mps,
41
42
43
44
45
    torch_device,
)


if is_peft_available():
46
    from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    from peft.tuners.tuners_utils import BaseTunerLayer
    from peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


73
74
75
76
77
78
def initialize_dummy_state_dict(state_dict):
    if not all(v.device.type == "meta" for _, v in state_dict.items()):
        raise ValueError("`state_dict` has non-meta values.")
    return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}


79
80
81
@require_peft_backend
class PeftLoraLoaderMixinTests:
    pipeline_class = None
Aryan's avatar
Aryan committed
82

83
84
    scheduler_cls = None
    scheduler_kwargs = None
Aryan's avatar
Aryan committed
85
    scheduler_classes = [DDIMScheduler, LCMScheduler]
Sayak Paul's avatar
Sayak Paul committed
86

87
    has_two_text_encoders = False
88
    has_three_text_encoders = False
Sayak Paul's avatar
Sayak Paul committed
89
90
91
92
93
94
95
    text_encoder_cls, text_encoder_id = None, None
    text_encoder_2_cls, text_encoder_2_id = None, None
    text_encoder_3_cls, text_encoder_3_id = None, None
    tokenizer_cls, tokenizer_id = None, None
    tokenizer_2_cls, tokenizer_2_id = None, None
    tokenizer_3_cls, tokenizer_3_id = None, None

96
    unet_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
97
    transformer_cls = None
98
    transformer_kwargs = None
Aryan's avatar
Aryan committed
99
    vae_cls = AutoencoderKL
100
101
    vae_kwargs = None

Aryan's avatar
Aryan committed
102
103
    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]

104
    def get_dummy_components(self, scheduler_cls=None, use_dora=False):
105
106
107
108
109
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

110
        scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
111
112
113
        rank = 4

        torch.manual_seed(0)
114
115
116
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
Sayak Paul's avatar
Sayak Paul committed
117
            transformer = self.transformer_cls(**self.transformer_kwargs)
118
119
120
121

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        torch.manual_seed(0)
Aryan's avatar
Aryan committed
122
        vae = self.vae_cls(**self.vae_kwargs)
123

Sayak Paul's avatar
Sayak Paul committed
124
125
        text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
126

Sayak Paul's avatar
Sayak Paul committed
127
128
129
        if self.text_encoder_2_cls is not None:
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
130

Sayak Paul's avatar
Sayak Paul committed
131
132
133
        if self.text_encoder_3_cls is not None:
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
134

135
136
137
        text_lora_config = LoraConfig(
            r=rank,
            lora_alpha=rank,
Aryan's avatar
Aryan committed
138
            target_modules=self.text_encoder_target_modules,
139
            init_lora_weights=False,
140
            use_dora=use_dora,
141
142
        )

143
        denoiser_lora_config = LoraConfig(
144
145
146
147
148
            r=rank,
            lora_alpha=rank,
            target_modules=["to_q", "to_k", "to_v", "to_out.0"],
            init_lora_weights=False,
            use_dora=use_dora,
149
150
        )

Sayak Paul's avatar
Sayak Paul committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})
177

178
        return pipeline_components, text_lora_config, denoiser_lora_config
179

Sayak Paul's avatar
Sayak Paul committed
180
181
182
183
    @property
    def output_shape(self):
        raise NotImplementedError

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = torch.manual_seed(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

205
    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
206
207
208
209
210
211
212
213
214
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

215
216
217
218
219
220
221
222
223
224
225
    def _get_lora_state_dicts(self, modules_to_save):
        state_dicts = {}
        for module_name, module in modules_to_save.items():
            if module is not None:
                state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
        return state_dicts

    def _get_modules_to_save(self, pipe, has_denoiser=False):
        modules_to_save = {}
        lora_loadable_modules = self.pipeline_class._lora_loadable_modules

226
227
228
229
230
        if (
            "text_encoder" in lora_loadable_modules
            and hasattr(pipe, "text_encoder")
            and getattr(pipe.text_encoder, "peft_config", None) is not None
        ):
231
232
            modules_to_save["text_encoder"] = pipe.text_encoder

233
234
235
236
237
        if (
            "text_encoder_2" in lora_loadable_modules
            and hasattr(pipe, "text_encoder_2")
            and getattr(pipe.text_encoder_2, "peft_config", None) is not None
        ):
238
239
240
241
242
243
244
245
246
247
248
            modules_to_save["text_encoder_2"] = pipe.text_encoder_2

        if has_denoiser:
            if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
                modules_to_save["unet"] = pipe.unet

            if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
                modules_to_save["transformer"] = pipe.transformer

        return modules_to_save

249
250
251
252
    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
253
        for scheduler_cls in self.scheduler_classes:
254
255
256
257
258
259
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
260
            output_no_lora = pipe(**inputs)[0]
Sayak Paul's avatar
Sayak Paul committed
261
            self.assertTrue(output_no_lora.shape == self.output_shape)
262
263
264
265
266
267

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
268
        for scheduler_cls in self.scheduler_classes:
269
270
271
272
273
274
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
275
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
276
            self.assertTrue(output_no_lora.shape == self.output_shape)
277
278
279
280

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

281
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
282
283
284
285
286
287
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
288

Aryan's avatar
Aryan committed
289
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
290
291
292
293
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    @require_peft_version_greater("0.13.1")
    def test_low_cpu_mem_usage_with_injection(self):
        """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
                )
                self.assertTrue(
                    "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
                    "The LoRA params should be on 'meta' device.",
                )

                te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
                set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
                self.assertTrue(
                    "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
                    "No param should be on 'meta' device.",
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
            self.assertTrue(
                "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
            )

            denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
            set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
            self.assertTrue(
                "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
            )

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    self.assertTrue(
                        "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
                        "The LoRA params should be on 'meta' device.",
                    )

                    te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
                    set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
                    self.assertTrue(
                        "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
                        "No param should be on 'meta' device.",
                    )

            _, _, inputs = self.get_dummy_inputs()
            output_lora = pipe(**inputs)[0]
            self.assertTrue(output_lora.shape == self.output_shape)

    @require_peft_version_greater("0.13.1")
356
    @require_transformers_version_greater("4.45.2")
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    def test_low_cpu_mem_usage_with_loading(self):
        """Tests if we can load LoRA state dict with low_cpu_mem_usage."""

        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

            with tempfile.TemporaryDirectory() as tmpdirname:
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)

                for module_name, module in modules_to_save.items():
                    self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

                images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
                self.assertTrue(
                    np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                    "Loading from saved checkpoints should give same results.",
                )

                # Now, check for `low_cpu_mem_usage.`
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)

                for module_name, module in modules_to_save.items():
                    self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

                images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
                self.assertTrue(
                    np.allclose(
                        images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
                    ),
                    "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
                )

424
425
426
427
428
    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
429
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
Aryan's avatar
Aryan committed
430
431

        # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
Aryan's avatar
Aryan committed
432
433
434
435
436
437
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
438
        for scheduler_cls in self.scheduler_classes:
439
440
441
442
443
444
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
445
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
446
            self.assertTrue(output_no_lora.shape == self.output_shape)
447
448
449
450

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

451
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
452
453
454
455
456
457
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
458

Aryan's avatar
Aryan committed
459
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
460
461
462
463
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
464
465
466
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

467
468
469
470
471
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
472
473
474
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

475
476
477
478
479
480
481
482
483
484
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
485
        for scheduler_cls in self.scheduler_classes:
486
487
488
489
490
491
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
492
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
493
            self.assertTrue(output_no_lora.shape == self.output_shape)
494
495
496
497

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

498
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
499
500
501
502
503
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
504
505
506
507
508

            pipe.fuse_lora()
            # Fusing should still keep the LoRA layers
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

509
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
510
511
512
513
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
514

Aryan's avatar
Aryan committed
515
            ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
516
517
518
519
520
521
522
523
524
            self.assertFalse(
                np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
525
        for scheduler_cls in self.scheduler_classes:
526
527
528
529
530
531
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
532
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
533
            self.assertTrue(output_no_lora.shape == self.output_shape)
534

Aryan's avatar
Aryan committed
535
536
537
538
539
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
540

541
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
542
543
544
545
546
547
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
548
549
550
551
552
553
554

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )

555
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
556
557
558
559
560
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
561

Aryan's avatar
Aryan committed
562
            ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
563
564
565
566
567
568
569
570
571
            self.assertTrue(
                np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
Aryan's avatar
Aryan committed
572
        for scheduler_cls in self.scheduler_classes:
573
574
575
576
577
578
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
579
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
580
            self.assertTrue(output_no_lora.shape == self.output_shape)
581

Aryan's avatar
Aryan committed
582
583
584
585
586
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
587

588
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
589
590
591
592
593
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
594

Aryan's avatar
Aryan committed
595
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
596
597

            with tempfile.TemporaryDirectory() as tmpdirname:
598
599
                modules_to_save = self._get_modules_to_save(pipe)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
600

601
602
603
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )
Sayak Paul's avatar
Sayak Paul committed
604

605
606
607
608
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

609
610
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
611

612
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
613
614
615
616
617
618

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

619
620
621
622
623
624
    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
625
        for scheduler_cls in self.scheduler_classes:
626
            components, _, _ = self.get_dummy_components(scheduler_cls)
627
            # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
628
629
630
631
632
633
634
635
636
637
638
639
640
            text_lora_config = LoraConfig(
                r=4,
                rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
                lora_alpha=4,
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
                init_lora_weights=False,
                use_dora=False,
            )
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
641
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
642
            self.assertTrue(output_no_lora.shape == self.output_shape)
643
644
645
646
647
648
649
650
651
652
653

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
            # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
            # supports missing layers (PR#8324).
            state_dict = {
                f"text_encoder.{module_name}": param
                for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                if "text_model.encoder.layers.4" not in module_name
            }

654
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
655
656
657
658
659
660
661
662
663
664
665
666
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    state_dict.update(
                        {
                            f"text_encoder_2.{module_name}": param
                            for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                            if "text_model.encoder.layers.4" not in module_name
                        }
                    )
667

Aryan's avatar
Aryan committed
668
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
669
670
671
672
673
674
675
676
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            # Unload lora and load it back using the pipe.load_lora_weights machinery
            pipe.unload_lora_weights()
            pipe.load_lora_weights(state_dict)

Aryan's avatar
Aryan committed
677
            output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
678
679
680
681
682
            self.assertTrue(
                not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
                "Removing adapters should change the output",
            )

683
684
685
686
    def test_simple_inference_save_pretrained(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
Aryan's avatar
Aryan committed
687
        for scheduler_cls in self.scheduler_classes:
688
689
690
691
692
693
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
694
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
695
            self.assertTrue(output_no_lora.shape == self.output_shape)
696
697
698
699

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

700
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
701
702
703
704
705
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
706

Aryan's avatar
Aryan committed
707
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
708
709
710
711
712
713
714
715
716
717
718
719

            with tempfile.TemporaryDirectory() as tmpdirname:
                pipe.save_pretrained(tmpdirname)

                pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
                pipe_from_pretrained.to(torch_device)

            self.assertTrue(
                check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                "Lora not correctly set in text encoder",
            )

720
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
721
722
723
724
725
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
                    )
726

Aryan's avatar
Aryan committed
727
            images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
728
729
730
731
732
733

            self.assertTrue(
                np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

734
    def test_simple_inference_with_text_denoiser_lora_save_load(self):
735
736
737
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
Aryan's avatar
Aryan committed
738
        for scheduler_cls in self.scheduler_classes:
739
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
740
741
742
743
744
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
745
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
746
            self.assertTrue(output_no_lora.shape == self.output_shape)
747

Aryan's avatar
Aryan committed
748
749
750
751
752
753
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
754
755
756
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
757

758
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
759
760
761
762
763
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
764

Aryan's avatar
Aryan committed
765
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
766
767

            with tempfile.TemporaryDirectory() as tmpdirname:
768
769
770
771
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
Aryan's avatar
Aryan committed
772
                )
773

774
775
776
777
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

778
779
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
780

781
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
782
783
784
785
786
            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

787
    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
788
789
790
791
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
792
793
794
795
796
797
798
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
799
        for scheduler_cls in self.scheduler_classes:
800
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
801
802
803
804
805
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
806
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
807
            self.assertTrue(output_no_lora.shape == self.output_shape)
808

Aryan's avatar
Aryan committed
809
810
811
812
813
814
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
815
816
817
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
818

819
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
820
821
822
823
824
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
825

Aryan's avatar
Aryan committed
826
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
827
828
829
830
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
831
832
833
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

834
835
836
837
838
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
839
840
841
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

842
843
844
845
846
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

Aryan's avatar
Aryan committed
847
848
849
850
851
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                    "The scaling parameter has not been correctly restored!",
                )
852

853
    def test_simple_inference_with_text_lora_denoiser_fused(self):
854
855
856
857
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
Aryan's avatar
Aryan committed
858
        for scheduler_cls in self.scheduler_classes:
859
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
860
861
862
863
864
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
865
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
866
            self.assertTrue(output_no_lora.shape == self.output_shape)
867

Aryan's avatar
Aryan committed
868
869
870
871
872
873
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
874
875
876
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
877

878
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
879
880
881
882
883
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
884

Aryan's avatar
Aryan committed
885
886
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

887
            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
888
889
890
891
892
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
893
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
894

895
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
896
897
898
899
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
900

Aryan's avatar
Aryan committed
901
            output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
902
            self.assertFalse(
Aryan's avatar
Aryan committed
903
                np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
904
905
            )

906
    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
907
908
909
910
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
911
        for scheduler_cls in self.scheduler_classes:
912
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
913
914
915
916
917
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
918
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
919
            self.assertTrue(output_no_lora.shape == self.output_shape)
920

Aryan's avatar
Aryan committed
921
922
923
924
925
926
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
927
928
929
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
930

931
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
932
933
934
935
936
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
937
938
939
940
941
942

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )
Aryan's avatar
Aryan committed
943
            self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
944

945
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
946
947
948
949
950
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
951

Aryan's avatar
Aryan committed
952
            output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
953
            self.assertTrue(
Aryan's avatar
Aryan committed
954
                np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
955
956
957
                "Fused lora should change the output",
            )

Aryan's avatar
Aryan committed
958
959
960
    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
961
962
963
964
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
965
        for scheduler_cls in self.scheduler_classes:
966
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
967
968
969
970
971
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
972
973
974
975
976
977
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
978
979
980
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
981

982
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
983
984
985
986
987
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
988

Aryan's avatar
Aryan committed
989
990
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
991

Aryan's avatar
Aryan committed
992
993
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
994
995

            # unloading should remove the LoRA layers
Aryan's avatar
Aryan committed
996
997
998
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

Aryan's avatar
Aryan committed
999
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
1000

1001
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1002
1003
1004
1005
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )
1006
1007
1008

            # Fuse and unfuse should lead to the same results
            self.assertTrue(
Aryan's avatar
Aryan committed
1009
1010
                np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
1011
1012
            )

1013
    def test_simple_inference_with_text_denoiser_multi_adapter(self):
1014
1015
1016
1017
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1018
        for scheduler_cls in self.scheduler_classes:
1019
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1020
1021
1022
1023
1024
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1025
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1026

Aryan's avatar
Aryan committed
1027
1028
1029
1030
1031
1032
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1033

Aryan's avatar
Aryan committed
1034
1035
1036
1037
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1038

1039
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1040
1041
1042
1043
1044
1045
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1046
1047

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1048
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1049
1050
1051
1052
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1053
1054

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1055
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1056
1057
1058
1059
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1060
1061

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1062
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1063
1064
1065
1066
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1085
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1086
1087
1088
1089
1090
1091

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
    def test_wrong_adapter_name_raises_error(self):
        scheduler_cls = self.scheduler_classes[0]
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

        with self.assertRaises(ValueError) as err_context:
            pipe.set_adapters("test")

        self.assertTrue("not in the list of present adapters" in str(err_context.exception))

        # test this works.
        pipe.set_adapters("adapter-1")
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

1124
    def test_simple_inference_with_text_denoiser_block_scale(self):
UmerHA's avatar
UmerHA committed
1125
1126
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
Aryan's avatar
Aryan committed
1127
        one adapter and set different weights for different blocks (i.e. block lora)
UmerHA's avatar
UmerHA committed
1128
        """
Aryan's avatar
Aryan committed
1129
        for scheduler_cls in self.scheduler_classes:
1130
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1131
1132
1133
1134
1135
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1136
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1137
1138
1139

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1140
1141
1142
1143

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1144

1145
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1146
1147
1148
1149
1150
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1151
1152
1153

            weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
            pipe.set_adapters("adapter-1", weights_1)
Aryan's avatar
Aryan committed
1154
            output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1155
1156
1157

            weights_2 = {"unet": {"up": 5}}
            pipe.set_adapters("adapter-1", weights_2)
Aryan's avatar
Aryan committed
1158
            output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173

            self.assertFalse(
                np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
                "LoRA weights 1 and 2 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 1 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 2 should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1174
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1175
1176
1177
1178
1179
1180

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1181
    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
UmerHA's avatar
UmerHA committed
1182
1183
1184
1185
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set differnt weights for different blocks (i.e. block lora)
        """
Aryan's avatar
Aryan committed
1186
        for scheduler_cls in self.scheduler_classes:
1187
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1188
1189
1190
1191
1192
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1193
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1194

Aryan's avatar
Aryan committed
1195
1196
1197
1198
1199
1200
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
UmerHA's avatar
UmerHA committed
1201

Aryan's avatar
Aryan committed
1202
1203
1204
1205
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1206

1207
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1208
1209
1210
1211
1212
1213
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1214
1215
1216
1217

            scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
            scales_2 = {"unet": {"down": 5, "mid": 5}}

Aryan's avatar
Aryan committed
1218
1219
            pipe.set_adapters("adapter-1", scales_1)
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1220
1221

            pipe.set_adapters("adapter-2", scales_2)
Aryan's avatar
Aryan committed
1222
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1223
1224

            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
Aryan's avatar
Aryan committed
1225
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1244
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            # a mismatching number of adapter_names and adapter_weights should raise an error
            with self.assertRaises(ValueError):
                pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])

1255
    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
UmerHA's avatar
UmerHA committed
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

1321
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
UmerHA's avatar
UmerHA committed
1322
1323
1324
1325
1326
1327
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1328
1329
1330

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1331

1332
        if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1333
1334
1335
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1336
1337
1338
1339
1340
1341
1342
1343

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

1344
    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1345
1346
1347
1348
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
Aryan's avatar
Aryan committed
1349
        for scheduler_cls in self.scheduler_classes:
1350
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1351
1352
1353
1354
1355
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1356
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1357

Aryan's avatar
Aryan committed
1358
1359
1360
1361
1362
1363
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1364

Aryan's avatar
Aryan committed
1365
1366
1367
1368
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1369

1370
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1371
1372
1373
1374
1375
1376
1377
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1378
1379

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1380
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1381
1382

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1383
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1384
1385

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1386
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.delete_adapters("adapter-1")
Aryan's avatar
Aryan committed
1404
            output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1405
1406
1407
1408
1409
1410
1411

            self.assertTrue(
                np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            pipe.delete_adapters("adapter-2")
Aryan's avatar
Aryan committed
1412
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1413
1414
1415
1416
1417
1418

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

Aryan's avatar
Aryan committed
1419
1420
1421
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1422

Aryan's avatar
Aryan committed
1423
1424
1425
1426
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1427
1428
1429
1430

            pipe.set_adapters(["adapter-1", "adapter-2"])
            pipe.delete_adapters(["adapter-1", "adapter-2"])

Aryan's avatar
Aryan committed
1431
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1432
1433
1434
1435
1436
1437

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1438
    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1439
1440
1441
1442
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1443
        for scheduler_cls in self.scheduler_classes:
1444
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1445
1446
1447
1448
1449
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1450
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1451

Aryan's avatar
Aryan committed
1452
1453
1454
1455
1456
1457
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1458

Aryan's avatar
Aryan committed
1459
1460
1461
1462
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1463

1464
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1465
1466
1467
1468
1469
1470
1471
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1472
1473

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1474
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1475
1476

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1477
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1478
1479

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1480
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
Aryan's avatar
Aryan committed
1499
            output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
1500
1501
1502
1503
1504
1505
1506

            self.assertFalse(
                np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Weighted adapter and mixed adapter should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1507
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1508
1509
1510
1511
1512
1513

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1514
    @skip_mps
1515
1516
1517
1518
1519
    @pytest.mark.xfail(
        condtion=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
        reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
        strict=True,
    )
1520
    def test_lora_fuse_nan(self):
Aryan's avatar
Aryan committed
1521
        for scheduler_cls in self.scheduler_classes:
1522
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1523
1524
1525
1526
1527
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1528
1529
1530
1531
1532
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1533

Aryan's avatar
Aryan committed
1534
1535
1536
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1537
1538
1539

            # corrupt one LoRA weight with `inf` values
            with torch.no_grad():
1540
1541
1542
1543
1544
1545
                if self.unet_kwargs:
                    pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
                        "adapter-1"
                    ].weight += float("inf")
                else:
                    pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
1546
1547
1548

            # with `safe_fusing=True` we should see an Error
            with self.assertRaises(ValueError):
Aryan's avatar
Aryan committed
1549
                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
1550
1551

            # without we should not see an error, but every image will be black
Aryan's avatar
Aryan committed
1552
1553
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
            out = pipe("test", num_inference_steps=2, output_type="np")[0]
1554
1555
1556
1557
1558
1559
1560
1561

            self.assertTrue(np.isnan(out).all())

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1562
        for scheduler_cls in self.scheduler_classes:
1563
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1564
1565
1566
1567
1568
1569
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1570
1571
1572

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1573
1574
1575
1576
1577

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-1"])

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
Aryan's avatar
Aryan committed
1578
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-2"])

            pipe.set_adapters(["adapter-1", "adapter-2"])
            self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1591
        for scheduler_cls in self.scheduler_classes:
1592
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1593
1594
1595
1596
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

Aryan's avatar
Aryan committed
1597
1598
1599
1600
1601
1602
            # 1.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                dicts_to_be_checked = {"text_encoder": ["adapter-1"]}

1603
1604
1605
1606
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
                dicts_to_be_checked.update({"unet": ["adapter-1"]})
            else:
Aryan's avatar
Aryan committed
1607
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1608
                dicts_to_be_checked.update({"transformer": ["adapter-1"]})
1609

Aryan's avatar
Aryan committed
1610
1611
1612
1613
1614
1615
1616
1617
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 2.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1618
1619
1620
1621
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
Aryan's avatar
Aryan committed
1622
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1623
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
1624

Aryan's avatar
Aryan committed
1625
1626
1627
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 3.
1628
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1629
1630
1631
1632
1633

            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1634
1635
1636
1637
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
Aryan's avatar
Aryan committed
1638

1639
1640
            self.assertDictEqual(
                pipe.get_list_adapters(),
1641
                dicts_to_be_checked,
1642
1643
            )

Aryan's avatar
Aryan committed
1644
1645
1646
1647
1648
            # 4.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1649
            if self.unet_kwargs is not None:
Aryan's avatar
Aryan committed
1650
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
1651
1652
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
            else:
Aryan's avatar
Aryan committed
1653
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
1654
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
Aryan's avatar
Aryan committed
1655

1656
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
1657
1658

    @require_peft_version_greater(peft_version="0.6.2")
Aryan's avatar
Aryan committed
1659
1660
1661
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
1662
1663
1664
1665
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
Aryan's avatar
Aryan committed
1666
        for scheduler_cls in self.scheduler_classes:
1667
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1668
1669
1670
1671
1672
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1673
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1674
            self.assertTrue(output_no_lora.shape == self.output_shape)
1675

Aryan's avatar
Aryan committed
1676
1677
1678
1679
1680
1681
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
1682
1683
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1684
1685

            # Attach a second adapter
Aryan's avatar
Aryan committed
1686
1687
1688
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

Aryan's avatar
Aryan committed
1689
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1690

Aryan's avatar
Aryan committed
1691
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1692

1693
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1694
1695
1696
1697
1698
1699
1700
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1701
1702
1703

            # set them to multi-adapter inference mode
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1704
            outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1705
1706

            pipe.set_adapters(["adapter-1"])
Aryan's avatar
Aryan committed
1707
            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1708

Aryan's avatar
Aryan committed
1709
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1710
1711

            # Fusing should still keep the LoRA layers so outpout should remain the same
Aryan's avatar
Aryan committed
1712
            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1713
1714

            self.assertTrue(
Aryan's avatar
Aryan committed
1715
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1716
1717
1718
                "Fused lora should not change the output",
            )

Aryan's avatar
Aryan committed
1719
1720
1721
1722
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
            )
1723
1724

            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
1725
            output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1726
            self.assertTrue(
Aryan's avatar
Aryan committed
1727
                np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
1728
1729
1730
                "Fused lora should not change the output",
            )

1731
1732
    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
Aryan's avatar
Aryan committed
1733
        for scheduler_cls in self.scheduler_classes:
1734
1735
1736
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
                scheduler_cls, use_dora=True
            )
1737
1738
1739
1740
1741
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1742
            output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1743
            self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1744
1745
1746

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1747
1748
1749
1750

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1751

1752
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1753
1754
1755
1756
1757
1758
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1759

Aryan's avatar
Aryan committed
1760
            output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1761
1762
1763
1764
1765
1766

            self.assertFalse(
                np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
                "DoRA lora should change the output",
            )

1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
    def test_missing_keys_warning(self):
        scheduler_cls = self.scheduler_classes[0]
        # Skip text encoder check for now as that is handled with `transformers`.
        components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

        # To make things dynamic since we cannot settle with a single key for all the models where we
        # offer PEFT support.
        missing_key = [k for k in state_dict if "lora_A" in k][0]
        del state_dict[missing_key]

1794
        logger = logging.get_logger("diffusers.loaders.peft")
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        # Since the missing key won't contain the adapter name ("default_0").
        # Also strip out the component prefix (such as "unet." from `missing_key`).
        component = list({k.split(".")[0] for k in state_dict})[0]
        self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

    def test_unexpected_keys_warning(self):
        scheduler_cls = self.scheduler_classes[0]
        # Skip text encoder check for now as that is handled with `transformers`.
        components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

        unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
        state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)

1829
        logger = logging.get_logger("diffusers.loaders.peft")
1830
1831
1832
1833
1834
1835
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        self.assertTrue(".diffusers_cat" in cap_logger.out)

1836
    @unittest.skip("This is failing for now - need to investigate")
1837
    def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1838
1839
1840
1841
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
1842
        for scheduler_cls in self.scheduler_classes:
1843
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1844
1845
1846
1847
1848
1849
1850
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1851
1852
1853
1854

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1855

1856
            if self.has_two_text_encoders or self.has_three_text_encoders:
1857
1858
1859
1860
1861
1862
1863
1864
                pipe.text_encoder_2.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)

1865
            if self.has_two_text_encoders or self.has_three_text_encoders:
1866
1867
1868
                pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)

            # Just makes sure it works..
Aryan's avatar
Aryan committed
1869
            _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1870
1871
1872
1873
1874
1875
1876

    def test_modify_padding_mode(self):
        def set_pad_mode(network, mode="circular"):
            for _, module in network.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    module.padding_mode = mode

Aryan's avatar
Aryan committed
1877
        for scheduler_cls in self.scheduler_classes:
1878
1879
1880
1881
1882
1883
1884
1885
1886
            components, _, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _pad_mode = "circular"
            set_pad_mode(pipe.vae, _pad_mode)
            set_pad_mode(pipe.unet, _pad_mode)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
1887
            _ = pipe(**inputs)[0]