utils.py 86.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sayak Paul's avatar
Sayak Paul committed
15
import inspect
16
17
18
import os
import tempfile
import unittest
UmerHA's avatar
UmerHA committed
19
from itertools import product
20
21
22
23
24
25
26

import numpy as np
import torch

from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
27
    FlowMatchEulerDiscreteScheduler,
28
29
30
31
32
33
34
35
    LCMScheduler,
    UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
    floats_tensor,
    require_peft_backend,
    require_peft_version_greater,
36
    skip_mps,
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    torch_device,
)


if is_peft_available():
    from peft import LoraConfig
    from peft.tuners.tuners_utils import BaseTunerLayer
    from peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


@require_peft_backend
class PeftLoraLoaderMixinTests:
    pipeline_class = None
    scheduler_cls = None
    scheduler_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
74
75
    uses_flow_matching = False

76
    has_two_text_encoders = False
77
    has_three_text_encoders = False
Sayak Paul's avatar
Sayak Paul committed
78
79
80
81
82
83
84
    text_encoder_cls, text_encoder_id = None, None
    text_encoder_2_cls, text_encoder_2_id = None, None
    text_encoder_3_cls, text_encoder_3_id = None, None
    tokenizer_cls, tokenizer_id = None, None
    tokenizer_2_cls, tokenizer_2_id = None, None
    tokenizer_3_cls, tokenizer_3_id = None, None

85
    unet_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
86
    transformer_cls = None
87
    transformer_kwargs = None
Aryan's avatar
Aryan committed
88
    vae_cls = AutoencoderKL
89
90
    vae_kwargs = None

Aryan's avatar
Aryan committed
91
92
    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]

93
    def get_dummy_components(self, scheduler_cls=None, use_dora=False):
94
95
96
97
98
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

99
        scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
100
101
102
        rank = 4

        torch.manual_seed(0)
103
104
105
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
Sayak Paul's avatar
Sayak Paul committed
106
            transformer = self.transformer_cls(**self.transformer_kwargs)
107
108
109
110

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        torch.manual_seed(0)
Aryan's avatar
Aryan committed
111
        vae = self.vae_cls(**self.vae_kwargs)
112

Sayak Paul's avatar
Sayak Paul committed
113
114
        text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
115

Sayak Paul's avatar
Sayak Paul committed
116
117
118
        if self.text_encoder_2_cls is not None:
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
119

Sayak Paul's avatar
Sayak Paul committed
120
121
122
        if self.text_encoder_3_cls is not None:
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
123

124
125
126
        text_lora_config = LoraConfig(
            r=rank,
            lora_alpha=rank,
Aryan's avatar
Aryan committed
127
            target_modules=self.text_encoder_target_modules,
128
            init_lora_weights=False,
129
            use_dora=use_dora,
130
131
        )

132
        denoiser_lora_config = LoraConfig(
133
134
135
136
137
            r=rank,
            lora_alpha=rank,
            target_modules=["to_q", "to_k", "to_v", "to_out.0"],
            init_lora_weights=False,
            use_dora=use_dora,
138
139
        )

Sayak Paul's avatar
Sayak Paul committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})
166

167
        return pipeline_components, text_lora_config, denoiser_lora_config
168

Sayak Paul's avatar
Sayak Paul committed
169
170
171
172
    @property
    def output_shape(self):
        raise NotImplementedError

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = torch.manual_seed(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

194
    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
195
196
197
198
199
200
201
202
203
204
205
206
207
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
Aryan's avatar
Aryan committed
208
209
210
        # TODO(aryan): Some of the assumptions made here in many different tests are incorrect for CogVideoX.
        # For example, we need to test with CogVideoXDDIMScheduler and CogVideoDPMScheduler instead of DDIMScheduler
        # and LCMScheduler, which are not supported by it.
211
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
212
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
213
214
        )
        for scheduler_cls in scheduler_classes:
215
216
217
218
219
220
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
221
            output_no_lora = pipe(**inputs)[0]
Sayak Paul's avatar
Sayak Paul committed
222
            self.assertTrue(output_no_lora.shape == self.output_shape)
223
224
225
226
227
228

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
229
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
230
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
231
232
        )
        for scheduler_cls in scheduler_classes:
233
234
235
236
237
238
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
239
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
240
            self.assertTrue(output_no_lora.shape == self.output_shape)
241
242
243
244

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

245
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
246
247
248
249
250
251
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
252

Aryan's avatar
Aryan committed
253
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
254
255
256
257
258
259
260
261
262
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
263
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
264
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
265
        )
Aryan's avatar
Aryan committed
266
267
268
269
270
271
272
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

273
        for scheduler_cls in scheduler_classes:
274
275
276
277
278
279
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
280
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
281
            self.assertTrue(output_no_lora.shape == self.output_shape)
282
283
284
285

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

286
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
287
288
289
290
291
292
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
293

Aryan's avatar
Aryan committed
294
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
295
296
297
298
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
299
300
301
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

302
303
304
305
306
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
307
308
309
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

310
311
312
313
314
315
316
317
318
319
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
320
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
321
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
322
323
        )
        for scheduler_cls in scheduler_classes:
324
325
326
327
328
329
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
330
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
331
            self.assertTrue(output_no_lora.shape == self.output_shape)
332
333
334
335

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

336
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
337
338
339
340
341
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
342
343
344
345
346

            pipe.fuse_lora()
            # Fusing should still keep the LoRA layers
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

347
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
348
349
350
351
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
352

Aryan's avatar
Aryan committed
353
            ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
354
355
356
357
358
359
360
361
362
            self.assertFalse(
                np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
363
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
364
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
365
366
        )
        for scheduler_cls in scheduler_classes:
367
368
369
370
371
372
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
373
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
374
            self.assertTrue(output_no_lora.shape == self.output_shape)
375

Aryan's avatar
Aryan committed
376
377
378
379
380
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
381

382
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
383
384
385
386
387
388
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
389
390
391
392
393
394
395

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )

396
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
397
398
399
400
401
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
402

Aryan's avatar
Aryan committed
403
            ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
404
405
406
407
408
409
410
411
412
            self.assertTrue(
                np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
413
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
414
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
415
416
        )
        for scheduler_cls in scheduler_classes:
417
418
419
420
421
422
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
423
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
424
            self.assertTrue(output_no_lora.shape == self.output_shape)
425

Aryan's avatar
Aryan committed
426
427
428
429
430
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
431

432
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
433
434
435
436
437
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
438

Aryan's avatar
Aryan committed
439
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
440
441
442

            with tempfile.TemporaryDirectory() as tmpdirname:
                text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
443
                if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
444
445
                    if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                        text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
446

Sayak Paul's avatar
Sayak Paul committed
447
448
449
450
451
452
                        self.pipeline_class.save_lora_weights(
                            save_directory=tmpdirname,
                            text_encoder_lora_layers=text_encoder_state_dict,
                            text_encoder_2_lora_layers=text_encoder_2_state_dict,
                            safe_serialization=False,
                        )
453
454
455
456
457
458
459
                else:
                    self.pipeline_class.save_lora_weights(
                        save_directory=tmpdirname,
                        text_encoder_lora_layers=text_encoder_state_dict,
                        safe_serialization=False,
                    )

Sayak Paul's avatar
Sayak Paul committed
460
461
462
463
464
465
466
467
                if self.has_two_text_encoders:
                    if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
                        self.pipeline_class.save_lora_weights(
                            save_directory=tmpdirname,
                            text_encoder_lora_layers=text_encoder_state_dict,
                            safe_serialization=False,
                        )

468
469
470
471
472
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()

                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

Aryan's avatar
Aryan committed
473
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
474
475
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

476
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
477
478
479
480
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
481
482
483
484
485
486

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

487
488
489
490
491
492
    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
493
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
494
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
495
496
        )
        for scheduler_cls in scheduler_classes:
497
            components, _, _ = self.get_dummy_components(scheduler_cls)
498
            # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
499
500
501
502
503
504
505
506
507
508
509
510
511
            text_lora_config = LoraConfig(
                r=4,
                rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
                lora_alpha=4,
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
                init_lora_weights=False,
                use_dora=False,
            )
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
512
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
513
            self.assertTrue(output_no_lora.shape == self.output_shape)
514
515
516
517
518
519
520
521
522
523
524

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
            # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
            # supports missing layers (PR#8324).
            state_dict = {
                f"text_encoder.{module_name}": param
                for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                if "text_model.encoder.layers.4" not in module_name
            }

525
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
526
527
528
529
530
531
532
533
534
535
536
537
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    state_dict.update(
                        {
                            f"text_encoder_2.{module_name}": param
                            for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                            if "text_model.encoder.layers.4" not in module_name
                        }
                    )
538

Aryan's avatar
Aryan committed
539
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
540
541
542
543
544
545
546
547
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            # Unload lora and load it back using the pipe.load_lora_weights machinery
            pipe.unload_lora_weights()
            pipe.load_lora_weights(state_dict)

Aryan's avatar
Aryan committed
548
            output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
549
550
551
552
553
            self.assertTrue(
                not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
                "Removing adapters should change the output",
            )

554
555
556
557
    def test_simple_inference_save_pretrained(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
558
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
559
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
560
561
        )
        for scheduler_cls in scheduler_classes:
562
563
564
565
566
567
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
568
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
569
            self.assertTrue(output_no_lora.shape == self.output_shape)
570
571
572
573

            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

574
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
575
576
577
578
579
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
580

Aryan's avatar
Aryan committed
581
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
582
583
584
585
586
587
588
589
590
591
592
593

            with tempfile.TemporaryDirectory() as tmpdirname:
                pipe.save_pretrained(tmpdirname)

                pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
                pipe_from_pretrained.to(torch_device)

            self.assertTrue(
                check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                "Lora not correctly set in text encoder",
            )

594
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
595
596
597
598
599
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
                    )
600

Aryan's avatar
Aryan committed
601
            images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
602
603
604
605
606
607

            self.assertTrue(
                np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

608
    def test_simple_inference_with_text_denoiser_lora_save_load(self):
609
610
611
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
612
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
613
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
614
615
616
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
617
618
619
620
621
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
622
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
623
            self.assertTrue(output_no_lora.shape == self.output_shape)
624

Aryan's avatar
Aryan committed
625
626
627
628
629
630
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

631
632
633
634
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config)
            else:
                pipe.transformer.add_adapter(denoiser_lora_config)
635

636
637
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
638

639
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
640
641
642
643
644
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
645

Aryan's avatar
Aryan committed
646
            images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
647
648

            with tempfile.TemporaryDirectory() as tmpdirname:
Aryan's avatar
Aryan committed
649
650
651
652
653
                text_encoder_state_dict = (
                    get_peft_model_state_dict(pipe.text_encoder)
                    if "text_encoder" in self.pipeline_class._lora_loadable_modules
                    else None
                )
654
655
656
657
658
659

                if self.unet_kwargs is not None:
                    denoiser_state_dict = get_peft_model_state_dict(pipe.unet)
                else:
                    denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)

Sayak Paul's avatar
Sayak Paul committed
660
661
662
663
                saving_kwargs = {
                    "save_directory": tmpdirname,
                    "safe_serialization": False,
                }
664

Aryan's avatar
Aryan committed
665
666
667
                if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                    saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict})

Sayak Paul's avatar
Sayak Paul committed
668
669
                if self.unet_kwargs is not None:
                    saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
670
                else:
Sayak Paul's avatar
Sayak Paul committed
671
672
673
674
675
676
677
678
                    saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})

                if self.has_two_text_encoders or self.has_three_text_encoders:
                    if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                        text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
                        saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})

                self.pipeline_class.save_lora_weights(**saving_kwargs)
679
680
681
682
683
684

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
                pipe.unload_lora_weights()

                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

Aryan's avatar
Aryan committed
685
686
687
688
689
690
691
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

692
693
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
694

695
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
696
697
698
699
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
700
701
702
703
704
705

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

706
    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
707
708
709
710
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
711
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
712
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
713
        )
Aryan's avatar
Aryan committed
714
715
716
717
718
719
720
        call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
        for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
            if possible_attention_kwargs in call_signature_keys:
                attention_kwargs_name = possible_attention_kwargs
                break
        assert attention_kwargs_name is not None

721
722
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
723
724
725
726
727
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
728
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
729
            self.assertTrue(output_no_lora.shape == self.output_shape)
730

Aryan's avatar
Aryan committed
731
732
733
734
735
736
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

737
738
739
740
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config)
            else:
                pipe.transformer.add_adapter(denoiser_lora_config)
Aryan's avatar
Aryan committed
741

742
743
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
744

745
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
746
747
748
749
750
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
751

Aryan's avatar
Aryan committed
752
            output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
753
754
755
756
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

Aryan's avatar
Aryan committed
757
758
759
            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

760
761
762
763
764
            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

Aryan's avatar
Aryan committed
765
766
767
            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]

768
769
770
771
772
            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

Aryan's avatar
Aryan committed
773
774
775
776
777
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                    "The scaling parameter has not been correctly restored!",
                )
778

779
    def test_simple_inference_with_text_lora_denoiser_fused(self):
780
781
782
783
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
784
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
785
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
786
787
788
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
789
790
791
792
793
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
794
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
795
            self.assertTrue(output_no_lora.shape == self.output_shape)
796

Aryan's avatar
Aryan committed
797
798
799
800
801
802
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

803
804
805
806
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config)
            else:
                pipe.transformer.add_adapter(denoiser_lora_config)
807

808
809
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
810

811
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
812
813
814
815
816
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
817

Aryan's avatar
Aryan committed
818
819
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

820
            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
821
822
823
824
825
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

826
827
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
828

829
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
830
831
832
833
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
834

Aryan's avatar
Aryan committed
835
            output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
836
            self.assertFalse(
Aryan's avatar
Aryan committed
837
                np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
838
839
            )

840
    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
841
842
843
844
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
845
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
846
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
847
848
849
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
850
851
852
853
854
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
855
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
856
            self.assertTrue(output_no_lora.shape == self.output_shape)
857

Aryan's avatar
Aryan committed
858
859
860
861
862
863
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

864
865
866
867
868
869
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config)
            else:
                pipe.transformer.add_adapter(denoiser_lora_config)
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
870

871
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
872
873
874
875
876
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
877
878
879
880
881
882

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )
883
884
885
886
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertFalse(
                check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly unloaded in denoiser"
            )
887

888
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
889
890
891
892
893
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )
894

Aryan's avatar
Aryan committed
895
            output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
896
            self.assertTrue(
Aryan's avatar
Aryan committed
897
                np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
898
899
900
                "Fused lora should change the output",
            )

Aryan's avatar
Aryan committed
901
902
903
    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
904
905
906
907
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
908
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
909
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
910
911
912
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
913
914
915
916
917
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
918
919
920
921
922
923
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

924
925
926
927
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config)
            else:
                pipe.transformer.add_adapter(denoiser_lora_config)
928

929
930
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
931

932
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
933
934
935
936
937
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
938

Aryan's avatar
Aryan committed
939
940
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
941

Aryan's avatar
Aryan committed
942
943
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
944
945

            # unloading should remove the LoRA layers
Aryan's avatar
Aryan committed
946
947
948
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

949
950
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
951

952
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
953
954
955
956
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )
957
958
959

            # Fuse and unfuse should lead to the same results
            self.assertTrue(
Aryan's avatar
Aryan committed
960
961
                np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
962
963
            )

964
    def test_simple_inference_with_text_denoiser_multi_adapter(self):
965
966
967
968
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
969
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
970
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
971
972
973
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
974
975
976
977
978
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
979
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
980

Aryan's avatar
Aryan committed
981
982
983
984
985
986
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
987

988
989
990
991
992
993
994
995
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
996

997
998
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
999

1000
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1001
1002
1003
1004
1005
1006
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1007
1008

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1009
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1010
1011

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1012
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1013
1014

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1015
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1034
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1035
1036
1037
1038
1039
1040

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1041
    def test_simple_inference_with_text_denoiser_block_scale(self):
UmerHA's avatar
UmerHA committed
1042
1043
1044
1045
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        one adapter and set differnt weights for different blocks (i.e. block lora)
        """
Aryan's avatar
Aryan committed
1046
        if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "CogVideoXPipeline"]:
1047
1048
1049
            return

        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1050
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1051
1052
1053
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1054
1055
1056
1057
1058
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1059
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1060
1061

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
1062
1063
1064
1065
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1066
1067

            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1068
1069
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
UmerHA's avatar
UmerHA committed
1070

1071
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1072
1073
1074
1075
1076
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1077
1078
1079

            weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
            pipe.set_adapters("adapter-1", weights_1)
Aryan's avatar
Aryan committed
1080
            output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1081
1082
1083

            weights_2 = {"unet": {"up": 5}}
            pipe.set_adapters("adapter-1", weights_2)
Aryan's avatar
Aryan committed
1084
            output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099

            self.assertFalse(
                np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
                "LoRA weights 1 and 2 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 1 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 2 should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1100
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1101
1102
1103
1104
1105
1106

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1107
    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
UmerHA's avatar
UmerHA committed
1108
1109
1110
1111
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set differnt weights for different blocks (i.e. block lora)
        """
1112
1113
1114
1115
        if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
            return

        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1116
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1117
1118
1119
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
UmerHA's avatar
UmerHA committed
1120
1121
1122
1123
1124
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1125
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1126

Aryan's avatar
Aryan committed
1127
1128
1129
1130
1131
1132
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
UmerHA's avatar
UmerHA committed
1133

1134
1135
1136
1137
1138
1139
1140
1141
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
UmerHA's avatar
UmerHA committed
1142

1143
1144
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
UmerHA's avatar
UmerHA committed
1145

1146
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1147
1148
1149
1150
1151
1152
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
UmerHA's avatar
UmerHA committed
1153
1154
1155
1156

            scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
            scales_2 = {"unet": {"down": 5, "mid": 5}}

Aryan's avatar
Aryan committed
1157
1158
            pipe.set_adapters("adapter-1", scales_1)
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1159
1160

            pipe.set_adapters("adapter-2", scales_2)
Aryan's avatar
Aryan committed
1161
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1162
1163

            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
Aryan's avatar
Aryan committed
1164
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
Aryan's avatar
Aryan committed
1183
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            # a mismatching number of adapter_names and adapter_weights should raise an error
            with self.assertRaises(ValueError):
                pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])

1194
    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
UmerHA's avatar
UmerHA committed
1195
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
Aryan's avatar
Aryan committed
1196
        if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]:
1197
            return
UmerHA's avatar
UmerHA committed
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

1262
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
UmerHA's avatar
UmerHA committed
1263
1264
1265
1266
1267
1268
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
1269
1270
1271
1272
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1273

1274
        if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1275
1276
1277
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1278
1279
1280
1281
1282
1283
1284
1285

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

1286
    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1287
1288
1289
1290
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
1291
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1292
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1293
1294
1295
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1296
1297
1298
1299
1300
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1301
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1302

Aryan's avatar
Aryan committed
1303
1304
1305
1306
1307
1308
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1309

1310
1311
1312
1313
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1314

1315
1316
1317
1318
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1319

1320
1321
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
1322

1323
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1324
1325
1326
1327
1328
1329
1330
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1331
1332

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1333
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1334
1335

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1336
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1337
1338

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1339
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.delete_adapters("adapter-1")
Aryan's avatar
Aryan committed
1357
            output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1358
1359
1360
1361
1362
1363
1364

            self.assertTrue(
                np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            pipe.delete_adapters("adapter-2")
Aryan's avatar
Aryan committed
1365
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1366
1367
1368
1369
1370
1371

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

Aryan's avatar
Aryan committed
1372
1373
1374
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1375

1376
1377
1378
1379
1380
1381
1382
1383
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1384
1385
1386
1387

            pipe.set_adapters(["adapter-1", "adapter-2"])
            pipe.delete_adapters(["adapter-1", "adapter-2"])

Aryan's avatar
Aryan committed
1388
            output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1389
1390
1391
1392
1393
1394

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1395
    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1396
1397
1398
1399
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
1400
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1401
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1402
1403
1404
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1405
1406
1407
1408
1409
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1410
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1411

Aryan's avatar
Aryan committed
1412
1413
1414
1415
1416
1417
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1418

1419
1420
1421
1422
1423
1424
1425
1426
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1427

1428
1429
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
1430

1431
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1432
1433
1434
1435
1436
1437
1438
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1439
1440

            pipe.set_adapters("adapter-1")
Aryan's avatar
Aryan committed
1441
            output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1442
1443

            pipe.set_adapters("adapter-2")
Aryan's avatar
Aryan committed
1444
            output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1445
1446

            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1447
            output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
Aryan's avatar
Aryan committed
1466
            output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
1467
1468
1469
1470
1471
1472
1473
1474

            self.assertFalse(
                np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Weighted adapter and mixed adapter should give different results",
            )

            pipe.disable_lora()

Aryan's avatar
Aryan committed
1475
            output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1476
1477
1478
1479
1480
1481

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

1482
    @skip_mps
1483
    def test_lora_fuse_nan(self):
1484
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1485
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1486
1487
1488
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1489
1490
1491
1492
1493
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1494
1495
1496
1497
1498
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
1499

1500
1501
1502
1503
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1504

1505
1506
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
1507
1508
1509

            # corrupt one LoRA weight with `inf` values
            with torch.no_grad():
1510
1511
1512
1513
1514
1515
                if self.unet_kwargs:
                    pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
                        "adapter-1"
                    ].weight += float("inf")
                else:
                    pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
1516
1517
1518

            # with `safe_fusing=True` we should see an Error
            with self.assertRaises(ValueError):
Aryan's avatar
Aryan committed
1519
                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
1520
1521

            # without we should not see an error, but every image will be black
Aryan's avatar
Aryan committed
1522
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
1523

Aryan's avatar
Aryan committed
1524
            out = pipe("test", num_inference_steps=2, output_type="np")[0]
1525
1526
1527
1528
1529
1530
1531
1532

            self.assertTrue(np.isnan(out).all())

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
1533
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1534
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1535
1536
1537
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1538
1539
1540
1541
1542
1543
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
1544
1545
1546
1547
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1548
1549
1550
1551
1552

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-1"])

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1553
1554
1555
1556
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-2"])

            pipe.set_adapters(["adapter-1", "adapter-2"])
            self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
1569
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1570
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1571
1572
1573
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1574
1575
1576
1577
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)

Aryan's avatar
Aryan committed
1578
1579
1580
1581
1582
1583
            # 1.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                dicts_to_be_checked = {"text_encoder": ["adapter-1"]}

1584
1585
1586
1587
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1588

1589
1590
1591
1592
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1"]})
1593

Aryan's avatar
Aryan committed
1594
1595
1596
1597
1598
1599
1600
1601
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 2.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1602
1603
1604
1605
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1606

1607
1608
1609
1610
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
1611

Aryan's avatar
Aryan committed
1612
1613
1614
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 3.
1615
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1616
1617
1618
1619
1620

            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1621
1622
1623
1624
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
Aryan's avatar
Aryan committed
1625

1626
1627
            self.assertDictEqual(
                pipe.get_list_adapters(),
1628
                dicts_to_be_checked,
1629
1630
            )

Aryan's avatar
Aryan committed
1631
            # 4.
1632
1633
1634
1635
1636
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")

Aryan's avatar
Aryan committed
1637
1638
1639
1640
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

1641
1642
1643
1644
            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
Aryan's avatar
Aryan committed
1645

1646
            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
1647
1648

    @require_peft_version_greater(peft_version="0.6.2")
Aryan's avatar
Aryan committed
1649
1650
1651
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
1652
1653
1654
1655
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
1656
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1657
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1658
1659
1660
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1661
1662
1663
1664
1665
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1666
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1667
            self.assertTrue(output_no_lora.shape == self.output_shape)
1668

Aryan's avatar
Aryan committed
1669
1670
1671
1672
1673
1674
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

1675
1676
1677
1678
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
1679
1680

            # Attach a second adapter
Aryan's avatar
Aryan committed
1681
1682
1683
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

1684
1685
1686
1687
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
1688

1689
1690
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
1691

1692
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1693
1694
1695
1696
1697
1698
1699
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1700
1701
1702

            # set them to multi-adapter inference mode
            pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1703
            outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1704
1705

            pipe.set_adapters(["adapter-1"])
Aryan's avatar
Aryan committed
1706
            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1707

Aryan's avatar
Aryan committed
1708
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
1709
1710

            # Fusing should still keep the LoRA layers so outpout should remain the same
Aryan's avatar
Aryan committed
1711
            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1712
1713

            self.assertTrue(
Aryan's avatar
Aryan committed
1714
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1715
1716
1717
                "Fused lora should not change the output",
            )

Aryan's avatar
Aryan committed
1718
1719
1720
1721
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
            )
1722
1723

            # Fusing should still keep the LoRA layers
Aryan's avatar
Aryan committed
1724
            output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1725
            self.assertTrue(
Aryan's avatar
Aryan committed
1726
                np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
1727
1728
1729
                "Fused lora should not change the output",
            )

1730
1731
    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
1732
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1733
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1734
1735
1736
1737
1738
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
                scheduler_cls, use_dora=True
            )
1739
1740
1741
1742
1743
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1744
            output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1745
            self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1746
1747

            pipe.text_encoder.add_adapter(text_lora_config)
1748
1749
1750
1751
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config)
            else:
                pipe.transformer.add_adapter(denoiser_lora_config)
1752
1753

            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1754
1755
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
1756

1757
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1758
1759
1760
1761
1762
1763
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
1764

Aryan's avatar
Aryan committed
1765
            output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1766
1767
1768
1769
1770
1771

            self.assertFalse(
                np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
                "DoRA lora should change the output",
            )

1772
    @unittest.skip("This is failing for now - need to investigate")
1773
    def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1774
1775
1776
1777
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
1778
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1779
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1780
1781
1782
        )
        for scheduler_cls in scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1783
1784
1785
1786
1787
1788
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config)
1789
1790
1791
1792
            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config)
            else:
                pipe.transformer.add_adapter(denoiser_lora_config)
1793
1794

            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1795
1796
            denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
            self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
1797

1798
            if self.has_two_text_encoders or self.has_three_text_encoders:
1799
1800
1801
1802
1803
1804
1805
1806
                pipe.text_encoder_2.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )

            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)

1807
            if self.has_two_text_encoders or self.has_three_text_encoders:
1808
1809
1810
                pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)

            # Just makes sure it works..
Aryan's avatar
Aryan committed
1811
            _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1812
1813

    def test_modify_padding_mode(self):
Aryan's avatar
Aryan committed
1814
        if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]:
1815
1816
            return

1817
1818
1819
1820
1821
        def set_pad_mode(network, mode="circular"):
            for _, module in network.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    module.padding_mode = mode

1822
        scheduler_classes = (
Sayak Paul's avatar
Sayak Paul committed
1823
            [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
1824
1825
        )
        for scheduler_cls in scheduler_classes:
1826
1827
1828
1829
1830
1831
1832
1833
1834
            components, _, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _pad_mode = "circular"
            set_pad_mode(pipe.vae, _pad_mode)
            set_pad_mode(pipe.unet, _pad_mode)

            _, _, inputs = self.get_dummy_inputs()
Aryan's avatar
Aryan committed
1835
            _ = pipe(**inputs)[0]