utils.py 78.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sayak Paul's avatar
Sayak Paul committed
15
import inspect
16
17
18
import os
import tempfile
import unittest
UmerHA's avatar
UmerHA committed
19
from itertools import product
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

import numpy as np
import torch

from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
    LCMScheduler,
    UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
    floats_tensor,
    require_peft_backend,
    require_peft_version_greater,
35
    skip_mps,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    torch_device,
)


if is_peft_available():
    from peft import LoraConfig
    from peft.tuners.tuners_utils import BaseTunerLayer
    from peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


@require_peft_backend
class PeftLoraLoaderMixinTests:
    pipeline_class = None
Aryan's avatar
Aryan committed
71

72
73
    scheduler_cls = None
    scheduler_kwargs = None
Aryan's avatar
Aryan committed
74
    scheduler_classes = [DDIMScheduler, LCMScheduler]
Sayak Paul's avatar
Sayak Paul committed
75

76
    has_two_text_encoders = False
77
    has_three_text_encoders = False
Sayak Paul's avatar
Sayak Paul committed
78
79
80
81
82
83
84
    text_encoder_cls, text_encoder_id = None, None
    text_encoder_2_cls, text_encoder_2_id = None, None
    text_encoder_3_cls, text_encoder_3_id = None, None
    tokenizer_cls, tokenizer_id = None, None
    tokenizer_2_cls, tokenizer_2_id = None, None
    tokenizer_3_cls, tokenizer_3_id = None, None

85
    unet_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
86
    transformer_cls = None
87
    transformer_kwargs = None
Aryan's avatar
Aryan committed
88
    vae_cls = AutoencoderKL
89
90
    vae_kwargs = None

Aryan's avatar
Aryan committed
91
92
    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]

93
    def get_dummy_components(self, scheduler_cls=None, use_dora=False):
94
95
96
97
98
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

99
        scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
100
101
102
        rank = 4

        torch.manual_seed(0)
103
104
105
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
Sayak Paul's avatar
Sayak Paul committed
106
            transformer = self.transformer_cls(**self.transformer_kwargs)
107
108
109
110

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        torch.manual_seed(0)
Aryan's avatar
Aryan committed
111
        vae = self.vae_cls(**self.vae_kwargs)
112

Sayak Paul's avatar
Sayak Paul committed
113
114
        text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
115

Sayak Paul's avatar
Sayak Paul committed
116
117
118
        if self.text_encoder_2_cls is not None:
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
119

Sayak Paul's avatar
Sayak Paul committed
120
121
122
        if self.text_encoder_3_cls is not None:
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
123

124
125
126
        text_lora_config = LoraConfig(
            r=rank,
            lora_alpha=rank,
Aryan's avatar
Aryan committed
127
            target_modules=self.text_encoder_target_modules,
128
            init_lora_weights=False,
129
            use_dora=use_dora,
130
131
        )

132
        denoiser_lora_config = LoraConfig(
133
134
135
136
137
            r=rank,
            lora_alpha=rank,
            target_modules=["to_q", "to_k", "to_v", "to_out.0"],
            init_lora_weights=False,
            use_dora=use_dora,
138
139
        )

Sayak Paul's avatar
Sayak Paul committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})
166

167
        return pipeline_components, text_lora_config, denoiser_lora_config
168

Sayak Paul's avatar
Sayak Paul committed
169
170
171
172
    @property
    def output_shape(self):
        raise NotImplementedError

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = torch.manual_seed(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

194
    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
195
196
197
198
199
200
201
202
203
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    def _get_lora_state_dicts(self, modules_to_save):
        state_dicts = {}
        for module_name, module in modules_to_save.items():
            if module is not None:
                state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
        return state_dicts

    def _get_modules_to_save(self, pipe, has_denoiser=False):
        modules_to_save = {}
        lora_loadable_modules = self.pipeline_class._lora_loadable_modules

        if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
            modules_to_save["text_encoder"] = pipe.text_encoder

        if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
            modules_to_save["text_encoder_2"] = pipe.text_encoder_2

        if has_denoiser:
            if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
                modules_to_save["unet"] = pipe.unet

            if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
                modules_to_save["transformer"] = pipe.transformer

        return modules_to_save

230
231
232
233
    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
234
        for scheduler_cls in self.scheduler_classes:
235
236
237
238
239
240
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
241
            output_no_lora = pipe(**inputs)[0]
Sayak Paul's avatar
Sayak Paul committed
242
            self.assertTrue(output_no_lora.shape == self.output_shape)
243
244
245
246
247
248

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
249
        for scheduler_cls in self.scheduler_classes:
250
251
252
253
254
255
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
256
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
257
            self.assertTrue(output_no_lora.shape == self.output_shape)
258
259
260
261

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

262
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
263
264
265
266
267
268
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
269

Aryan's avatar
Aryan committed
270
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
271
272
273
274
275
276
277
278
279
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
280
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
Aryan's avatar
Aryan committed
281
282

        # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
Aryan's avatar
Aryan committed
283
284
285
286
287
288
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
289
        for scheduler_cls in self.scheduler_classes:
290
291
292
293
294
295
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
296
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
297
            self.assertTrue(output_no_lora.shape == self.output_shape)
298
299
300
301

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

302
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
303
304
305
306
307
308
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
309

Aryan's avatar
Aryan committed
310
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
311
312
313
314
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
315
316
317
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

318
319
320
321
322
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
323
324
325
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

326
327
328
329
330
331
332
333
334
335
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
336
        for scheduler_cls in self.scheduler_classes:
337
338
339
340
341
342
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
343
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
344
            self.assertTrue(output_no_lora.shape == self.output_shape)
345
346
347
348

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

349
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
350
351
352
353
354
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
355
356
357
358
359

            pipe.fuse_lora()
            # Fusing should still keep the LoRA layers
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

360
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
361
362
363
364
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
365

Aryan's avatar
Aryan committed
366
            ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
367
368
369
370
371
372
373
374
375
            self.assertFalse(
                np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
376
        for scheduler_cls in self.scheduler_classes:
377
378
379
380
381
382
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
383
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
384
            self.assertTrue(output_no_lora.shape == self.output_shape)
385

Aryan's avatar
Aryan committed
386
387
388
389
390
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
391

392
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
393
394
395
396
397
398
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
399
400
401
402
403
404
405

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )

406
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
407
408
409
410
411
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
412

Aryan's avatar
Aryan committed
413
            ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
414
415
416
417
418
419
420
421
422
            self.assertTrue(
                np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
Aryan's avatar
Aryan committed
423
        for scheduler_cls in self.scheduler_classes:
424
425
426
427
428
429
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
430
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
431
            self.assertTrue(output_no_lora.shape == self.output_shape)
432

Aryan's avatar
Aryan committed
433
434
435
436
437
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
438

439
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
440
441
442
443
444
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
445

Aryan's avatar
Aryan committed
446
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
447
448

            with tempfile.TemporaryDirectory() as tmpdirname:
449
450
                modules_to_save = self._get_modules_to_save(pipe)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
451

452
453
454
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )
Sayak Paul's avatar
Sayak Paul committed
455

456
457
458
459
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

460
461
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
462

463
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
464
465
466
467
468
469

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

470
471
472
473
474
475
    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
476
        for scheduler_cls in self.scheduler_classes:
477
            components, _, _ = self.get_dummy_components(scheduler_cls)
478
            # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
479
480
481
482
483
484
485
486
487
488
489
490
491
            text_lora_config = LoraConfig(
                r=4,
                rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
                lora_alpha=4,
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
                init_lora_weights=False,
                use_dora=False,
            )
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
492
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
493
            self.assertTrue(output_no_lora.shape == self.output_shape)
494
495
496
497
498
499
500
501
502
503
504

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
            # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
            # supports missing layers (PR#8324).
            state_dict = {
                f"text_encoder.{module_name}": param
                for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                if "text_model.encoder.layers.4" not in module_name
            }

505
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
506
507
508
509
510
511
512
513
514
515
516
517
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    state_dict.update(
                        {
                            f"text_encoder_2.{module_name}": param
                            for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                            if "text_model.encoder.layers.4" not in module_name
                        }
                    )
518

Aryan's avatar
Aryan committed
519
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
520
521
522
523
524
525
526
527
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            # Unload lora and load it back using the pipe.load_lora_weights machinery
            pipe.unload_lora_weights()
            pipe.load_lora_weights(state_dict)

Aryan's avatar
Aryan committed
528
            output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
529
530
531
532
533
            self.assertTrue(
                not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
                "Removing adapters should change the output",
            )

534
535
536
537
    def test_simple_inference_save_pretrained(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
Aryan's avatar
Aryan committed
538
        for scheduler_cls in self.scheduler_classes:
539
540
541
542
543
544
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
545
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
546
            self.assertTrue(output_no_lora.shape == self.output_shape)
547
548
549
550

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

551
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
552
553
554
555
556
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
557

Aryan's avatar
Aryan committed
558
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
559
560
561
562
563
564
565
566
567
568
569
570

            with tempfile.TemporaryDirectory() as tmpdirname:
                pipe.save_pretrained(tmpdirname)

                pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
                pipe_from_pretrained.to(torch_device)

            self.assertTrue(
                check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                "Lora not correctly set in text encoder",
            )

571
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
572
573
574
575
576
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
                    )
577

Aryan's avatar
Aryan committed
578
            images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
579
580
581
582
583
584

            self.assertTrue(
                np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

585
    def test_simple_inference_with_text_denoiser_lora_save_load(self):
586
587
588
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
Aryan's avatar
Aryan committed
589
        for scheduler_cls in self.scheduler_classes:
590
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
591
592
593
594
595
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
596
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
597
            self.assertTrue(output_no_lora.shape == self.output_shape)
598

Aryan's avatar
Aryan committed
599
600
601
602
603
604
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
605
606
607
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
608

609
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
610
611
612
613
614
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
615

Aryan's avatar
Aryan committed
616
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
617
618

            with tempfile.TemporaryDirectory() as tmpdirname:
619
620
621
622
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
Aryan's avatar
Aryan committed
623
                )
624

625
626
627
628
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

629
630
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
631

632
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
633
634
635
636
637
            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

638
    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
639
640
641
642
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
643
644
645
646
647
648
649
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
650
        for scheduler_cls in self.scheduler_classes:
651
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
652
653
654
655
656
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
657
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
658
            self.assertTrue(output_no_lora.shape == self.output_shape)
659

Aryan's avatar
Aryan committed
660
661
662
663
664
665
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
666
667
668
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
669

670
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
671
672
673
674
675
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
676

Aryan's avatar
Aryan committed
677
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
678
679
680
681
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
682
683
684
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

685
686
687
688
689
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
690
691
692
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

693
694
695
696
697
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

Aryan's avatar
Aryan committed
698
699
700
701
702
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                    "The scaling parameter has not been correctly restored!",
                )
703

704
    def test_simple_inference_with_text_lora_denoiser_fused(self):
705
706
707
708
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
Aryan's avatar
Aryan committed
709
        for scheduler_cls in self.scheduler_classes:
710
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
711
712
713
714
715
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
716
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
717
            self.assertTrue(output_no_lora.shape == self.output_shape)
718

Aryan's avatar
Aryan committed
719
720
721
722
723
724
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
725
726
727
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
728

729
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
730
731
732
733
734
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
735

Aryan's avatar
Aryan committed
736
737
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

738
            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
739
740
741
742
743
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
744
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
745

746
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
747
748
749
750
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
751

Aryan's avatar
Aryan committed
752
            output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
753
            self.assertFalse(
Aryan's avatar
Aryan committed
754
                np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
755
756
            )

757
    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
758
759
760
761
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
762
        for scheduler_cls in self.scheduler_classes:
763
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
764
765
766
767
768
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
769
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
770
            self.assertTrue(output_no_lora.shape == self.output_shape)
771

Aryan's avatar
Aryan committed
772
773
774
775
776
777
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
778
779
780
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
781

782
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
783
784
785
786
787
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
788
789
790
791
792
793

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )
Aryan's avatar
Aryan committed
794
            self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
795

796
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
797
798
799
800
801
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
802

Aryan's avatar
Aryan committed
803
            output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
804
            self.assertTrue(
Aryan's avatar
Aryan committed
805
                np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
806
807
808
                "Fused lora should change the output",
            )

Aryan's avatar
Aryan committed
809
810
811
    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
812
813
814
815
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
816
        for scheduler_cls in self.scheduler_classes:
817
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
818
819
820
821
822
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
823
824
825
826
827
828
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
829
830
831
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
832

833
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
834
835
836
837
838
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
839

Aryan's avatar
Aryan committed
840
841
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
842

Aryan's avatar
Aryan committed
843
844
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
845
846

            # unloading should remove the LoRA layers
Aryan's avatar
Aryan committed
847
848
849
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

Aryan's avatar
Aryan committed
850
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
851

852
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
853
854
855
856
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )
857
858
859

            # Fuse and unfuse should lead to the same results
            self.assertTrue(
Aryan's avatar
Aryan committed
860
861
                np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
862
863
            )

864
    def test_simple_inference_with_text_denoiser_multi_adapter(self):
865
866
867
868
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
869
        for scheduler_cls in self.scheduler_classes:
870
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
871
872
873
874
875
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
876
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
877

Aryan's avatar
Aryan committed
878
879
880
881
882
883
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
884

Aryan's avatar
Aryan committed
885
886
887
888
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
889

890
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
891
892
893
894
895
896
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
897
898

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
899
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
900
901
902
903
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
904
905

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
906
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
907
908
909
910
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
911
912

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
913
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
914
915
916
917
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
936
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
937
938
939
940
941
942

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
    def test_wrong_adapter_name_raises_error(self):
        scheduler_cls = self.scheduler_classes[0]
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

        with self.assertRaises(ValueError) as err_context:
            pipe.set_adapters("test")

        self.assertTrue("not in the list of present adapters" in str(err_context.exception))

        # test this works.
        pipe.set_adapters("adapter-1")
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

975
    def test_simple_inference_with_text_denoiser_block_scale(self):
UmerHA's avatar
UmerHA committed
976
977
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
Aryan's avatar
Aryan committed
978
        one adapter and set different weights for different blocks (i.e. block lora)
UmerHA's avatar
UmerHA committed
979
        """
Aryan's avatar
Aryan committed
980
        for scheduler_cls in self.scheduler_classes:
981
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
982
983
984
985
986
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
987
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
988
989
990

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
991
992
993
994

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
995

996
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
997
998
999
1000
1001
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1002
1003
1004

            weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
            pipe.set_adapters("adapter-1", weights_1)
Aryan's avatar
Aryan committed
1005
            output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1006
1007
1008

            weights_2 = {"unet": {"up": 5}}
            pipe.set_adapters("adapter-1", weights_2)
Aryan's avatar
Aryan committed
1009
            output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024

            self.assertFalse(
                np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
                "LoRA weights 1 and 2 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 1 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 2 should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1025
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1026
1027
1028
1029
1030
1031

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1032
    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
UmerHA's avatar
UmerHA committed
1033
1034
1035
1036
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set differnt weights for different blocks (i.e. block lora)
        """
Aryan's avatar
Aryan committed
1037
        for scheduler_cls in self.scheduler_classes:
1038
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1039
1040
1041
1042
1043
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1044
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1045

Aryan's avatar
Aryan committed
1046
1047
1048
1049
1050
1051
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
UmerHA's avatar
UmerHA committed
1052

Aryan's avatar
Aryan committed
1053
1054
1055
1056
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1057

1058
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1059
1060
1061
1062
1063
1064
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1065
1066
1067
1068

            scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
            scales_2 = {"unet": {"down": 5, "mid": 5}}

Aryan's avatar
Aryan committed
1069
1070
            pipe.set_adapters("adapter-1", scales_1)
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1071
1072

            pipe.set_adapters("adapter-2", scales_2)
Aryan's avatar
Aryan committed
1073
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1074
1075

            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
Aryan's avatar
Aryan committed
1076
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1095
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            # a mismatching number of adapter_names and adapter_weights should raise an error
            with self.assertRaises(ValueError):
                pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])

1106
    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
UmerHA's avatar
UmerHA committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

1172
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
UmerHA's avatar
UmerHA committed
1173
1174
1175
1176
1177
1178
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1179
1180
1181

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1182

1183
        if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1184
1185
1186
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1187
1188
1189
1190
1191
1192
1193
1194

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

1195
    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1196
1197
1198
1199
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
Aryan's avatar
Aryan committed
1200
        for scheduler_cls in self.scheduler_classes:
1201
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1202
1203
1204
1205
1206
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1207
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1208

Aryan's avatar
Aryan committed
1209
1210
1211
1212
1213
1214
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1215

Aryan's avatar
Aryan committed
1216
1217
1218
1219
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1220

1221
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1222
1223
1224
1225
1226
1227
1228
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1229
1230

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1231
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1232
1233

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1234
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1235
1236

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1237
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.delete_adapters("adapter-1")
Aryan's avatar
Aryan committed
1255
            output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1256
1257
1258
1259
1260
1261
1262

            self.assertTrue(
                np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            pipe.delete_adapters("adapter-2")
Aryan's avatar
Aryan committed
1263
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1264
1265
1266
1267
1268
1269

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

Aryan's avatar
Aryan committed
1270
1271
1272
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1273

Aryan's avatar
Aryan committed
1274
1275
1276
1277
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1278
1279
1280
1281

            pipe.set_adapters(["adapter-1", "adapter-2"])
            pipe.delete_adapters(["adapter-1", "adapter-2"])

Aryan's avatar
Aryan committed
1282
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1283
1284
1285
1286
1287
1288

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1289
    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1290
1291
1292
1293
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1294
        for scheduler_cls in self.scheduler_classes:
1295
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1296
1297
1298
1299
1300
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1301
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1302

Aryan's avatar
Aryan committed
1303
1304
1305
1306
1307
1308
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1309

Aryan's avatar
Aryan committed
1310
1311
1312
1313
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1314

1315
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1316
1317
1318
1319
1320
1321
1322
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1323
1324

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1325
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1326
1327

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1328
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1329
1330

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1331
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
Aryan's avatar
Aryan committed
1350
            output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
1351
1352
1353
1354
1355
1356
1357

            self.assertFalse(
                np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Weighted adapter and mixed adapter should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1358
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1359
1360
1361
1362
1363
1364

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1365
    @skip_mps
1366
    def test_lora_fuse_nan(self):
Aryan's avatar
Aryan committed
1367
        for scheduler_cls in self.scheduler_classes:
1368
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1369
1370
1371
1372
1373
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1374
1375
1376
1377
1378
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1379

Aryan's avatar
Aryan committed
1380
1381
1382
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1383
1384
1385

            # corrupt one LoRA weight with `inf` values
            with torch.no_grad():
1386
1387
1388
1389
1390
1391
                if self.unet_kwargs:
                    pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
                        "adapter-1"
                    ].weight += float("inf")
                else:
                    pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
1392
1393
1394

            # with `safe_fusing=True` we should see an Error
            with self.assertRaises(ValueError):
Aryan's avatar
Aryan committed
1395
                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
1396
1397

            # without we should not see an error, but every image will be black
Aryan's avatar
Aryan committed
1398
1399
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
            out = pipe("test", num_inference_steps=2, output_type="np")[0]
1400
1401
1402
1403
1404
1405
1406
1407

            self.assertTrue(np.isnan(out).all())

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1408
        for scheduler_cls in self.scheduler_classes:
1409
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1410
1411
1412
1413
1414
1415
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1416
1417
1418

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1419
1420
1421
1422
1423

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-1"])

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
Aryan's avatar
Aryan committed
1424
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-2"])

            pipe.set_adapters(["adapter-1", "adapter-2"])
            self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1437
        for scheduler_cls in self.scheduler_classes:
1438
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1439
1440
1441
1442
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

Aryan's avatar
Aryan committed
1443
1444
1445
1446
1447
1448
            # 1.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                dicts_to_be_checked = {"text_encoder": ["adapter-1"]}

1449
1450
1451
1452
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
                dicts_to_be_checked.update({"unet": ["adapter-1"]})
            else:
Aryan's avatar
Aryan committed
1453
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1454
                dicts_to_be_checked.update({"transformer": ["adapter-1"]})
1455

Aryan's avatar
Aryan committed
1456
1457
1458
1459
1460
1461
1462
1463
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 2.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1464
1465
1466
1467
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
Aryan's avatar
Aryan committed
1468
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1469
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
1470

Aryan's avatar
Aryan committed
1471
1472
1473
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 3.
1474
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1475
1476
1477
1478
1479

            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1480
1481
1482
1483
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
Aryan's avatar
Aryan committed
1484

1485
1486
            self.assertDictEqual(
                pipe.get_list_adapters(),
1487
                dicts_to_be_checked,
1488
1489
            )

Aryan's avatar
Aryan committed
1490
1491
1492
1493
1494
            # 4.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1495
            if self.unet_kwargs is not None:
Aryan's avatar
Aryan committed
1496
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
1497
1498
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
            else:
Aryan's avatar
Aryan committed
1499
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
1500
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
Aryan's avatar
Aryan committed
1501

1502
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
1503
1504

    @require_peft_version_greater(peft_version="0.6.2")
Aryan's avatar
Aryan committed
1505
1506
1507
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
1508
1509
1510
1511
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
Aryan's avatar
Aryan committed
1512
        for scheduler_cls in self.scheduler_classes:
1513
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1514
1515
1516
1517
1518
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1519
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1520
            self.assertTrue(output_no_lora.shape == self.output_shape)
1521

Aryan's avatar
Aryan committed
1522
1523
1524
1525
1526
1527
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
1528
1529
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1530
1531

            # Attach a second adapter
Aryan's avatar
Aryan committed
1532
1533
1534
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

Aryan's avatar
Aryan committed
1535
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1536

Aryan's avatar
Aryan committed
1537
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1538

1539
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1540
1541
1542
1543
1544
1545
1546
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1547
1548
1549

            # set them to multi-adapter inference mode
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1550
            outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1551
1552

            pipe.set_adapters(["adapter-1"])
Aryan's avatar
Aryan committed
1553
            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1554

Aryan's avatar
Aryan committed
1555
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1556
1557

            # Fusing should still keep the LoRA layers so outpout should remain the same
Aryan's avatar
Aryan committed
1558
            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1559
1560

            self.assertTrue(
Aryan's avatar
Aryan committed
1561
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1562
1563
1564
                "Fused lora should not change the output",
            )

Aryan's avatar
Aryan committed
1565
1566
1567
1568
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
            )
1569
1570

            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
1571
            output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1572
            self.assertTrue(
Aryan's avatar
Aryan committed
1573
                np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
1574
1575
1576
                "Fused lora should not change the output",
            )

1577
1578
    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
Aryan's avatar
Aryan committed
1579
        for scheduler_cls in self.scheduler_classes:
1580
1581
1582
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
                scheduler_cls, use_dora=True
            )
1583
1584
1585
1586
1587
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1588
            output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1589
            self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1590
1591
1592

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1593
1594
1595
1596

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1597

1598
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1599
1600
1601
1602
1603
1604
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1605

Aryan's avatar
Aryan committed
1606
            output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1607
1608
1609
1610
1611
1612

            self.assertFalse(
                np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
                "DoRA lora should change the output",
            )

1613
    @unittest.skip("This is failing for now - need to investigate")
1614
    def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1615
1616
1617
1618
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
1619
        for scheduler_cls in self.scheduler_classes:
1620
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1621
1622
1623
1624
1625
1626
1627
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1628
1629
1630
1631

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1632

1633
            if self.has_two_text_encoders or self.has_three_text_encoders:
1634
1635
1636
1637
1638
1639
1640
1641
                pipe.text_encoder_2.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)

1642
            if self.has_two_text_encoders or self.has_three_text_encoders:
1643
1644
1645
                pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)

            # Just makes sure it works..
Aryan's avatar
Aryan committed
1646
            _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1647
1648
1649
1650
1651
1652
1653

    def test_modify_padding_mode(self):
        def set_pad_mode(network, mode="circular"):
            for _, module in network.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    module.padding_mode = mode

Aryan's avatar
Aryan committed
1654
        for scheduler_cls in self.scheduler_classes:
1655
1656
1657
1658
1659
1660
1661
1662
1663
            components, _, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _pad_mode = "circular"
            set_pad_mode(pipe.vae, _pad_mode)
            set_pad_mode(pipe.unet, _pad_mode)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
1664
            _ = pipe(**inputs)[0]