"tests/vscode:/vscode.git/clone" did not exist on "5b5db28ddb215e85a8d57530fcd19b28150cc3d1"
utils.py 80.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sayak Paul's avatar
Sayak Paul committed
15
import inspect
16
17
18
import os
import tempfile
import unittest
UmerHA's avatar
UmerHA committed
19
from itertools import product
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

import numpy as np
import torch

from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
    LCMScheduler,
    UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
    floats_tensor,
    require_peft_backend,
    require_peft_version_greater,
35
    skip_mps,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    torch_device,
)


if is_peft_available():
    from peft import LoraConfig
    from peft.tuners.tuners_utils import BaseTunerLayer
    from peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


@require_peft_backend
class PeftLoraLoaderMixinTests:
    pipeline_class = None
Aryan's avatar
Aryan committed
71

72
73
    scheduler_cls = None
    scheduler_kwargs = None
Aryan's avatar
Aryan committed
74
    scheduler_classes = [DDIMScheduler, LCMScheduler]
Sayak Paul's avatar
Sayak Paul committed
75

76
    has_two_text_encoders = False
77
    has_three_text_encoders = False
Sayak Paul's avatar
Sayak Paul committed
78
79
80
81
82
83
84
    text_encoder_cls, text_encoder_id = None, None
    text_encoder_2_cls, text_encoder_2_id = None, None
    text_encoder_3_cls, text_encoder_3_id = None, None
    tokenizer_cls, tokenizer_id = None, None
    tokenizer_2_cls, tokenizer_2_id = None, None
    tokenizer_3_cls, tokenizer_3_id = None, None

85
    unet_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
86
    transformer_cls = None
87
    transformer_kwargs = None
Aryan's avatar
Aryan committed
88
    vae_cls = AutoencoderKL
89
90
    vae_kwargs = None

Aryan's avatar
Aryan committed
91
92
    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]

93
    def get_dummy_components(self, scheduler_cls=None, use_dora=False):
94
95
96
97
98
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

99
        scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
100
101
102
        rank = 4

        torch.manual_seed(0)
103
104
105
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
Sayak Paul's avatar
Sayak Paul committed
106
            transformer = self.transformer_cls(**self.transformer_kwargs)
107
108
109
110

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        torch.manual_seed(0)
Aryan's avatar
Aryan committed
111
        vae = self.vae_cls(**self.vae_kwargs)
112

Sayak Paul's avatar
Sayak Paul committed
113
114
        text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
115

Sayak Paul's avatar
Sayak Paul committed
116
117
118
        if self.text_encoder_2_cls is not None:
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
119

Sayak Paul's avatar
Sayak Paul committed
120
121
122
        if self.text_encoder_3_cls is not None:
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
123

124
125
126
        text_lora_config = LoraConfig(
            r=rank,
            lora_alpha=rank,
Aryan's avatar
Aryan committed
127
            target_modules=self.text_encoder_target_modules,
128
            init_lora_weights=False,
129
            use_dora=use_dora,
130
131
        )

132
        denoiser_lora_config = LoraConfig(
133
134
135
136
137
            r=rank,
            lora_alpha=rank,
            target_modules=["to_q", "to_k", "to_v", "to_out.0"],
            init_lora_weights=False,
            use_dora=use_dora,
138
139
        )

Sayak Paul's avatar
Sayak Paul committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})
166

167
        return pipeline_components, text_lora_config, denoiser_lora_config
168

Sayak Paul's avatar
Sayak Paul committed
169
170
171
172
    @property
    def output_shape(self):
        raise NotImplementedError

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = torch.manual_seed(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

194
    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
195
196
197
198
199
200
201
202
203
204
205
206
207
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
208
        for scheduler_cls in self.scheduler_classes:
209
210
211
212
213
214
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
215
            output_no_lora = pipe(**inputs)[0]
Sayak Paul's avatar
Sayak Paul committed
216
            self.assertTrue(output_no_lora.shape == self.output_shape)
217
218
219
220
221
222

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
223
        for scheduler_cls in self.scheduler_classes:
224
225
226
227
228
229
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
230
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
231
            self.assertTrue(output_no_lora.shape == self.output_shape)
232
233
234
235

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

236
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
237
238
239
240
241
242
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
243

Aryan's avatar
Aryan committed
244
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
245
246
247
248
249
250
251
252
253
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
254
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
Aryan's avatar
Aryan committed
255
256

        # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
Aryan's avatar
Aryan committed
257
258
259
260
261
262
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
263
        for scheduler_cls in self.scheduler_classes:
264
265
266
267
268
269
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
270
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
271
            self.assertTrue(output_no_lora.shape == self.output_shape)
272
273
274
275

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

276
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
277
278
279
280
281
282
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
283

Aryan's avatar
Aryan committed
284
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
285
286
287
288
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
289
290
291
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

292
293
294
295
296
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
297
298
299
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

300
301
302
303
304
305
306
307
308
309
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
310
        for scheduler_cls in self.scheduler_classes:
311
312
313
314
315
316
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
317
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
318
            self.assertTrue(output_no_lora.shape == self.output_shape)
319
320
321
322

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

323
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
324
325
326
327
328
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
329
330
331
332
333

            pipe.fuse_lora()
            # Fusing should still keep the LoRA layers
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

334
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
335
336
337
338
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
339

Aryan's avatar
Aryan committed
340
            ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
341
342
343
344
345
346
347
348
349
            self.assertFalse(
                np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
350
        for scheduler_cls in self.scheduler_classes:
351
352
353
354
355
356
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
357
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
358
            self.assertTrue(output_no_lora.shape == self.output_shape)
359

Aryan's avatar
Aryan committed
360
361
362
363
364
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
365

366
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
367
368
369
370
371
372
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
373
374
375
376
377
378
379

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )

380
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
381
382
383
384
385
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
386

Aryan's avatar
Aryan committed
387
            ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
388
389
390
391
392
393
394
395
396
            self.assertTrue(
                np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
Aryan's avatar
Aryan committed
397
        for scheduler_cls in self.scheduler_classes:
398
399
400
401
402
403
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
404
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
405
            self.assertTrue(output_no_lora.shape == self.output_shape)
406

Aryan's avatar
Aryan committed
407
408
409
410
411
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
412

413
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
414
415
416
417
418
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
419

Aryan's avatar
Aryan committed
420
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
421
422
423

            with tempfile.TemporaryDirectory() as tmpdirname:
                text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
424
                if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
425
426
                    if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                        text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
427

Sayak Paul's avatar
Sayak Paul committed
428
429
430
431
432
433
                        self.pipeline_class.save_lora_weights(
                            save_directory=tmpdirname,
                            text_encoder_lora_layers=text_encoder_state_dict,
                            text_encoder_2_lora_layers=text_encoder_2_state_dict,
                            safe_serialization=False,
                        )
434
435
436
437
438
439
440
                else:
                    self.pipeline_class.save_lora_weights(
                        save_directory=tmpdirname,
                        text_encoder_lora_layers=text_encoder_state_dict,
                        safe_serialization=False,
                    )

Sayak Paul's avatar
Sayak Paul committed
441
442
443
444
445
446
447
448
                if self.has_two_text_encoders:
                    if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
                        self.pipeline_class.save_lora_weights(
                            save_directory=tmpdirname,
                            text_encoder_lora_layers=text_encoder_state_dict,
                            safe_serialization=False,
                        )

449
450
451
452
453
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()

                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

Aryan's avatar
Aryan committed
454
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
455
456
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

457
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
458
459
460
461
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
462
463
464
465
466
467

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

468
469
470
471
472
473
    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
474
        for scheduler_cls in self.scheduler_classes:
475
            components, _, _ = self.get_dummy_components(scheduler_cls)
476
            # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
477
478
479
480
481
482
483
484
485
486
487
488
489
            text_lora_config = LoraConfig(
                r=4,
                rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
                lora_alpha=4,
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
                init_lora_weights=False,
                use_dora=False,
            )
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
490
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
491
            self.assertTrue(output_no_lora.shape == self.output_shape)
492
493
494
495
496
497
498
499
500
501
502

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
            # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
            # supports missing layers (PR#8324).
            state_dict = {
                f"text_encoder.{module_name}": param
                for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                if "text_model.encoder.layers.4" not in module_name
            }

503
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
504
505
506
507
508
509
510
511
512
513
514
515
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    state_dict.update(
                        {
                            f"text_encoder_2.{module_name}": param
                            for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                            if "text_model.encoder.layers.4" not in module_name
                        }
                    )
516

Aryan's avatar
Aryan committed
517
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
518
519
520
521
522
523
524
525
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            # Unload lora and load it back using the pipe.load_lora_weights machinery
            pipe.unload_lora_weights()
            pipe.load_lora_weights(state_dict)

Aryan's avatar
Aryan committed
526
            output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
527
528
529
530
531
            self.assertTrue(
                not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
                "Removing adapters should change the output",
            )

532
533
534
535
    def test_simple_inference_save_pretrained(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
Aryan's avatar
Aryan committed
536
        for scheduler_cls in self.scheduler_classes:
537
538
539
540
541
542
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
543
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
544
            self.assertTrue(output_no_lora.shape == self.output_shape)
545
546
547
548

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

549
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
550
551
552
553
554
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
555

Aryan's avatar
Aryan committed
556
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
557
558
559
560
561
562
563
564
565
566
567
568

            with tempfile.TemporaryDirectory() as tmpdirname:
                pipe.save_pretrained(tmpdirname)

                pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
                pipe_from_pretrained.to(torch_device)

            self.assertTrue(
                check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                "Lora not correctly set in text encoder",
            )

569
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
570
571
572
573
574
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
                    )
575

Aryan's avatar
Aryan committed
576
            images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
577
578
579
580
581
582

            self.assertTrue(
                np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

583
    def test_simple_inference_with_text_denoiser_lora_save_load(self):
584
585
586
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
Aryan's avatar
Aryan committed
587
        for scheduler_cls in self.scheduler_classes:
588
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
589
590
591
592
593
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
594
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
595
            self.assertTrue(output_no_lora.shape == self.output_shape)
596

Aryan's avatar
Aryan committed
597
598
599
600
601
602
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
603
604
605
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
606

607
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
608
609
610
611
612
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
613

Aryan's avatar
Aryan committed
614
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
615
616

            with tempfile.TemporaryDirectory() as tmpdirname:
Aryan's avatar
Aryan committed
617
618
619
620
621
                text_encoder_state_dict = (
                    get_peft_model_state_dict(pipe.text_encoder)
                    if "text_encoder" in self.pipeline_class._lora_loadable_modules
                    else None
                )
622

Aryan's avatar
Aryan committed
623
                denoiser_state_dict = get_peft_model_state_dict(denoiser)
624

Sayak Paul's avatar
Sayak Paul committed
625
626
627
628
                saving_kwargs = {
                    "save_directory": tmpdirname,
                    "safe_serialization": False,
                }
629

Aryan's avatar
Aryan committed
630
631
632
                if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                    saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict})

Sayak Paul's avatar
Sayak Paul committed
633
634
                if self.unet_kwargs is not None:
                    saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
635
                else:
Sayak Paul's avatar
Sayak Paul committed
636
637
638
639
640
641
642
643
                    saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})

                if self.has_two_text_encoders or self.has_three_text_encoders:
                    if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                        text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
                        saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})

                self.pipeline_class.save_lora_weights(**saving_kwargs)
644
645
646
647
648
649

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()

                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

Aryan's avatar
Aryan committed
650
651
652
653
654
655
656
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
657
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
658

659
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
660
661
662
663
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
664
665
666
667
668
669

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

670
    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
671
672
673
674
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
675
676
677
678
679
680
681
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
682
        for scheduler_cls in self.scheduler_classes:
683
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
684
685
686
687
688
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
689
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
690
            self.assertTrue(output_no_lora.shape == self.output_shape)
691

Aryan's avatar
Aryan committed
692
693
694
695
696
697
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
698
699
700
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
701

702
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
703
704
705
706
707
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
708

Aryan's avatar
Aryan committed
709
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
710
711
712
713
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
714
715
716
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

717
718
719
720
721
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
722
723
724
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

725
726
727
728
729
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

Aryan's avatar
Aryan committed
730
731
732
733
734
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                    "The scaling parameter has not been correctly restored!",
                )
735

736
    def test_simple_inference_with_text_lora_denoiser_fused(self):
737
738
739
740
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
Aryan's avatar
Aryan committed
741
        for scheduler_cls in self.scheduler_classes:
742
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
743
744
745
746
747
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
748
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
749
            self.assertTrue(output_no_lora.shape == self.output_shape)
750

Aryan's avatar
Aryan committed
751
752
753
754
755
756
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
757
758
759
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
760

761
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
762
763
764
765
766
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
767

Aryan's avatar
Aryan committed
768
769
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

770
            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
771
772
773
774
775
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
776
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
777

778
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
779
780
781
782
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
783

Aryan's avatar
Aryan committed
784
            output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
785
            self.assertFalse(
Aryan's avatar
Aryan committed
786
                np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
787
788
            )

789
    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
790
791
792
793
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
794
        for scheduler_cls in self.scheduler_classes:
795
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
796
797
798
799
800
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
801
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
802
            self.assertTrue(output_no_lora.shape == self.output_shape)
803

Aryan's avatar
Aryan committed
804
805
806
807
808
809
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
810
811
812
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
813

814
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
815
816
817
818
819
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
820
821
822
823
824
825

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )
Aryan's avatar
Aryan committed
826
            self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
827

828
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
829
830
831
832
833
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
834

Aryan's avatar
Aryan committed
835
            output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
836
            self.assertTrue(
Aryan's avatar
Aryan committed
837
                np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
838
839
840
                "Fused lora should change the output",
            )

Aryan's avatar
Aryan committed
841
842
843
    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
844
845
846
847
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
848
        for scheduler_cls in self.scheduler_classes:
849
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
850
851
852
853
854
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
855
856
857
858
859
860
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
861
862
863
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
864

865
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
866
867
868
869
870
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
871

Aryan's avatar
Aryan committed
872
873
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
874

Aryan's avatar
Aryan committed
875
876
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
877
878

            # unloading should remove the LoRA layers
Aryan's avatar
Aryan committed
879
880
881
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

Aryan's avatar
Aryan committed
882
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
883

884
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
885
886
887
888
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )
889
890
891

            # Fuse and unfuse should lead to the same results
            self.assertTrue(
Aryan's avatar
Aryan committed
892
893
                np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
894
895
            )

896
    def test_simple_inference_with_text_denoiser_multi_adapter(self):
897
898
899
900
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
901
        for scheduler_cls in self.scheduler_classes:
902
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
903
904
905
906
907
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
908
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
909

Aryan's avatar
Aryan committed
910
911
912
913
914
915
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
916

Aryan's avatar
Aryan committed
917
918
919
920
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
921

922
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
923
924
925
926
927
928
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
929
930

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
931
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
932
933
934
935
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
936
937

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
938
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
939
940
941
942
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
943
944

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
945
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
946
947
948
949
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
968
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
969
970
971
972
973
974

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
    def test_wrong_adapter_name_raises_error(self):
        scheduler_cls = self.scheduler_classes[0]
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

        with self.assertRaises(ValueError) as err_context:
            pipe.set_adapters("test")

        self.assertTrue("not in the list of present adapters" in str(err_context.exception))

        # test this works.
        pipe.set_adapters("adapter-1")
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

1007
    def test_simple_inference_with_text_denoiser_block_scale(self):
UmerHA's avatar
UmerHA committed
1008
1009
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
Aryan's avatar
Aryan committed
1010
        one adapter and set different weights for different blocks (i.e. block lora)
UmerHA's avatar
UmerHA committed
1011
        """
Aryan's avatar
Aryan committed
1012
        for scheduler_cls in self.scheduler_classes:
1013
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1014
1015
1016
1017
1018
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1019
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1020
1021
1022

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1023
1024
1025
1026

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1027

1028
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1029
1030
1031
1032
1033
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1034
1035
1036

            weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
            pipe.set_adapters("adapter-1", weights_1)
Aryan's avatar
Aryan committed
1037
            output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1038
1039
1040

            weights_2 = {"unet": {"up": 5}}
            pipe.set_adapters("adapter-1", weights_2)
Aryan's avatar
Aryan committed
1041
            output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056

            self.assertFalse(
                np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
                "LoRA weights 1 and 2 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 1 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 2 should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1057
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1058
1059
1060
1061
1062
1063

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1064
    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
UmerHA's avatar
UmerHA committed
1065
1066
1067
1068
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set differnt weights for different blocks (i.e. block lora)
        """
Aryan's avatar
Aryan committed
1069
        for scheduler_cls in self.scheduler_classes:
1070
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1071
1072
1073
1074
1075
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1076
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1077

Aryan's avatar
Aryan committed
1078
1079
1080
1081
1082
1083
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
UmerHA's avatar
UmerHA committed
1084

Aryan's avatar
Aryan committed
1085
1086
1087
1088
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1089

1090
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1091
1092
1093
1094
1095
1096
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1097
1098
1099
1100

            scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
            scales_2 = {"unet": {"down": 5, "mid": 5}}

Aryan's avatar
Aryan committed
1101
1102
            pipe.set_adapters("adapter-1", scales_1)
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1103
1104

            pipe.set_adapters("adapter-2", scales_2)
Aryan's avatar
Aryan committed
1105
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1106
1107

            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
Aryan's avatar
Aryan committed
1108
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1127
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            # a mismatching number of adapter_names and adapter_weights should raise an error
            with self.assertRaises(ValueError):
                pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])

1138
    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
UmerHA's avatar
UmerHA committed
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

1204
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
UmerHA's avatar
UmerHA committed
1205
1206
1207
1208
1209
1210
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1211
1212
1213

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1214

1215
        if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1216
1217
1218
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1219
1220
1221
1222
1223
1224
1225
1226

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

1227
    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1228
1229
1230
1231
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
Aryan's avatar
Aryan committed
1232
        for scheduler_cls in self.scheduler_classes:
1233
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1234
1235
1236
1237
1238
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1239
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1240

Aryan's avatar
Aryan committed
1241
1242
1243
1244
1245
1246
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1247

Aryan's avatar
Aryan committed
1248
1249
1250
1251
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1252

1253
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1254
1255
1256
1257
1258
1259
1260
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1261
1262

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1263
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1264
1265

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1266
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1267
1268

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1269
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.delete_adapters("adapter-1")
Aryan's avatar
Aryan committed
1287
            output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1288
1289
1290
1291
1292
1293
1294

            self.assertTrue(
                np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            pipe.delete_adapters("adapter-2")
Aryan's avatar
Aryan committed
1295
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1296
1297
1298
1299
1300
1301

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

Aryan's avatar
Aryan committed
1302
1303
1304
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1305

Aryan's avatar
Aryan committed
1306
1307
1308
1309
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1310
1311
1312
1313

            pipe.set_adapters(["adapter-1", "adapter-2"])
            pipe.delete_adapters(["adapter-1", "adapter-2"])

Aryan's avatar
Aryan committed
1314
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1315
1316
1317
1318
1319
1320

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1321
    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1322
1323
1324
1325
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1326
        for scheduler_cls in self.scheduler_classes:
1327
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1328
1329
1330
1331
1332
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1333
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1334

Aryan's avatar
Aryan committed
1335
1336
1337
1338
1339
1340
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1341

Aryan's avatar
Aryan committed
1342
1343
1344
1345
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1346

1347
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1348
1349
1350
1351
1352
1353
1354
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1355
1356

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1357
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1358
1359

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1360
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1361
1362

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1363
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
Aryan's avatar
Aryan committed
1382
            output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
1383
1384
1385
1386
1387
1388
1389

            self.assertFalse(
                np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Weighted adapter and mixed adapter should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1390
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1391
1392
1393
1394
1395
1396

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1397
    @skip_mps
1398
    def test_lora_fuse_nan(self):
Aryan's avatar
Aryan committed
1399
        for scheduler_cls in self.scheduler_classes:
1400
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1401
1402
1403
1404
1405
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1406
1407
1408
1409
1410
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1411

Aryan's avatar
Aryan committed
1412
1413
1414
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1415
1416
1417

            # corrupt one LoRA weight with `inf` values
            with torch.no_grad():
1418
1419
1420
1421
1422
1423
                if self.unet_kwargs:
                    pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
                        "adapter-1"
                    ].weight += float("inf")
                else:
                    pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
1424
1425
1426

            # with `safe_fusing=True` we should see an Error
            with self.assertRaises(ValueError):
Aryan's avatar
Aryan committed
1427
                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
1428
1429

            # without we should not see an error, but every image will be black
Aryan's avatar
Aryan committed
1430
1431
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
            out = pipe("test", num_inference_steps=2, output_type="np")[0]
1432
1433
1434
1435
1436
1437
1438
1439

            self.assertTrue(np.isnan(out).all())

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1440
        for scheduler_cls in self.scheduler_classes:
1441
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1442
1443
1444
1445
1446
1447
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1448
1449
1450

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1451
1452
1453
1454
1455

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-1"])

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
Aryan's avatar
Aryan committed
1456
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-2"])

            pipe.set_adapters(["adapter-1", "adapter-2"])
            self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1469
        for scheduler_cls in self.scheduler_classes:
1470
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1471
1472
1473
1474
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

Aryan's avatar
Aryan committed
1475
1476
1477
1478
1479
1480
            # 1.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                dicts_to_be_checked = {"text_encoder": ["adapter-1"]}

1481
1482
1483
1484
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
                dicts_to_be_checked.update({"unet": ["adapter-1"]})
            else:
Aryan's avatar
Aryan committed
1485
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1486
                dicts_to_be_checked.update({"transformer": ["adapter-1"]})
1487

Aryan's avatar
Aryan committed
1488
1489
1490
1491
1492
1493
1494
1495
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 2.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1496
1497
1498
1499
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
Aryan's avatar
Aryan committed
1500
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1501
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
1502

Aryan's avatar
Aryan committed
1503
1504
1505
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 3.
1506
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1507
1508
1509
1510
1511

            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1512
1513
1514
1515
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
Aryan's avatar
Aryan committed
1516

1517
1518
            self.assertDictEqual(
                pipe.get_list_adapters(),
1519
                dicts_to_be_checked,
1520
1521
            )

Aryan's avatar
Aryan committed
1522
1523
1524
1525
1526
            # 4.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1527
            if self.unet_kwargs is not None:
Aryan's avatar
Aryan committed
1528
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
1529
1530
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
            else:
Aryan's avatar
Aryan committed
1531
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
1532
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
Aryan's avatar
Aryan committed
1533

1534
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
1535
1536

    @require_peft_version_greater(peft_version="0.6.2")
Aryan's avatar
Aryan committed
1537
1538
1539
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
1540
1541
1542
1543
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
Aryan's avatar
Aryan committed
1544
        for scheduler_cls in self.scheduler_classes:
1545
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1546
1547
1548
1549
1550
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1551
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1552
            self.assertTrue(output_no_lora.shape == self.output_shape)
1553

Aryan's avatar
Aryan committed
1554
1555
1556
1557
1558
1559
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
1560
1561
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1562
1563

            # Attach a second adapter
Aryan's avatar
Aryan committed
1564
1565
1566
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

Aryan's avatar
Aryan committed
1567
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1568

Aryan's avatar
Aryan committed
1569
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1570

1571
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1572
1573
1574
1575
1576
1577
1578
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1579
1580
1581

            # set them to multi-adapter inference mode
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1582
            outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1583
1584

            pipe.set_adapters(["adapter-1"])
Aryan's avatar
Aryan committed
1585
            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1586

Aryan's avatar
Aryan committed
1587
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1588
1589

            # Fusing should still keep the LoRA layers so outpout should remain the same
Aryan's avatar
Aryan committed
1590
            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1591
1592

            self.assertTrue(
Aryan's avatar
Aryan committed
1593
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1594
1595
1596
                "Fused lora should not change the output",
            )

Aryan's avatar
Aryan committed
1597
1598
1599
1600
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
            )
1601
1602

            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
1603
            output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1604
            self.assertTrue(
Aryan's avatar
Aryan committed
1605
                np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
1606
1607
1608
                "Fused lora should not change the output",
            )

1609
1610
    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
Aryan's avatar
Aryan committed
1611
        for scheduler_cls in self.scheduler_classes:
1612
1613
1614
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
                scheduler_cls, use_dora=True
            )
1615
1616
1617
1618
1619
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1620
            output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1621
            self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1622
1623
1624

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1625
1626
1627
1628

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1629

1630
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1631
1632
1633
1634
1635
1636
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1637

Aryan's avatar
Aryan committed
1638
            output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1639
1640
1641
1642
1643
1644

            self.assertFalse(
                np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
                "DoRA lora should change the output",
            )

1645
    @unittest.skip("This is failing for now - need to investigate")
1646
    def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1647
1648
1649
1650
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
1651
        for scheduler_cls in self.scheduler_classes:
1652
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1653
1654
1655
1656
1657
1658
1659
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1660
1661
1662
1663

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1664

1665
            if self.has_two_text_encoders or self.has_three_text_encoders:
1666
1667
1668
1669
1670
1671
1672
1673
                pipe.text_encoder_2.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)

1674
            if self.has_two_text_encoders or self.has_three_text_encoders:
1675
1676
1677
                pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)

            # Just makes sure it works..
Aryan's avatar
Aryan committed
1678
            _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1679
1680
1681
1682
1683
1684
1685

    def test_modify_padding_mode(self):
        def set_pad_mode(network, mode="circular"):
            for _, module in network.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    module.padding_mode = mode

Aryan's avatar
Aryan committed
1686
        for scheduler_cls in self.scheduler_classes:
1687
1688
1689
1690
1691
1692
1693
1694
1695
            components, _, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _pad_mode = "circular"
            set_pad_mode(pipe.vae, _pad_mode)
            set_pad_mode(pipe.unet, _pad_mode)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
1696
            _ = pipe(**inputs)[0]