utils.py 85.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sayak Paul's avatar
Sayak Paul committed
15
import inspect
16
17
18
import os
import tempfile
import unittest
UmerHA's avatar
UmerHA committed
19
from itertools import product
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

import numpy as np
import torch

from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
    LCMScheduler,
    UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
    floats_tensor,
    require_peft_backend,
    require_peft_version_greater,
35
    require_transformers_version_greater,
36
    skip_mps,
37
38
39
40
41
    torch_device,
)


if is_peft_available():
42
    from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    from peft.tuners.tuners_utils import BaseTunerLayer
    from peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


69
70
71
72
73
74
def initialize_dummy_state_dict(state_dict):
    if not all(v.device.type == "meta" for _, v in state_dict.items()):
        raise ValueError("`state_dict` has non-meta values.")
    return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}


75
76
77
@require_peft_backend
class PeftLoraLoaderMixinTests:
    pipeline_class = None
Aryan's avatar
Aryan committed
78

79
80
    scheduler_cls = None
    scheduler_kwargs = None
Aryan's avatar
Aryan committed
81
    scheduler_classes = [DDIMScheduler, LCMScheduler]
Sayak Paul's avatar
Sayak Paul committed
82

83
    has_two_text_encoders = False
84
    has_three_text_encoders = False
Sayak Paul's avatar
Sayak Paul committed
85
86
87
88
89
90
91
    text_encoder_cls, text_encoder_id = None, None
    text_encoder_2_cls, text_encoder_2_id = None, None
    text_encoder_3_cls, text_encoder_3_id = None, None
    tokenizer_cls, tokenizer_id = None, None
    tokenizer_2_cls, tokenizer_2_id = None, None
    tokenizer_3_cls, tokenizer_3_id = None, None

92
    unet_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
93
    transformer_cls = None
94
    transformer_kwargs = None
Aryan's avatar
Aryan committed
95
    vae_cls = AutoencoderKL
96
97
    vae_kwargs = None

Aryan's avatar
Aryan committed
98
99
    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]

100
    def get_dummy_components(self, scheduler_cls=None, use_dora=False):
101
102
103
104
105
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

106
        scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
107
108
109
        rank = 4

        torch.manual_seed(0)
110
111
112
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
Sayak Paul's avatar
Sayak Paul committed
113
            transformer = self.transformer_cls(**self.transformer_kwargs)
114
115
116
117

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        torch.manual_seed(0)
Aryan's avatar
Aryan committed
118
        vae = self.vae_cls(**self.vae_kwargs)
119

Sayak Paul's avatar
Sayak Paul committed
120
121
        text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
122

Sayak Paul's avatar
Sayak Paul committed
123
124
125
        if self.text_encoder_2_cls is not None:
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
126

Sayak Paul's avatar
Sayak Paul committed
127
128
129
        if self.text_encoder_3_cls is not None:
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
130

131
132
133
        text_lora_config = LoraConfig(
            r=rank,
            lora_alpha=rank,
Aryan's avatar
Aryan committed
134
            target_modules=self.text_encoder_target_modules,
135
            init_lora_weights=False,
136
            use_dora=use_dora,
137
138
        )

139
        denoiser_lora_config = LoraConfig(
140
141
142
143
144
            r=rank,
            lora_alpha=rank,
            target_modules=["to_q", "to_k", "to_v", "to_out.0"],
            init_lora_weights=False,
            use_dora=use_dora,
145
146
        )

Sayak Paul's avatar
Sayak Paul committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})
173

174
        return pipeline_components, text_lora_config, denoiser_lora_config
175

Sayak Paul's avatar
Sayak Paul committed
176
177
178
179
    @property
    def output_shape(self):
        raise NotImplementedError

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = torch.manual_seed(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

201
    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
202
203
204
205
206
207
208
209
210
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    def _get_lora_state_dicts(self, modules_to_save):
        state_dicts = {}
        for module_name, module in modules_to_save.items():
            if module is not None:
                state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
        return state_dicts

    def _get_modules_to_save(self, pipe, has_denoiser=False):
        modules_to_save = {}
        lora_loadable_modules = self.pipeline_class._lora_loadable_modules

        if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
            modules_to_save["text_encoder"] = pipe.text_encoder

        if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
            modules_to_save["text_encoder_2"] = pipe.text_encoder_2

        if has_denoiser:
            if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
                modules_to_save["unet"] = pipe.unet

            if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
                modules_to_save["transformer"] = pipe.transformer

        return modules_to_save

237
238
239
240
    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
241
        for scheduler_cls in self.scheduler_classes:
242
243
244
245
246
247
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
248
            output_no_lora = pipe(**inputs)[0]
Sayak Paul's avatar
Sayak Paul committed
249
            self.assertTrue(output_no_lora.shape == self.output_shape)
250
251
252
253
254
255

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
256
        for scheduler_cls in self.scheduler_classes:
257
258
259
260
261
262
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
263
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
264
            self.assertTrue(output_no_lora.shape == self.output_shape)
265
266
267
268

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

269
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
270
271
272
273
274
275
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
276

Aryan's avatar
Aryan committed
277
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
278
279
280
281
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    @require_peft_version_greater("0.13.1")
    def test_low_cpu_mem_usage_with_injection(self):
        """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
                )
                self.assertTrue(
                    "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
                    "The LoRA params should be on 'meta' device.",
                )

                te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
                set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
                self.assertTrue(
                    "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
                    "No param should be on 'meta' device.",
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
            self.assertTrue(
                "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
            )

            denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
            set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
            self.assertTrue(
                "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
            )

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    self.assertTrue(
                        "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
                        "The LoRA params should be on 'meta' device.",
                    )

                    te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
                    set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
                    self.assertTrue(
                        "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
                        "No param should be on 'meta' device.",
                    )

            _, _, inputs = self.get_dummy_inputs()
            output_lora = pipe(**inputs)[0]
            self.assertTrue(output_lora.shape == self.output_shape)

    @require_peft_version_greater("0.13.1")
344
    @require_transformers_version_greater("4.45.2")
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    def test_low_cpu_mem_usage_with_loading(self):
        """Tests if we can load LoRA state dict with low_cpu_mem_usage."""

        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

            with tempfile.TemporaryDirectory() as tmpdirname:
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)

                for module_name, module in modules_to_save.items():
                    self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

                images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
                self.assertTrue(
                    np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                    "Loading from saved checkpoints should give same results.",
                )

                # Now, check for `low_cpu_mem_usage.`
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)

                for module_name, module in modules_to_save.items():
                    self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

                images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
                self.assertTrue(
                    np.allclose(
                        images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
                    ),
                    "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
                )

412
413
414
415
416
    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
417
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
Aryan's avatar
Aryan committed
418
419

        # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
Aryan's avatar
Aryan committed
420
421
422
423
424
425
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
426
        for scheduler_cls in self.scheduler_classes:
427
428
429
430
431
432
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
433
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
434
            self.assertTrue(output_no_lora.shape == self.output_shape)
435
436
437
438

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

439
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
440
441
442
443
444
445
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
446

Aryan's avatar
Aryan committed
447
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
448
449
450
451
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
452
453
454
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

455
456
457
458
459
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
460
461
462
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

463
464
465
466
467
468
469
470
471
472
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
473
        for scheduler_cls in self.scheduler_classes:
474
475
476
477
478
479
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
480
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
481
            self.assertTrue(output_no_lora.shape == self.output_shape)
482
483
484
485

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

486
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
487
488
489
490
491
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
492
493
494
495
496

            pipe.fuse_lora()
            # Fusing should still keep the LoRA layers
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

497
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
498
499
500
501
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
502

Aryan's avatar
Aryan committed
503
            ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
504
505
506
507
508
509
510
511
512
            self.assertFalse(
                np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
513
        for scheduler_cls in self.scheduler_classes:
514
515
516
517
518
519
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
520
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
521
            self.assertTrue(output_no_lora.shape == self.output_shape)
522

Aryan's avatar
Aryan committed
523
524
525
526
527
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
528

529
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
530
531
532
533
534
535
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
536
537
538
539
540
541
542

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )

543
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
544
545
546
547
548
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
549

Aryan's avatar
Aryan committed
550
            ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
551
552
553
554
555
556
557
558
559
            self.assertTrue(
                np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
Aryan's avatar
Aryan committed
560
        for scheduler_cls in self.scheduler_classes:
561
562
563
564
565
566
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
567
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
568
            self.assertTrue(output_no_lora.shape == self.output_shape)
569

Aryan's avatar
Aryan committed
570
571
572
573
574
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
575

576
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
577
578
579
580
581
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
582

Aryan's avatar
Aryan committed
583
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
584
585

            with tempfile.TemporaryDirectory() as tmpdirname:
586
587
                modules_to_save = self._get_modules_to_save(pipe)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
588

589
590
591
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )
Sayak Paul's avatar
Sayak Paul committed
592

593
594
595
596
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

597
598
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
599

600
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
601
602
603
604
605
606

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

607
608
609
610
611
612
    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
613
        for scheduler_cls in self.scheduler_classes:
614
            components, _, _ = self.get_dummy_components(scheduler_cls)
615
            # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
616
617
618
619
620
621
622
623
624
625
626
627
628
            text_lora_config = LoraConfig(
                r=4,
                rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
                lora_alpha=4,
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
                init_lora_weights=False,
                use_dora=False,
            )
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
629
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
630
            self.assertTrue(output_no_lora.shape == self.output_shape)
631
632
633
634
635
636
637
638
639
640
641

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
            # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
            # supports missing layers (PR#8324).
            state_dict = {
                f"text_encoder.{module_name}": param
                for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                if "text_model.encoder.layers.4" not in module_name
            }

642
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
643
644
645
646
647
648
649
650
651
652
653
654
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    state_dict.update(
                        {
                            f"text_encoder_2.{module_name}": param
                            for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                            if "text_model.encoder.layers.4" not in module_name
                        }
                    )
655

Aryan's avatar
Aryan committed
656
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
657
658
659
660
661
662
663
664
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            # Unload lora and load it back using the pipe.load_lora_weights machinery
            pipe.unload_lora_weights()
            pipe.load_lora_weights(state_dict)

Aryan's avatar
Aryan committed
665
            output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
666
667
668
669
670
            self.assertTrue(
                not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
                "Removing adapters should change the output",
            )

671
672
673
674
    def test_simple_inference_save_pretrained(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
Aryan's avatar
Aryan committed
675
        for scheduler_cls in self.scheduler_classes:
676
677
678
679
680
681
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
682
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
683
            self.assertTrue(output_no_lora.shape == self.output_shape)
684
685
686
687

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

688
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
689
690
691
692
693
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
694

Aryan's avatar
Aryan committed
695
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
696
697
698
699
700
701
702
703
704
705
706
707

            with tempfile.TemporaryDirectory() as tmpdirname:
                pipe.save_pretrained(tmpdirname)

                pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
                pipe_from_pretrained.to(torch_device)

            self.assertTrue(
                check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                "Lora not correctly set in text encoder",
            )

708
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
709
710
711
712
713
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
                    )
714

Aryan's avatar
Aryan committed
715
            images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
716
717
718
719
720
721

            self.assertTrue(
                np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

722
    def test_simple_inference_with_text_denoiser_lora_save_load(self):
723
724
725
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
Aryan's avatar
Aryan committed
726
        for scheduler_cls in self.scheduler_classes:
727
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
728
729
730
731
732
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
733
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
734
            self.assertTrue(output_no_lora.shape == self.output_shape)
735

Aryan's avatar
Aryan committed
736
737
738
739
740
741
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
742
743
744
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
745

746
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
747
748
749
750
751
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
752

Aryan's avatar
Aryan committed
753
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
754
755

            with tempfile.TemporaryDirectory() as tmpdirname:
756
757
758
759
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
Aryan's avatar
Aryan committed
760
                )
761

762
763
764
765
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

766
767
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
768

769
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
770
771
772
773
774
            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

775
    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
776
777
778
779
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
780
781
782
783
784
785
786
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

Aryan's avatar
Aryan committed
787
        for scheduler_cls in self.scheduler_classes:
788
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
789
790
791
792
793
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
794
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
795
            self.assertTrue(output_no_lora.shape == self.output_shape)
796

Aryan's avatar
Aryan committed
797
798
799
800
801
802
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
803
804
805
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
806

807
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
808
809
810
811
812
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
813

Aryan's avatar
Aryan committed
814
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
815
816
817
818
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
819
820
821
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

822
823
824
825
826
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
827
828
829
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

830
831
832
833
834
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

Aryan's avatar
Aryan committed
835
836
837
838
839
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                    "The scaling parameter has not been correctly restored!",
                )
840

841
    def test_simple_inference_with_text_lora_denoiser_fused(self):
842
843
844
845
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
Aryan's avatar
Aryan committed
846
        for scheduler_cls in self.scheduler_classes:
847
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
848
849
850
851
852
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
853
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
854
            self.assertTrue(output_no_lora.shape == self.output_shape)
855

Aryan's avatar
Aryan committed
856
857
858
859
860
861
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
862
863
864
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
865

866
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
867
868
869
870
871
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
872

Aryan's avatar
Aryan committed
873
874
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

875
            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
876
877
878
879
880
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
881
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
882

883
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
884
885
886
887
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
888

Aryan's avatar
Aryan committed
889
            output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
890
            self.assertFalse(
Aryan's avatar
Aryan committed
891
                np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
892
893
            )

894
    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
895
896
897
898
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
899
        for scheduler_cls in self.scheduler_classes:
900
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
901
902
903
904
905
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
906
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
907
            self.assertTrue(output_no_lora.shape == self.output_shape)
908

Aryan's avatar
Aryan committed
909
910
911
912
913
914
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
915
916
917
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
918

919
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
920
921
922
923
924
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
925
926
927
928
929
930

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )
Aryan's avatar
Aryan committed
931
            self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
932

933
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
934
935
936
937
938
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
939

Aryan's avatar
Aryan committed
940
            output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
941
            self.assertTrue(
Aryan's avatar
Aryan committed
942
                np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
943
944
945
                "Fused lora should change the output",
            )

Aryan's avatar
Aryan committed
946
947
948
    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
949
950
951
952
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
953
        for scheduler_cls in self.scheduler_classes:
954
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
955
956
957
958
959
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
960
961
962
963
964
965
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
966
967
968
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
969

970
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
971
972
973
974
975
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
976

Aryan's avatar
Aryan committed
977
978
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
979

Aryan's avatar
Aryan committed
980
981
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
982
983

            # unloading should remove the LoRA layers
Aryan's avatar
Aryan committed
984
985
986
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

Aryan's avatar
Aryan committed
987
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
988

989
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
990
991
992
993
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )
994
995
996

            # Fuse and unfuse should lead to the same results
            self.assertTrue(
Aryan's avatar
Aryan committed
997
998
                np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
999
1000
            )

1001
    def test_simple_inference_with_text_denoiser_multi_adapter(self):
1002
1003
1004
1005
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1006
        for scheduler_cls in self.scheduler_classes:
1007
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1008
1009
1010
1011
1012
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1013
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1014

Aryan's avatar
Aryan committed
1015
1016
1017
1018
1019
1020
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1021

Aryan's avatar
Aryan committed
1022
1023
1024
1025
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1026

1027
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1028
1029
1030
1031
1032
1033
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1034
1035

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1036
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1037
1038
1039
1040
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1041
1042

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1043
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1044
1045
1046
1047
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1048
1049

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1050
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1051
1052
1053
1054
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1073
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1074
1075
1076
1077
1078
1079

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    def test_wrong_adapter_name_raises_error(self):
        scheduler_cls = self.scheduler_classes[0]
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

        with self.assertRaises(ValueError) as err_context:
            pipe.set_adapters("test")

        self.assertTrue("not in the list of present adapters" in str(err_context.exception))

        # test this works.
        pipe.set_adapters("adapter-1")
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

1112
    def test_simple_inference_with_text_denoiser_block_scale(self):
UmerHA's avatar
UmerHA committed
1113
1114
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
Aryan's avatar
Aryan committed
1115
        one adapter and set different weights for different blocks (i.e. block lora)
UmerHA's avatar
UmerHA committed
1116
        """
Aryan's avatar
Aryan committed
1117
        for scheduler_cls in self.scheduler_classes:
1118
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1119
1120
1121
1122
1123
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1124
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1125
1126
1127

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1128
1129
1130
1131

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1132

1133
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1134
1135
1136
1137
1138
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1139
1140
1141

            weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
            pipe.set_adapters("adapter-1", weights_1)
Aryan's avatar
Aryan committed
1142
            output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1143
1144
1145

            weights_2 = {"unet": {"up": 5}}
            pipe.set_adapters("adapter-1", weights_2)
Aryan's avatar
Aryan committed
1146
            output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161

            self.assertFalse(
                np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
                "LoRA weights 1 and 2 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 1 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 2 should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1162
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1163
1164
1165
1166
1167
1168

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1169
    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
UmerHA's avatar
UmerHA committed
1170
1171
1172
1173
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set differnt weights for different blocks (i.e. block lora)
        """
Aryan's avatar
Aryan committed
1174
        for scheduler_cls in self.scheduler_classes:
1175
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1176
1177
1178
1179
1180
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1181
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1182

Aryan's avatar
Aryan committed
1183
1184
1185
1186
1187
1188
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
UmerHA's avatar
UmerHA committed
1189

Aryan's avatar
Aryan committed
1190
1191
1192
1193
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1194

1195
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1196
1197
1198
1199
1200
1201
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1202
1203
1204
1205

            scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
            scales_2 = {"unet": {"down": 5, "mid": 5}}

Aryan's avatar
Aryan committed
1206
1207
            pipe.set_adapters("adapter-1", scales_1)
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1208
1209

            pipe.set_adapters("adapter-2", scales_2)
Aryan's avatar
Aryan committed
1210
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1211
1212

            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
Aryan's avatar
Aryan committed
1213
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1232
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            # a mismatching number of adapter_names and adapter_weights should raise an error
            with self.assertRaises(ValueError):
                pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])

1243
    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
UmerHA's avatar
UmerHA committed
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

1309
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
UmerHA's avatar
UmerHA committed
1310
1311
1312
1313
1314
1315
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1316
1317
1318

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1319

1320
        if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1321
1322
1323
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1324
1325
1326
1327
1328
1329
1330
1331

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

1332
    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1333
1334
1335
1336
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
Aryan's avatar
Aryan committed
1337
        for scheduler_cls in self.scheduler_classes:
1338
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1339
1340
1341
1342
1343
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1344
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1345

Aryan's avatar
Aryan committed
1346
1347
1348
1349
1350
1351
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1352

Aryan's avatar
Aryan committed
1353
1354
1355
1356
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1357

1358
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1359
1360
1361
1362
1363
1364
1365
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1366
1367

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1368
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1369
1370

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1371
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1372
1373

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1374
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.delete_adapters("adapter-1")
Aryan's avatar
Aryan committed
1392
            output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1393
1394
1395
1396
1397
1398
1399

            self.assertTrue(
                np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            pipe.delete_adapters("adapter-2")
Aryan's avatar
Aryan committed
1400
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1401
1402
1403
1404
1405
1406

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

Aryan's avatar
Aryan committed
1407
1408
1409
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1410

Aryan's avatar
Aryan committed
1411
1412
1413
1414
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1415
1416
1417
1418

            pipe.set_adapters(["adapter-1", "adapter-2"])
            pipe.delete_adapters(["adapter-1", "adapter-2"])

Aryan's avatar
Aryan committed
1419
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1420
1421
1422
1423
1424
1425

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1426
    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1427
1428
1429
1430
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
Aryan's avatar
Aryan committed
1431
        for scheduler_cls in self.scheduler_classes:
1432
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1433
1434
1435
1436
1437
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1438
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1439

Aryan's avatar
Aryan committed
1440
1441
1442
1443
1444
1445
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1446

Aryan's avatar
Aryan committed
1447
1448
1449
1450
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1451

1452
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1453
1454
1455
1456
1457
1458
1459
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1460
1461

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1462
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1463
1464

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1465
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1466
1467

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1468
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
Aryan's avatar
Aryan committed
1487
            output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
1488
1489
1490
1491
1492
1493
1494

            self.assertFalse(
                np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Weighted adapter and mixed adapter should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1495
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1496
1497
1498
1499
1500
1501

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1502
    @skip_mps
1503
    def test_lora_fuse_nan(self):
Aryan's avatar
Aryan committed
1504
        for scheduler_cls in self.scheduler_classes:
1505
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1506
1507
1508
1509
1510
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1511
1512
1513
1514
1515
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1516

Aryan's avatar
Aryan committed
1517
1518
1519
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1520
1521
1522

            # corrupt one LoRA weight with `inf` values
            with torch.no_grad():
1523
1524
1525
1526
1527
1528
                if self.unet_kwargs:
                    pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
                        "adapter-1"
                    ].weight += float("inf")
                else:
                    pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
1529
1530
1531

            # with `safe_fusing=True` we should see an Error
            with self.assertRaises(ValueError):
Aryan's avatar
Aryan committed
1532
                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
1533
1534

            # without we should not see an error, but every image will be black
Aryan's avatar
Aryan committed
1535
1536
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
            out = pipe("test", num_inference_steps=2, output_type="np")[0]
1537
1538
1539
1540
1541
1542
1543
1544

            self.assertTrue(np.isnan(out).all())

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1545
        for scheduler_cls in self.scheduler_classes:
1546
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1547
1548
1549
1550
1551
1552
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1553
1554
1555

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1556
1557
1558
1559
1560

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-1"])

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
Aryan's avatar
Aryan committed
1561
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-2"])

            pipe.set_adapters(["adapter-1", "adapter-2"])
            self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
Aryan's avatar
Aryan committed
1574
        for scheduler_cls in self.scheduler_classes:
1575
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1576
1577
1578
1579
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

Aryan's avatar
Aryan committed
1580
1581
1582
1583
1584
1585
            # 1.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                dicts_to_be_checked = {"text_encoder": ["adapter-1"]}

1586
1587
1588
1589
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
                dicts_to_be_checked.update({"unet": ["adapter-1"]})
            else:
Aryan's avatar
Aryan committed
1590
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1591
                dicts_to_be_checked.update({"transformer": ["adapter-1"]})
1592

Aryan's avatar
Aryan committed
1593
1594
1595
1596
1597
1598
1599
1600
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 2.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1601
1602
1603
1604
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
Aryan's avatar
Aryan committed
1605
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1606
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
1607

Aryan's avatar
Aryan committed
1608
1609
1610
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 3.
1611
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1612
1613
1614
1615
1616

            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1617
1618
1619
1620
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
Aryan's avatar
Aryan committed
1621

1622
1623
            self.assertDictEqual(
                pipe.get_list_adapters(),
1624
                dicts_to_be_checked,
1625
1626
            )

Aryan's avatar
Aryan committed
1627
1628
1629
1630
1631
            # 4.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1632
            if self.unet_kwargs is not None:
Aryan's avatar
Aryan committed
1633
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
1634
1635
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
            else:
Aryan's avatar
Aryan committed
1636
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
1637
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
Aryan's avatar
Aryan committed
1638

1639
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
1640
1641

    @require_peft_version_greater(peft_version="0.6.2")
Aryan's avatar
Aryan committed
1642
1643
1644
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
1645
1646
1647
1648
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
Aryan's avatar
Aryan committed
1649
        for scheduler_cls in self.scheduler_classes:
1650
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1651
1652
1653
1654
1655
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1656
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1657
            self.assertTrue(output_no_lora.shape == self.output_shape)
1658

Aryan's avatar
Aryan committed
1659
1660
1661
1662
1663
1664
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
1665
1666
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1667
1668

            # Attach a second adapter
Aryan's avatar
Aryan committed
1669
1670
1671
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

Aryan's avatar
Aryan committed
1672
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1673

Aryan's avatar
Aryan committed
1674
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1675

1676
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1677
1678
1679
1680
1681
1682
1683
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1684
1685
1686

            # set them to multi-adapter inference mode
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1687
            outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1688
1689

            pipe.set_adapters(["adapter-1"])
Aryan's avatar
Aryan committed
1690
            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1691

Aryan's avatar
Aryan committed
1692
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1693
1694

            # Fusing should still keep the LoRA layers so outpout should remain the same
Aryan's avatar
Aryan committed
1695
            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1696
1697

            self.assertTrue(
Aryan's avatar
Aryan committed
1698
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1699
1700
1701
                "Fused lora should not change the output",
            )

Aryan's avatar
Aryan committed
1702
1703
1704
1705
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
            )
1706
1707

            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
1708
            output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1709
            self.assertTrue(
Aryan's avatar
Aryan committed
1710
                np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
1711
1712
1713
                "Fused lora should not change the output",
            )

1714
1715
    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
Aryan's avatar
Aryan committed
1716
        for scheduler_cls in self.scheduler_classes:
1717
1718
1719
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
                scheduler_cls, use_dora=True
            )
1720
1721
1722
1723
1724
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1725
            output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1726
            self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1727
1728
1729

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1730
1731
1732
1733

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1734

1735
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1736
1737
1738
1739
1740
1741
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1742

Aryan's avatar
Aryan committed
1743
            output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1744
1745
1746
1747
1748
1749

            self.assertFalse(
                np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
                "DoRA lora should change the output",
            )

1750
    @unittest.skip("This is failing for now - need to investigate")
1751
    def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1752
1753
1754
1755
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
1756
        for scheduler_cls in self.scheduler_classes:
1757
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1758
1759
1760
1761
1762
1763
1764
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1765
1766
1767
1768

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1769

1770
            if self.has_two_text_encoders or self.has_three_text_encoders:
1771
1772
1773
1774
1775
1776
1777
1778
                pipe.text_encoder_2.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)

1779
            if self.has_two_text_encoders or self.has_three_text_encoders:
1780
1781
1782
                pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)

            # Just makes sure it works..
Aryan's avatar
Aryan committed
1783
            _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1784
1785
1786
1787
1788
1789
1790

    def test_modify_padding_mode(self):
        def set_pad_mode(network, mode="circular"):
            for _, module in network.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    module.padding_mode = mode

Aryan's avatar
Aryan committed
1791
        for scheduler_cls in self.scheduler_classes:
1792
1793
1794
1795
1796
1797
1798
1799
1800
            components, _, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _pad_mode = "circular"
            set_pad_mode(pipe.vae, _pad_mode)
            set_pad_mode(pipe.unet, _pad_mode)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
1801
            _ = pipe(**inputs)[0]