"vscode:/vscode.git/clone" did not exist on "4142c3ef7c4c543bf9735cdddb99d4570071c5bd"
utils.py 113 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2025 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sayak Paul's avatar
Sayak Paul committed
15
import inspect
16
import os
Aryan's avatar
Aryan committed
17
import re
18
19
import tempfile
import unittest
UmerHA's avatar
UmerHA committed
20
from itertools import product
21
22

import numpy as np
23
import pytest
24
import torch
25
from parameterized import parameterized
26
27
28
29
30

from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
)
31
from diffusers.utils import logging
32
from diffusers.utils.import_utils import is_peft_available
33
34

from ..testing_utils import (
35
    CaptureLogger,
36
    check_if_dicts_are_equal,
37
    floats_tensor,
38
    is_torch_version,
39
40
    require_peft_backend,
    require_peft_version_greater,
41
    require_torch_accelerator,
42
    require_transformers_version_greater,
43
    skip_mps,
44
45
46
47
48
    torch_device,
)


if is_peft_available():
49
    from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    from peft.tuners.tuners_utils import BaseTunerLayer
    from peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


76
77
78
79
80
81
82
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
    extracted = {
        k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
    }
    check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])


83
84
85
86
87
88
def initialize_dummy_state_dict(state_dict):
    if not all(v.device.type == "meta" for _, v in state_dict.items()):
        raise ValueError("`state_dict` has non-meta values.")
    return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}


89
90
91
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]


92
93
94
95
96
97
98
99
100
101
102
103
def determine_attention_kwargs_name(pipeline_class):
    call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()

    # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
    for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
        if possible_attention_kwargs in call_signature_keys:
            attention_kwargs_name = possible_attention_kwargs
            break
    assert attention_kwargs_name is not None
    return attention_kwargs_name


104
105
106
@require_peft_backend
class PeftLoraLoaderMixinTests:
    pipeline_class = None
Aryan's avatar
Aryan committed
107

108
109
    scheduler_cls = None
    scheduler_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
110

111
    has_two_text_encoders = False
112
    has_three_text_encoders = False
113
114
115
116
117
118
    text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, ""
    text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, ""
    text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, ""
    tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
    tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
    tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
Sayak Paul's avatar
Sayak Paul committed
119

120
    unet_kwargs = None
Sayak Paul's avatar
Sayak Paul committed
121
    transformer_cls = None
122
    transformer_kwargs = None
Aryan's avatar
Aryan committed
123
    vae_cls = AutoencoderKL
124
125
    vae_kwargs = None

Aryan's avatar
Aryan committed
126
    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
127
    denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
Aryan's avatar
Aryan committed
128

129
    def get_dummy_components(self, use_dora=False, lora_alpha=None):
130
131
132
133
134
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

135
        scheduler_cls = self.scheduler_cls
136
        rank = 4
137
        lora_alpha = rank if lora_alpha is None else lora_alpha
138
139

        torch.manual_seed(0)
140
141
142
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
Sayak Paul's avatar
Sayak Paul committed
143
            transformer = self.transformer_cls(**self.transformer_kwargs)
144
145
146
147

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        torch.manual_seed(0)
Aryan's avatar
Aryan committed
148
        vae = self.vae_cls(**self.vae_kwargs)
149

150
151
152
153
        text_encoder = self.text_encoder_cls.from_pretrained(
            self.text_encoder_id, subfolder=self.text_encoder_subfolder
        )
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
154

Sayak Paul's avatar
Sayak Paul committed
155
        if self.text_encoder_2_cls is not None:
156
157
158
159
160
161
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
                self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
            )
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
                self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
            )
162

Sayak Paul's avatar
Sayak Paul committed
163
        if self.text_encoder_3_cls is not None:
164
165
166
167
168
169
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
                self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
            )
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
                self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
            )
170

171
172
        text_lora_config = LoraConfig(
            r=rank,
173
            lora_alpha=lora_alpha,
Aryan's avatar
Aryan committed
174
            target_modules=self.text_encoder_target_modules,
175
            init_lora_weights=False,
176
            use_dora=use_dora,
177
178
        )

179
        denoiser_lora_config = LoraConfig(
180
            r=rank,
181
            lora_alpha=lora_alpha,
182
            target_modules=self.denoiser_target_modules,
183
184
            init_lora_weights=False,
            use_dora=use_dora,
185
186
        )

Sayak Paul's avatar
Sayak Paul committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})
213

214
        return pipeline_components, text_lora_config, denoiser_lora_config
215

Sayak Paul's avatar
Sayak Paul committed
216
217
218
219
    @property
    def output_shape(self):
        raise NotImplementedError

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = torch.manual_seed(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

241
    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
242
243
244
245
246
247
248
249
250
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

251
252
253
254
255
256
257
    def _get_lora_state_dicts(self, modules_to_save):
        state_dicts = {}
        for module_name, module in modules_to_save.items():
            if module is not None:
                state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
        return state_dicts

258
259
260
261
262
263
264
    def _get_lora_adapter_metadata(self, modules_to_save):
        metadatas = {}
        for module_name, module in modules_to_save.items():
            if module is not None:
                metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
        return metadatas

265
266
267
268
    def _get_modules_to_save(self, pipe, has_denoiser=False):
        modules_to_save = {}
        lora_loadable_modules = self.pipeline_class._lora_loadable_modules

269
270
271
272
273
        if (
            "text_encoder" in lora_loadable_modules
            and hasattr(pipe, "text_encoder")
            and getattr(pipe.text_encoder, "peft_config", None) is not None
        ):
274
275
            modules_to_save["text_encoder"] = pipe.text_encoder

276
277
278
279
280
        if (
            "text_encoder_2" in lora_loadable_modules
            and hasattr(pipe, "text_encoder_2")
            and getattr(pipe.text_encoder_2, "peft_config", None) is not None
        ):
281
282
283
284
285
286
287
288
289
290
291
            modules_to_save["text_encoder_2"] = pipe.text_encoder_2

        if has_denoiser:
            if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
                modules_to_save["unet"] = pipe.unet

            if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
                modules_to_save["transformer"] = pipe.transformer

        return modules_to_save

292
    def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        if text_lora_config is not None:
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

        if denoiser_lora_config is not None:
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
        else:
            denoiser = None

        if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
        return pipe, denoiser

315
316
317
318
    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
319
320
321
322
        components, text_lora_config, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
323

324
325
326
        _, _, inputs = self.get_dummy_inputs()
        output_no_lora = pipe(**inputs)[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
327
328
329
330
331
332

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
333
334
335
336
337
        components, text_lora_config, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
338

339
340
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
341

342
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
343

344
345
346
347
        output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
        )
348

349
350
351
    @require_peft_version_greater("0.13.1")
    def test_low_cpu_mem_usage_with_injection(self):
        """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
352
353
354
355
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
356

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
            self.assertTrue(
                "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
                "The LoRA params should be on 'meta' device.",
            )

            te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
            set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
            self.assertTrue(
                "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
                "No param should be on 'meta' device.",
            )

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
        self.assertTrue(
            "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
        )

        denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
        set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
        self.assertTrue(
            "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
        )

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
388
                self.assertTrue(
389
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
390
391
                )
                self.assertTrue(
392
                    "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
393
394
395
                    "The LoRA params should be on 'meta' device.",
                )

396
397
                te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
                set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
398
                self.assertTrue(
399
                    "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
400
401
402
                    "No param should be on 'meta' device.",
                )

403
404
405
        _, _, inputs = self.get_dummy_inputs()
        output_lora = pipe(**inputs)[0]
        self.assertTrue(output_lora.shape == self.output_shape)
406
407

    @require_peft_version_greater("0.13.1")
408
    @require_transformers_version_greater("4.45.2")
409
410
    def test_low_cpu_mem_usage_with_loading(self):
        """Tests if we can load LoRA state dict with low_cpu_mem_usage."""
411
412
413
414
415
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
416

417
418
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
419

420
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
421

422
        images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
423

424
425
426
427
428
429
        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
430

431
432
433
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            pipe.unload_lora_weights()
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
434

435
436
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
437

438
439
440
441
442
            images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )
443

444
445
446
            # Now, check for `low_cpu_mem_usage.`
            pipe.unload_lora_weights()
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
447

448
449
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
450

451
452
453
454
455
            images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
            self.assertTrue(
                np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
            )
456

457
458
459
460
461
    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
462
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
463
464
465
466
467
        components, text_lora_config, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
Aryan's avatar
Aryan committed
468

469
470
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
471

472
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
473

474
475
476
477
        output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
        )
478

479
480
        attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
        output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
Aryan's avatar
Aryan committed
481

482
483
484
485
        self.assertTrue(
            not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
            "Lora + scale should change the output",
        )
486

487
488
        attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
        output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
Aryan's avatar
Aryan committed
489

490
491
492
493
        self.assertTrue(
            np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
            "Lora + 0 scale should lead to same result as no LoRA",
        )
494
495
496
497
498
499

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
500
501
502
503
504
        components, text_lora_config, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
505

506
507
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
508

509
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
510

511
512
513
        pipe.fuse_lora()
        # Fusing should still keep the LoRA layers
        self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
514

515
516
517
518
519
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
520

521
522
523
524
        ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertFalse(
            np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
        )
525
526
527
528
529
530

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
531
532
533
534
535
        components, text_lora_config, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
536

537
538
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
539

540
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
541

542
543
544
        pipe.unload_lora_weights()
        # unloading should remove the LoRA layers
        self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
545

546
547
548
549
550
551
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                self.assertFalse(
                    check_if_lora_correctly_set(pipe.text_encoder_2),
                    "Lora not correctly unloaded in text encoder 2",
                )
552

553
554
555
556
557
        ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
            "Fused lora should change the output",
        )
558
559
560
561
562

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
563
564
565
566
567
        components, text_lora_config, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
568

569
570
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
571

572
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
573

574
        images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
575

576
577
578
        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
579

580
581
582
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
Sayak Paul's avatar
Sayak Paul committed
583

584
585
586
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            pipe.unload_lora_weights()
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
587

588
589
        for module_name, module in modules_to_save.items():
            self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
590

591
        images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
592

593
594
595
596
        self.assertTrue(
            np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
            "Loading from saved checkpoints should give same results.",
        )
597

598
599
600
601
602
603
    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        components, _, _ = self.get_dummy_components()
        # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
        text_lora_config = LoraConfig(
            r=4,
            rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
            lora_alpha=4,
            target_modules=self.text_encoder_target_modules,
            init_lora_weights=False,
            use_dora=False,
        )
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
618

619
620
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
621

622
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
623

624
625
626
627
628
629
630
631
632
        state_dict = {}
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
            # supports missing layers (PR#8324).
            state_dict = {
                f"text_encoder.{module_name}": param
                for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                if "text_model.encoder.layers.4" not in module_name
            }
633

634
635
636
637
638
639
640
641
642
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                state_dict.update(
                    {
                        f"text_encoder_2.{module_name}": param
                        for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                        if "text_model.encoder.layers.4" not in module_name
                    }
                )
643

644
645
646
647
        output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
        )
648

649
650
651
        # Unload lora and load it back using the pipe.load_lora_weights machinery
        pipe.unload_lora_weights()
        pipe.load_lora_weights(state_dict)
652

653
654
655
656
657
        output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
            "Removing adapters should change the output",
        )
658

659
    def test_simple_inference_save_pretrained_with_text_lora(self):
660
661
662
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
663
664
665
666
667
        components, text_lora_config, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
668

669
670
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
671

672
673
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
        images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
674

675
676
        with tempfile.TemporaryDirectory() as tmpdirname:
            pipe.save_pretrained(tmpdirname)
677

678
679
            pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
            pipe_from_pretrained.to(torch_device)
680

681
682
683
684
685
686
687
688
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            self.assertTrue(
                check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                "Lora not correctly set in text encoder",
            )

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
689
                self.assertTrue(
690
691
                    check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                    "Lora not correctly set in text encoder 2",
692
                )
693

694
        images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
695

696
697
698
699
        self.assertTrue(
            np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
            "Loading from saved checkpoints should give same results.",
        )
700

701
    def test_simple_inference_with_text_denoiser_lora_save_load(self):
702
703
704
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
705
706
707
708
709
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
710

711
712
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
713

714
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
715

716
        images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
717

718
719
720
721
722
723
        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
724

725
726
727
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            pipe.unload_lora_weights()
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
728

729
730
        for module_name, module in modules_to_save.items():
            self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
731

732
733
734
735
736
        images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
            "Loading from saved checkpoints should give same results.",
        )
737

738
    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
739
740
741
742
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
743
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
744
745
746
747
748
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
Aryan's avatar
Aryan committed
749

750
751
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
752

753
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
754

755
756
757
758
        output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
        )
759

760
761
        attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
        output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
762

763
764
765
766
        self.assertTrue(
            not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
            "Lora + scale should change the output",
        )
Aryan's avatar
Aryan committed
767

768
769
        attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
        output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
770

771
772
773
774
        self.assertTrue(
            np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
            "Lora + 0 scale should lead to same result as no LoRA",
        )
Aryan's avatar
Aryan committed
775

776
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
777
            self.assertTrue(
778
779
                pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                "The scaling parameter has not been correctly restored!",
780
781
            )

782
    def test_simple_inference_with_text_lora_denoiser_fused(self):
783
784
785
786
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
787
788
789
790
791
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
792

793
794
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
795

796
        pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
797

798
        pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
Aryan's avatar
Aryan committed
799

800
801
802
        # Fusing should still keep the LoRA layers
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
803

804
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
805

806
807
808
809
810
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
811

812
813
814
815
        output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertFalse(
            np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
        )
816

817
    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
818
819
820
821
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
822
823
824
825
826
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
827

828
829
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)
830

831
        pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
832

833
834
835
836
        pipe.unload_lora_weights()
        # unloading should remove the LoRA layers
        self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
        self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
837

838
839
840
841
842
843
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                self.assertFalse(
                    check_if_lora_correctly_set(pipe.text_encoder_2),
                    "Lora not correctly unloaded in text encoder 2",
                )
844

845
846
847
848
849
        output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
            "Fused lora should change the output",
        )
850

Aryan's avatar
Aryan committed
851
852
853
    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
854
855
856
857
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
858
859
860
861
862
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
863

864
        pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
865

866
867
868
        pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
        self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
        output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
869

870
871
872
        pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
        self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
        output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
873

874
875
876
        # unloading should remove the LoRA layers
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
Aryan's avatar
Aryan committed
877

878
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
879

880
881
882
883
884
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                )
885

886
887
888
889
890
        # Fuse and unfuse should lead to the same results
        self.assertTrue(
            np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
            "Fused lora should not change the output",
        )
891

892
    def test_simple_inference_with_text_denoiser_multi_adapter(self):
893
894
895
896
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
897
898
899
900
901
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
902

903
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
904

905
906
907
908
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
909

910
911
912
913
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        denoiser.add_adapter(denoiser_lora_config, "adapter-2")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
914

915
916
917
918
919
920
921
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
922

923
924
925
926
927
928
        pipe.set_adapters("adapter-1")
        output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertFalse(
            np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
            "Adapter outputs should be different.",
        )
929

930
931
932
933
934
935
        pipe.set_adapters("adapter-2")
        output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertFalse(
            np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
            "Adapter outputs should be different.",
        )
936

937
938
939
940
941
942
        pipe.set_adapters(["adapter-1", "adapter-2"])
        output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertFalse(
            np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter outputs should be different.",
        )
943

944
945
946
947
948
        # Fuse and unfuse should lead to the same results
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
            "Adapter 1 and 2 should give different results",
        )
949

950
951
952
953
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 1 and mixed adapters should give different results",
        )
954

955
956
957
958
        self.assertFalse(
            np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 2 and mixed adapters should give different results",
        )
959

960
961
        pipe.disable_lora()
        output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
962

963
964
965
966
        self.assertTrue(
            np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
            "output with no lora and output with lora disabled should give same results",
        )
967

968
    def test_wrong_adapter_name_raises_error(self):
969
970
        adapter_name = "adapter-1"

971
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
972
973
974
975
976
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

977
        pipe, _ = self.add_adapters_to_pipeline(
978
979
            pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
        )
980
981
982
983
984
985
986

        with self.assertRaises(ValueError) as err_context:
            pipe.set_adapters("test")

        self.assertTrue("not in the list of present adapters" in str(err_context.exception))

        # test this works.
987
        pipe.set_adapters(adapter_name)
988
989
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

990
    def test_multiple_wrong_adapter_name_raises_error(self):
991
        adapter_name = "adapter-1"
992
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
993
994
995
996
997
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

998
        pipe, _ = self.add_adapters_to_pipeline(
999
1000
            pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
        )
1001
1002
1003
1004
1005

        scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
        logger = logging.get_logger("diffusers.loaders.lora_base")
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
1006
            pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
1007
1008
1009
1010
1011
1012

        wrong_components = sorted(set(scale_with_wrong_components.keys()))
        msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
        self.assertTrue(msg in str(cap_logger.out))

        # test this works.
1013
        pipe.set_adapters(adapter_name)
1014
1015
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

1016
    def test_simple_inference_with_text_denoiser_block_scale(self):
UmerHA's avatar
UmerHA committed
1017
1018
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
Aryan's avatar
Aryan committed
1019
        one adapter and set different weights for different blocks (i.e. block lora)
UmerHA's avatar
UmerHA committed
1020
        """
1021
1022
1023
1024
1025
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
UmerHA's avatar
UmerHA committed
1026

1027
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1028

1029
1030
        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
Aryan's avatar
Aryan committed
1031

1032
1033
1034
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1035

1036
1037
1038
1039
1040
1041
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
UmerHA's avatar
UmerHA committed
1042

1043
1044
1045
        weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
        pipe.set_adapters("adapter-1", weights_1)
        output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1046

1047
1048
1049
        weights_2 = {"unet": {"up": 5}}
        pipe.set_adapters("adapter-1", weights_2)
        output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1050

1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
        self.assertFalse(
            np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
            "LoRA weights 1 and 2 should give different results",
        )
        self.assertFalse(
            np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
            "No adapter and LoRA weights 1 should give different results",
        )
        self.assertFalse(
            np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
            "No adapter and LoRA weights 2 should give different results",
        )
UmerHA's avatar
UmerHA committed
1063

1064
1065
        pipe.disable_lora()
        output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1066

1067
1068
1069
1070
        self.assertTrue(
            np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
            "output with no lora and output with lora disabled should give same results",
        )
UmerHA's avatar
UmerHA committed
1071

1072
    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
UmerHA's avatar
UmerHA committed
1073
1074
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
1075
        multiple adapters and set different weights for different blocks (i.e. block lora)
UmerHA's avatar
UmerHA committed
1076
        """
1077
1078
1079
1080
1081
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
UmerHA's avatar
UmerHA committed
1082

1083
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1084

1085
1086
1087
1088
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
UmerHA's avatar
UmerHA committed
1089

1090
1091
1092
1093
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        denoiser.add_adapter(denoiser_lora_config, "adapter-2")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
UmerHA's avatar
UmerHA committed
1094

1095
1096
1097
1098
1099
1100
1101
        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
UmerHA's avatar
UmerHA committed
1102

1103
1104
        scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
        scales_2 = {"unet": {"down": 5, "mid": 5}}
UmerHA's avatar
UmerHA committed
1105

1106
1107
        pipe.set_adapters("adapter-1", scales_1)
        output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1108

1109
1110
        pipe.set_adapters("adapter-2", scales_2)
        output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1111

1112
1113
        pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
        output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1114

1115
1116
1117
1118
1119
        # Fuse and unfuse should lead to the same results
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
            "Adapter 1 and 2 should give different results",
        )
UmerHA's avatar
UmerHA committed
1120

1121
1122
1123
1124
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 1 and mixed adapters should give different results",
        )
UmerHA's avatar
UmerHA committed
1125

1126
1127
1128
1129
        self.assertFalse(
            np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 2 and mixed adapters should give different results",
        )
UmerHA's avatar
UmerHA committed
1130

1131
1132
        pipe.disable_lora()
        output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
UmerHA's avatar
UmerHA committed
1133

1134
1135
1136
1137
        self.assertTrue(
            np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
            "output with no lora and output with lora disabled should give same results",
        )
UmerHA's avatar
UmerHA committed
1138

1139
1140
1141
        # a mismatching number of adapter_names and adapter_weights should raise an error
        with self.assertRaises(ValueError):
            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
UmerHA's avatar
UmerHA committed
1142

1143
    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
UmerHA's avatar
UmerHA committed
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

1209
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
UmerHA's avatar
UmerHA committed
1210
1211
1212
1213
1214
1215
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1216
1217
1218

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1219

1220
        if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1221
1222
1223
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
UmerHA's avatar
UmerHA committed
1224
1225
1226
1227
1228
1229
1230
1231

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

1232
    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1233
1234
1235
1236
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
1237
1238
1239
1240
1241
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
1242

1243
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1244

1245
1246
1247
1248
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1249

1250
1251
1252
1253
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        denoiser.add_adapter(denoiser_lora_config, "adapter-2")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1254

1255
1256
1257
1258
1259
1260
1261
1262
        if self.has_two_text_encoders or self.has_three_text_encoders:
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
1263

1264
1265
        pipe.set_adapters("adapter-1")
        output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1266

1267
1268
        pipe.set_adapters("adapter-2")
        output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1269

1270
1271
        pipe.set_adapters(["adapter-1", "adapter-2"])
        output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1272

1273
1274
1275
1276
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
            "Adapter 1 and 2 should give different results",
        )
1277

1278
1279
1280
1281
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 1 and mixed adapters should give different results",
        )
1282

1283
1284
1285
1286
        self.assertFalse(
            np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 2 and mixed adapters should give different results",
        )
1287

1288
1289
        pipe.delete_adapters("adapter-1")
        output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1290

1291
1292
1293
1294
        self.assertTrue(
            np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
            "Adapter 1 and 2 should give different results",
        )
1295

1296
1297
        pipe.delete_adapters("adapter-2")
        output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1298

1299
1300
1301
1302
        self.assertTrue(
            np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
            "output with no lora and output with lora disabled should give same results",
        )
1303

1304
1305
1306
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
1307

1308
1309
1310
1311
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        denoiser.add_adapter(denoiser_lora_config, "adapter-2")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1312

1313
1314
        pipe.set_adapters(["adapter-1", "adapter-2"])
        pipe.delete_adapters(["adapter-1", "adapter-2"])
1315

1316
        output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
1317

1318
1319
1320
1321
        self.assertTrue(
            np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
            "output with no lora and output with lora disabled should give same results",
        )
1322

1323
    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1324
1325
1326
1327
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
1328
1329
1330
1331
1332
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
1333

1334
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1335

1336
1337
1338
1339
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1340

1341
1342
1343
1344
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        denoiser.add_adapter(denoiser_lora_config, "adapter-2")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1345

1346
1347
1348
1349
1350
1351
1352
1353
        if self.has_two_text_encoders or self.has_three_text_encoders:
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
1354

1355
1356
        pipe.set_adapters("adapter-1")
        output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1357

1358
1359
        pipe.set_adapters("adapter-2")
        output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
1360

1361
1362
        pipe.set_adapters(["adapter-1", "adapter-2"])
        output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
1363

1364
1365
1366
1367
1368
        # Fuse and unfuse should lead to the same results
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
            "Adapter 1 and 2 should give different results",
        )
1369

1370
1371
1372
1373
        self.assertFalse(
            np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 1 and mixed adapters should give different results",
        )
1374

1375
1376
1377
1378
        self.assertFalse(
            np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Adapter 2 and mixed adapters should give different results",
        )
1379

1380
1381
        pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
        output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
1382

1383
1384
1385
1386
        self.assertFalse(
            np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
            "Weighted adapter and mixed adapter should give different results",
        )
1387

1388
1389
        pipe.disable_lora()
        output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
1390

1391
1392
1393
1394
        self.assertTrue(
            np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
            "output with no lora and output with lora disabled should give same results",
        )
1395

1396
    @skip_mps
1397
    @pytest.mark.xfail(
1398
        condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
1399
        reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
1400
        strict=False,
1401
    )
1402
1403
1404
1405
1406
1407
    def test_lora_fuse_nan(self):
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
1408

1409
1410
1411
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1412

1413
1414
1415
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1416

1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
        # corrupt one LoRA weight with `inf` values
        with torch.no_grad():
            if self.unet_kwargs:
                pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
                    "inf"
                )
            else:
                named_modules = [name for name, _ in pipe.transformer.named_modules()]
                possible_tower_names = [
                    "transformer_blocks",
                    "blocks",
                    "joint_transformer_blocks",
                    "single_transformer_blocks",
                ]
                filtered_tower_names = [
                    tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
                ]
                if len(filtered_tower_names) == 0:
                    reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
                    raise ValueError(reason)
                for tower_name in filtered_tower_names:
                    transformer_tower = getattr(pipe.transformer, tower_name)
                    has_attn1 = any("attn1" in name for name in named_modules)
                    if has_attn1:
                        transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
                    else:
                        transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")

        # with `safe_fusing=True` we should see an Error
        with self.assertRaises(ValueError):
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)

        # without we should not see an error, but every image will be black
        pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
        out = pipe(**inputs)[0]

        self.assertTrue(np.isnan(out).all())
1454
1455
1456
1457
1458
1459

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
1460
1461
1462
1463
1464
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
1465

1466
        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
Aryan's avatar
Aryan committed
1467

1468
1469
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1470

1471
1472
        adapter_names = pipe.get_active_adapters()
        self.assertListEqual(adapter_names, ["adapter-1"])
1473

1474
1475
        pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
        denoiser.add_adapter(denoiser_lora_config, "adapter-2")
1476

1477
1478
        adapter_names = pipe.get_active_adapters()
        self.assertListEqual(adapter_names, ["adapter-2"])
1479

1480
1481
        pipe.set_adapters(["adapter-1", "adapter-2"])
        self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
1482
1483
1484
1485
1486
1487

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
1488
1489
1490
1491
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
1492

1493
1494
1495
1496
1497
        # 1.
        dicts_to_be_checked = {}
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
Aryan's avatar
Aryan committed
1498

1499
1500
1501
1502
1503
1504
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            dicts_to_be_checked.update({"unet": ["adapter-1"]})
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            dicts_to_be_checked.update({"transformer": ["adapter-1"]})
1505

1506
        self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
Aryan's avatar
Aryan committed
1507

1508
1509
1510
1511
1512
        # 2.
        dicts_to_be_checked = {}
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
            dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
Aryan's avatar
Aryan committed
1513

1514
1515
1516
1517
1518
1519
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
            dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
            dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
1520

1521
        self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
Aryan's avatar
Aryan committed
1522

1523
1524
        # 3.
        pipe.set_adapters(["adapter-1", "adapter-2"])
Aryan's avatar
Aryan committed
1525

1526
1527
1528
        dicts_to_be_checked = {}
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
Aryan's avatar
Aryan committed
1529

1530
1531
1532
1533
        if self.unet_kwargs is not None:
            dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
        else:
            dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
Aryan's avatar
Aryan committed
1534

1535
1536
1537
1538
        self.assertDictEqual(
            pipe.get_list_adapters(),
            dicts_to_be_checked,
        )
1539

1540
1541
1542
1543
        # 4.
        dicts_to_be_checked = {}
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
Aryan's avatar
Aryan committed
1544

1545
1546
1547
1548
1549
1550
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
            dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
            dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
Aryan's avatar
Aryan committed
1551

1552
        self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
1553
1554

    @require_peft_version_greater(peft_version="0.6.2")
Aryan's avatar
Aryan committed
1555
1556
1557
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
1558
1559
1560
1561
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)

        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
        denoiser.add_adapter(denoiser_lora_config, "adapter-2")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")

        # set them to multi-adapter inference mode
        pipe.set_adapters(["adapter-1", "adapter-2"])
        outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

        pipe.set_adapters(["adapter-1"])
        outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]

        pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
        self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

        # Fusing should still keep the LoRA layers so output should remain the same
        outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]

        self.assertTrue(
            np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
            "Fused lora should not change the output",
        )

        pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
        self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

        self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                )

        pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
        self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

        # Fusing should still keep the LoRA layers
        output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
            "Fused lora should not change the output",
        )
        pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
        self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

    def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

        for lora_scale in [1.0, 0.8]:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
1639
1640
1641
1642
1643
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

Aryan's avatar
Aryan committed
1644
            output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Sayak Paul's avatar
Sayak Paul committed
1645
            self.assertTrue(output_no_lora.shape == self.output_shape)
1646

Aryan's avatar
Aryan committed
1647
1648
1649
1650
1651
1652
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

Aryan's avatar
Aryan committed
1653
1654
1655
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1656

1657
            if self.has_two_text_encoders or self.has_three_text_encoders:
Sayak Paul's avatar
Sayak Paul committed
1658
1659
1660
1661
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
1662
1663
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
Sayak Paul's avatar
Sayak Paul committed
1664
                    )
1665
1666

            pipe.set_adapters(["adapter-1"])
1667
1668
            attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
            outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
1669

1670
1671
1672
1673
1674
            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules,
                adapter_names=["adapter-1"],
                lora_scale=lora_scale,
            )
1675
            self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
1676

Aryan's avatar
Aryan committed
1677
            outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
1678
1679

            self.assertTrue(
Aryan's avatar
Aryan committed
1680
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
1681
1682
                "Fused lora should not change the output",
            )
1683
1684
1685
            self.assertFalse(
                np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
                "LoRA should change the output",
1686
1687
            )

1688
1689
    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
1690
1691
1692
1693
1694
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
1695

1696
1697
        output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1698

1699
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
1700

1701
        output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1702

1703
1704
1705
1706
        self.assertFalse(
            np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
            "DoRA lora should change the output",
        )
1707

1708
1709
    def test_missing_keys_warning(self):
        # Skip text encoder check for now as that is handled with `transformers`.
1710
        components, _, denoiser_lora_config = self.get_dummy_components()
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

        # To make things dynamic since we cannot settle with a single key for all the models where we
        # offer PEFT support.
        missing_key = [k for k in state_dict if "lora_A" in k][0]
        del state_dict[missing_key]

1734
        logger = logging.get_logger("diffusers.utils.peft_utils")
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        # Since the missing key won't contain the adapter name ("default_0").
        # Also strip out the component prefix (such as "unet." from `missing_key`).
        component = list({k.split(".")[0] for k in state_dict})[0]
        self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

    def test_unexpected_keys_warning(self):
        # Skip text encoder check for now as that is handled with `transformers`.
1746
        components, _, denoiser_lora_config = self.get_dummy_components()
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
            state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

        unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
        state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)

1768
        logger = logging.get_logger("diffusers.utils.peft_utils")
1769
1770
1771
1772
1773
1774
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        self.assertTrue(".diffusers_cat" in cap_logger.out)

1775
    @unittest.skip("This is failing for now - need to investigate")
1776
    def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1777
1778
1779
1780
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
1781
1782
1783
1784
1785
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
1786

1787
        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
1788

1789
1790
        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
        pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
1791

1792
1793
        if self.has_two_text_encoders or self.has_three_text_encoders:
            pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
1794

1795
1796
        # Just makes sure it works.
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1797
1798
1799
1800
1801
1802
1803

    def test_modify_padding_mode(self):
        def set_pad_mode(network, mode="circular"):
            for _, module in network.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    module.padding_mode = mode

1804
1805
1806
1807
1808
1809
1810
        components, _, _ = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _pad_mode = "circular"
        set_pad_mode(pipe.vae, _pad_mode)
        set_pad_mode(pipe.unet, _pad_mode)
1811

1812
1813
        _, _, inputs = self.get_dummy_inputs()
        _ = pipe(**inputs)[0]
1814

1815
1816
    def test_logs_info_when_no_lora_keys_found(self):
        # Skip text encoder check for now as that is handled with `transformers`.
1817
        components, _, _ = self.get_dummy_components()
1818
1819
1820
1821
1822
1823
1824
1825
1826
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        _, _, inputs = self.get_dummy_inputs(with_generator=False)
        original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

        no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
        logger = logging.get_logger("diffusers.loaders.peft")
1827
        logger.setLevel(logging.WARNING)
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846

        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(no_op_state_dict)
        out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0]

        denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
        self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
        self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))

        # test only for text encoder
        for lora_module in self.pipeline_class._lora_loadable_modules:
            if "text_encoder" in lora_module:
                text_encoder = getattr(pipe, lora_module)
                if lora_module == "text_encoder":
                    prefix = "text_encoder"
                elif lora_module == "text_encoder_2":
                    prefix = "text_encoder_2"

                logger = logging.get_logger("diffusers.loaders.lora_base")
1847
                logger.setLevel(logging.WARNING)
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857

                with CaptureLogger(logger) as cap_logger:
                    self.pipeline_class.load_lora_into_text_encoder(
                        no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
                    )

                self.assertTrue(
                    cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
                )

1858
1859
    def test_set_adapters_match_attention_kwargs(self):
        """Test to check if outputs after `set_adapters()` and attention kwargs match."""
1860
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)

        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

        lora_scale = 0.5
        attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
        output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
        self.assertFalse(
            np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
            "Lora + scale should change the output",
        )

        pipe.set_adapters("default", lora_scale)
        output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(
            not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
            "Lora + scale should change the output",
        )
        self.assertTrue(
            np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
            "Lora + scale should match the output of `set_adapters()`.",
        )

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
            )
1897

1898
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
1899
1900
1901
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device)
            pipe.set_progress_bar_config(disable=None)
1902
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
1903

1904
1905
            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
1906

1907
1908
1909
            output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
            self.assertTrue(
                not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
1910
1911
1912
                "Lora + scale should change the output",
            )
            self.assertTrue(
1913
1914
                np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results as attention_kwargs.",
1915
1916
            )
            self.assertTrue(
1917
1918
                np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results as set_adapters().",
1919
1920
            )

1921
1922
1923
1924
    @require_peft_version_greater("0.13.2")
    def test_lora_B_bias(self):
        # Currently, this test is only relevant for Flux Control LoRA as we are not
        # aware of any other LoRA checkpoint that has its `lora_B` biases trained.
1925
        components, _, denoiser_lora_config = self.get_dummy_components()
1926
1927
1928
1929
1930
1931
1932
1933
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        # keep track of the bias values of the base layers to perform checks later.
        bias_values = {}
        denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
        for name, module in denoiser.named_modules():
1934
            if any(k in name for k in self.denoiser_target_modules):
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
                if module.bias is not None:
                    bias_values[name] = module.bias.data.clone()

        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

        denoiser_lora_config.lora_bias = False
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
        lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
        pipe.delete_adapters("adapter-1")

        denoiser_lora_config.lora_bias = True
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
        lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

        self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
        self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
        self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))

    def test_correct_lora_configs_with_different_ranks(self):
1962
        components, _, denoiser_lora_config = self.get_dummy_components()
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")

        lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]

        if self.unet_kwargs is not None:
            pipe.unet.delete_adapters("adapter-1")
        else:
            pipe.transformer.delete_adapters("adapter-1")

        denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
        for name, _ in denoiser.named_modules():
            if "to_k" in name and "attn" in name and "lora" not in name:
                module_name_to_rank_update = name.replace(".base_layer.", ".")
                break

        # change the rank_pattern
        updated_rank = denoiser_lora_config.r * 2
        denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}

        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern

        self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})

        lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
        self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))

        if self.unet_kwargs is not None:
            pipe.unet.delete_adapters("adapter-1")
        else:
            pipe.transformer.delete_adapters("adapter-1")

        # similarly change the alpha_pattern
        updated_alpha = denoiser_lora_config.lora_alpha * 2
        denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(
                pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
            )
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(
                pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
            )

        lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
        self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
Aryan's avatar
Aryan committed
2027
2028

    def test_layerwise_casting_inference_denoiser(self):
2029
2030
        from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
        from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
Aryan's avatar
Aryan committed
2031
2032
2033
2034
2035
2036

        def check_linear_dtype(module, storage_dtype, compute_dtype):
            patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
            if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
                patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
            for name, submodule in module.named_modules():
2037
                if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
Aryan's avatar
Aryan committed
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
                    continue
                dtype_to_check = storage_dtype
                if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
                    dtype_to_check = compute_dtype
                if getattr(submodule, "weight", None) is not None:
                    self.assertEqual(submodule.weight.dtype, dtype_to_check)
                if getattr(submodule, "bias", None) is not None:
                    self.assertEqual(submodule.bias.dtype, dtype_to_check)

        def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
2048
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
Aryan's avatar
Aryan committed
2049
2050
2051
2052
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device, dtype=compute_dtype)
            pipe.set_progress_bar_config(disable=None)

2053
            pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
Aryan's avatar
Aryan committed
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070

            if storage_dtype is not None:
                denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
                check_linear_dtype(denoiser, storage_dtype, compute_dtype)

            return pipe

        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe_fp32 = initialize_pipeline(storage_dtype=None)
        pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]

        pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
        pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]

        pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
        pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087

    @require_peft_version_greater("0.14.0")
    def test_layerwise_casting_peft_input_autocast_denoiser(self):
        r"""
        A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
        is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
        cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
        In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
        this test will fail with the following error:

        ```
        RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
        ```

        See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
        """

2088
        from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
        from diffusers.hooks.layerwise_casting import (
            _PEFT_AUTOCAST_DISABLE_HOOK,
            DEFAULT_SKIP_MODULES_PATTERN,
            apply_layerwise_casting,
        )

        storage_dtype = torch.float8_e4m3fn
        compute_dtype = torch.float32

        def check_module(denoiser):
            # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
            for name, module in denoiser.named_modules():
2101
                if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
                    continue
                dtype_to_check = storage_dtype
                if any(re.search(pattern, name) for pattern in patterns_to_check):
                    dtype_to_check = compute_dtype
                if getattr(module, "weight", None) is not None:
                    self.assertEqual(module.weight.dtype, dtype_to_check)
                if getattr(module, "bias", None) is not None:
                    self.assertEqual(module.bias.dtype, dtype_to_check)
                if isinstance(module, BaseTunerLayer):
                    self.assertTrue(getattr(module, "_diffusers_hook", None) is not None)
                    self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)

        # 1. Test forward with add_adapter
2115
        components, _, denoiser_lora_config = self.get_dummy_components()
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device, dtype=compute_dtype)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
        if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None:
            patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns)

        apply_layerwise_casting(
            denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check
        )
        check_module(denoiser)

        _, _, inputs = self.get_dummy_inputs(with_generator=False)
        pipe(**inputs, generator=torch.manual_seed(0))[0]

        # 2. Test forward with load_lora_weights
        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
            )

            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
2145
            components, _, _ = self.get_dummy_components()
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(torch_device, dtype=compute_dtype)
            pipe.set_progress_bar_config(disable=None)
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            apply_layerwise_casting(
                denoiser,
                storage_dtype=storage_dtype,
                compute_dtype=compute_dtype,
                skip_modules_pattern=patterns_to_check,
            )
            check_module(denoiser)

            _, _, inputs = self.get_dummy_inputs(with_generator=False)
            pipe(**inputs, generator=torch.manual_seed(0))[0]
2162

2163
2164
    @parameterized.expand([4, 8, 16])
    def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
2165
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
2166
2167
        pipe = self.pipeline_class(**components)

2168
        pipe, _ = self.add_adapters_to_pipeline(
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
            pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
        )

        with tempfile.TemporaryDirectory() as tmpdir:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
            self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
            pipe.unload_lora_weights()

            out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
            if len(out) == 3:
                _, _, parsed_metadata = out
            elif len(out) == 2:
                _, parsed_metadata = out

            denoiser_key = (
                f"{self.pipeline_class.transformer_name}"
                if self.transformer_kwargs is not None
                else f"{self.pipeline_class.unet_name}"
            )
            self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
            check_module_lora_metadata(
                parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
            )

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                text_encoder_key = self.pipeline_class.text_encoder_name
                self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
                check_module_lora_metadata(
                    parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
                )

            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                text_encoder_2_key = "text_encoder_2"
                self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
                check_module_lora_metadata(
                    parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
                )

    @parameterized.expand([4, 8, 16])
    def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
2211
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
2212
2213
2214
2215
2216
2217
        pipe = self.pipeline_class(**components).to(torch_device)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(output_no_lora.shape == self.output_shape)

2218
        pipe, _ = self.add_adapters_to_pipeline(
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
            pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
        )
        output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

        with tempfile.TemporaryDirectory() as tmpdir:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
            self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
            pipe.unload_lora_weights()
            pipe.load_lora_weights(tmpdir)

            output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

            self.assertTrue(
                np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
            )

2237
2238
    def test_lora_unload_add_adapter(self):
        """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
2239
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
        pipe = self.pipeline_class(**components).to(torch_device)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe, _ = self.add_adapters_to_pipeline(
            pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
        )
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

        # unload and then add.
        pipe.unload_lora_weights()
        pipe, _ = self.add_adapters_to_pipeline(
            pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
        )
        _ = pipe(**inputs, generator=torch.manual_seed(0))[0]

2255
2256
    def test_inference_load_delete_load_adapters(self):
        "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
2257
2258
2259
2260
2261
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
2262

2263
        output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2264

2265
2266
2267
        if "text_encoder" in self.pipeline_class._lora_loadable_modules:
            pipe.text_encoder.add_adapter(text_lora_config)
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
2268

2269
2270
2271
        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2272

2273
2274
2275
2276
2277
2278
2279
        if self.has_two_text_encoders or self.has_three_text_encoders:
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
2280

2281
        output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
2282

2283
2284
2285
2286
2287
        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
2288

2289
2290
2291
2292
2293
            # First, delete adapter and compare.
            pipe.delete_adapters(pipe.get_active_adapters()[0])
            output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
            self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
            self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
2294

2295
2296
2297
2298
            # Then load adapter and compare.
            pipe.load_lora_weights(tmpdirname)
            output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
            self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
2299
2300
2301
2302
2303
2304
2305

    def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
        from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook

        onload_device = torch_device
        offload_device = torch.device("cpu")

2306
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
            )
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))

2323
            components, _, _ = self.get_dummy_components()
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet

            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
            check_if_lora_correctly_set(denoiser)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            # Test group offloading with load_lora_weights
            denoiser.enable_group_offload(
                onload_device=onload_device,
                offload_device=offload_device,
                offload_type=offload_type,
                num_blocks_per_group=1,
                use_stream=use_stream,
            )
2340
2341
2342
2343
            # Place other model-level components on `torch_device`.
            for _, component in pipe.components.items():
                if isinstance(component, torch.nn.Module):
                    component.to(torch_device)
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
            group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
            self.assertTrue(group_offload_hook_1 is not None)
            output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]

            # Test group offloading after removing the lora
            pipe.unload_lora_weights()
            group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
            self.assertTrue(group_offload_hook_2 is not None)
            output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]  # noqa: F841

            # Add the lora again and check if group offloading works
            pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
            check_if_lora_correctly_set(denoiser)
            group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
            self.assertTrue(group_offload_hook_3 is not None)
            output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]

            self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))

    @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
    @require_torch_accelerator
    def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
        for cls in inspect.getmro(self.__class__):
            if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
                # Skip this test if it is overwritten by child class. We need to do this because parameterized
                # materializes the test methods on invocation which cannot be overridden.
                return
        self._test_group_offloading_inference_denoiser(offload_type, use_stream)
2372
2373
2374

    @require_torch_accelerator
    def test_lora_loading_model_cpu_offload(self):
2375
        components, _, denoiser_lora_config = self.get_dummy_components()
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
        _, _, inputs = self.get_dummy_inputs(with_generator=False)
        pipe = self.pipeline_class(**components)
        pipe = pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
            )
            # reinitialize the pipeline to mimic the inference workflow.
2394
            components, _, denoiser_lora_config = self.get_dummy_components()
2395
2396
2397
2398
2399
2400
2401
2402
            pipe = self.pipeline_class(**components)
            pipe.enable_model_cpu_offload(device=torch_device)
            pipe.load_lora_weights(tmpdirname)
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
        self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))