test_mixed_int8.py 31.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest

import numpy as np
20
import pytest
21
from huggingface_hub import hf_hub_download
22
from PIL import Image
23

hlky's avatar
hlky committed
24
25
26
from diffusers import (
    BitsAndBytesConfig,
    DiffusionPipeline,
27
    FluxControlPipeline,
hlky's avatar
hlky committed
28
29
30
31
32
    FluxTransformer2DModel,
    SanaTransformer2DModel,
    SD3Transformer2DModel,
    logging,
)
33
from diffusers.quantizers import PipelineQuantizationConfig
34
from diffusers.utils import is_accelerate_version
35
36
from diffusers.utils.testing_utils import (
    CaptureLogger,
37
    backend_empty_cache,
38
39
40
41
42
43
44
    is_bitsandbytes_available,
    is_torch_available,
    is_transformers_available,
    load_pt,
    numpy_cosine_similarity_distance,
    require_accelerate,
    require_bitsandbytes_version_greater,
45
    require_peft_backend,
46
    require_peft_version_greater,
47
    require_torch,
48
    require_torch_accelerator,
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    require_transformers_version_greater,
    slow,
    torch_device,
)


def get_some_linear_layer(model):
    if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
        return model.transformer_blocks[0].attn.to_q
    else:
        return NotImplementedError("Don't know what layer to retrieve here.")


if is_transformers_available():
63
    from transformers import BitsAndBytesConfig as BnbConfig
64
65
66
67
68
    from transformers import T5EncoderModel

if is_torch_available():
    import torch

69
    from ..utils import LoRALayer, get_memory_consumption_stat
70
71
72
73
74


if is_bitsandbytes_available():
    import bitsandbytes as bnb

75
76
    from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear

77
78
79
80

@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@require_torch
81
@require_torch_accelerator
82
83
84
85
86
87
88
89
90
@slow
class Base8bitTests(unittest.TestCase):
    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
    # Therefore here we use only SD3 to test our module
    model_name = "stabilityai/stable-diffusion-3-medium-diffusers"

    # This was obtained on audace so the number might slightly change
    expected_rel_difference = 1.94

91
92
    expected_memory_saving_ratio = 0.7

93
94
95
96
97
98
    prompt = "a beautiful sunset amidst the mountains."
    num_inference_steps = 10
    seed = 0

    def get_dummy_inputs(self):
        prompt_embeds = load_pt(
99
100
            "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
            map_location="cpu",
101
102
        )
        pooled_prompt_embeds = load_pt(
103
104
            "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
            map_location="cpu",
105
106
        )
        latent_model_input = load_pt(
107
108
            "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
            map_location="cpu",
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        )

        input_dict_for_transformer = {
            "hidden_states": latent_model_input,
            "encoder_hidden_states": prompt_embeds,
            "pooled_projections": pooled_prompt_embeds,
            "timestep": torch.Tensor([1.0]),
            "return_dict": False,
        }
        return input_dict_for_transformer


class BnB8bitBasicTests(Base8bitTests):
    def setUp(self):
123
        gc.collect()
124
        backend_empty_cache(torch_device)
125

126
127
128
129
130
131
        # Models
        self.model_fp16 = SD3Transformer2DModel.from_pretrained(
            self.model_name, subfolder="transformer", torch_dtype=torch.float16
        )
        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        self.model_8bit = SD3Transformer2DModel.from_pretrained(
132
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
133
134
135
        )

    def tearDown(self):
136
137
138
139
        if hasattr(self, "model_fp16"):
            del self.model_fp16
        if hasattr(self, "model_8bit"):
            del self.model_8bit
140
141

        gc.collect()
142
        backend_empty_cache(torch_device)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    def test_quantization_num_parameters(self):
        r"""
        Test if the number of returned parameters is correct
        """
        num_params_8bit = self.model_8bit.num_parameters()
        num_params_fp16 = self.model_fp16.num_parameters()

        self.assertEqual(num_params_8bit, num_params_fp16)

    def test_quantization_config_json_serialization(self):
        r"""
        A simple test to check if the quantization config is correctly serialized and deserialized
        """
        config = self.model_8bit.config

        self.assertTrue("quantization_config" in config)

        _ = config["quantization_config"].to_dict()
        _ = config["quantization_config"].to_diff_dict()

        _ = config["quantization_config"].to_json_string()

    def test_memory_footprint(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        mem_fp16 = self.model_fp16.get_memory_footprint()
        mem_8bit = self.model_8bit.get_memory_footprint()

        self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2)
        linear = get_some_linear_layer(self.model_8bit)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    def test_model_memory_usage(self):
        # Delete to not let anything interfere.
        del self.model_8bit, self.model_fp16

        # Re-instantiate.
        inputs = self.get_dummy_inputs()
        inputs = {
            k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
        }
        model_fp16 = SD3Transformer2DModel.from_pretrained(
            self.model_name, subfolder="transformer", torch_dtype=torch.float16
        ).to(torch_device)
        unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
        del model_fp16

        config = BitsAndBytesConfig(load_in_8bit=True)
        model_8bit = SD3Transformer2DModel.from_pretrained(
            self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
        )
        quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
        assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio

200
201
    def test_original_dtype(self):
        r"""
202
        A simple test to check if the model successfully stores the original dtype
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        """
        self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config)
        self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
        self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16)

    def test_keep_modules_in_fp32(self):
        r"""
        A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
        Also ensures if inference works.
        """
        fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
        SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]

        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        model = SD3Transformer2DModel.from_pretrained(
218
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
219
220
221
222
223
224
225
226
227
228
229
        )

        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name in model._keep_in_fp32_modules:
                    self.assertTrue(module.weight.dtype == torch.float32)
                else:
                    # 8-bit parameters are packed in int8 variables
                    self.assertTrue(module.weight.dtype == torch.int8)

        # test if inference works.
jiqing-feng's avatar
jiqing-feng committed
230
        with torch.no_grad() and torch.autocast(model.device.type, dtype=torch.float16):
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            input_dict_for_transformer = self.get_dummy_inputs()
            model_inputs = {
                k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
            }
            model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
            _ = model(**model_inputs)

        SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules

    def test_linear_are_8bit(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        self.model_fp16.get_memory_footprint()
        self.model_8bit.get_memory_footprint()

        for name, module in self.model_8bit.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name not in ["proj_out"]:
                    # 8-bit parameters are packed in int8 variables
                    self.assertTrue(module.weight.dtype == torch.int8)

    def test_llm_skip(self):
        r"""
        A simple test to check if `llm_int8_skip_modules` works as expected
        """
        config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
        model_8bit = SD3Transformer2DModel.from_pretrained(
260
            self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
261
262
263
264
265
        )
        linear = get_some_linear_layer(model_8bit)
        self.assertTrue(linear.weight.dtype == torch.int8)
        self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))

266
        self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
267
268
269
270
        self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)

    def test_config_from_pretrained(self):
        transformer_8bit = FluxTransformer2DModel.from_pretrained(
271
            "hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        )
        linear = get_some_linear_layer(transformer_8bit)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
        self.assertTrue(hasattr(linear.weight, "SCB"))

    def test_device_and_dtype_assignment(self):
        r"""
        Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
        Checks also if other models are casted correctly.
        """
        with self.assertRaises(ValueError):
            # Tries with `str`
            self.model_8bit.to("cpu")

        with self.assertRaises(ValueError):
            # Tries with a `dtype``
            self.model_8bit.to(torch.float16)

        with self.assertRaises(ValueError):
            # Tries with a `device`
292
            self.model_8bit.to(torch.device(f"{torch_device}:0"))
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

        with self.assertRaises(ValueError):
            # Tries with a `device`
            self.model_8bit.float()

        with self.assertRaises(ValueError):
            # Tries with a `device`
            self.model_8bit.half()

        # Test if we did not break anything
        self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
        input_dict_for_transformer = self.get_dummy_inputs()
        model_inputs = {
            k: v.to(dtype=torch.float32, device=torch_device)
            for k, v in input_dict_for_transformer.items()
            if not isinstance(v, bool)
        }
        model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
        with torch.no_grad():
            _ = self.model_fp16(**model_inputs)

        # Check this does not throw an error
        _ = self.model_fp16.to("cpu")

        # Check this does not throw an error
        _ = self.model_fp16.half()

        # Check this does not throw an error
        _ = self.model_fp16.float()

        # Check that this does not throw an error
324
        _ = self.model_fp16.to(torch_device)
325

326
327
328
329
330
331
332
333
334
335
336
337
    def test_bnb_8bit_logs_warning_for_no_quantization(self):
        model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
        assert (
            "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
            in cap_logger.out
        )

338

hlky's avatar
hlky committed
339
340
341
class Bnb8bitDeviceTests(Base8bitTests):
    def setUp(self) -> None:
        gc.collect()
342
        backend_empty_cache(torch_device)
hlky's avatar
hlky committed
343
344
345
346
347
348

        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        self.model_8bit = SanaTransformer2DModel.from_pretrained(
            "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
            subfolder="transformer",
            quantization_config=mixed_int8_config,
349
            device_map=torch_device,
hlky's avatar
hlky committed
350
351
352
353
354
355
        )

    def tearDown(self):
        del self.model_8bit

        gc.collect()
356
        backend_empty_cache(torch_device)
hlky's avatar
hlky committed
357
358
359
360
361
362
363
364
365
366

    def test_buffers_device_assignment(self):
        for buffer_name, buffer in self.model_8bit.named_buffers():
            self.assertEqual(
                buffer.device.type,
                torch.device(torch_device).type,
                f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
            )


367
368
class BnB8bitTrainingTests(Base8bitTests):
    def setUp(self):
369
        gc.collect()
370
        backend_empty_cache(torch_device)
371

372
373
        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        self.model_8bit = SD3Transformer2DModel.from_pretrained(
374
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        )

    def test_training(self):
        # Step 1: freeze all parameters
        for param in self.model_8bit.parameters():
            param.requires_grad = False  # freeze the model - train adapters later
            if param.ndim == 1:
                # cast the small parameters (e.g. layernorm) to fp32 for stability
                param.data = param.data.to(torch.float32)

        # Step 2: add adapters
        for _, module in self.model_8bit.named_modules():
            if "Attention" in repr(type(module)):
                module.to_k = LoRALayer(module.to_k, rank=4)
                module.to_q = LoRALayer(module.to_q, rank=4)
                module.to_v = LoRALayer(module.to_v, rank=4)

        # Step 3: dummy batch
        input_dict_for_transformer = self.get_dummy_inputs()
        model_inputs = {
            k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
        }
        model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})

        # Step 4: Check if the gradient is not None
Yao Matrix's avatar
Yao Matrix committed
400
        with torch.amp.autocast(torch_device, dtype=torch.float16):
401
402
403
404
405
406
407
408
409
410
411
412
            out = self.model_8bit(**model_inputs)[0]
            out.norm().backward()

        for module in self.model_8bit.modules():
            if isinstance(module, LoRALayer):
                self.assertTrue(module.adapter[1].weight.grad is not None)
                self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)


@require_transformers_version_greater("4.44.0")
class SlowBnb8bitTests(Base8bitTests):
    def setUp(self) -> None:
413
        gc.collect()
414
        backend_empty_cache(torch_device)
415

416
417
        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        model_8bit = SD3Transformer2DModel.from_pretrained(
418
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
419
420
421
422
423
424
425
426
427
428
        )
        self.pipeline_8bit = DiffusionPipeline.from_pretrained(
            self.model_name, transformer=model_8bit, torch_dtype=torch.float16
        )
        self.pipeline_8bit.enable_model_cpu_offload()

    def tearDown(self):
        del self.pipeline_8bit

        gc.collect()
429
        backend_empty_cache(torch_device)
430
431
432
433
434
435
436
437
438

    def test_quality(self):
        output = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=self.num_inference_steps,
            generator=torch.manual_seed(self.seed),
            output_type="np",
        ).images
        out_slice = output[0, -3:, -3:, -1].flatten()
439
        expected_slice = np.array([0.0674, 0.0623, 0.0364, 0.0632, 0.0671, 0.0430, 0.0317, 0.0493, 0.0583])
440
441
442
443
444
445

        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-2)

    def test_model_cpu_offload_raises_warning(self):
        model_8bit = SD3Transformer2DModel.from_pretrained(
446
447
448
449
            self.model_name,
            subfolder="transformer",
            quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            device_map=torch_device,
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        )
        pipeline_8bit = DiffusionPipeline.from_pretrained(
            self.model_name, transformer=model_8bit, torch_dtype=torch.float16
        )
        logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
        logger.setLevel(30)

        with CaptureLogger(logger) as cap_logger:
            pipeline_8bit.enable_model_cpu_offload()

        assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out

    def test_moving_to_cpu_throws_warning(self):
        model_8bit = SD3Transformer2DModel.from_pretrained(
464
465
466
467
            self.model_name,
            subfolder="transformer",
            quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            device_map=torch_device,
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        )
        logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
        logger.setLevel(30)

        with CaptureLogger(logger) as cap_logger:
            # Because `model.dtype` will return torch.float16 as SD3 transformer has
            # a conv layer as the first layer.
            _ = DiffusionPipeline.from_pretrained(
                self.model_name, transformer=model_8bit, torch_dtype=torch.float16
            ).to("cpu")

        assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out

    def test_generate_quality_dequantize(self):
        r"""
        Test that loading the model and unquantize it produce correct results.
        """
485
        torch.use_deterministic_algorithms(True)
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        self.pipeline_8bit.transformer.dequantize()
        output = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=self.num_inference_steps,
            generator=torch.manual_seed(self.seed),
            output_type="np",
        ).images

        out_slice = output[0, -3:, -3:, -1].flatten()
        expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208])
        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-2)

        # 8bit models cannot be offloaded to CPU.
500
        self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device)
501
502
503
504
505
506
507
508
        # calling it again shouldn't be a problem
        _ = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=2,
            generator=torch.manual_seed(self.seed),
            output_type="np",
        ).images

509
510
511
512
513
514
515
516
517
518
519
520
    @pytest.mark.xfail(
        condition=is_accelerate_version("<=", "1.1.1"),
        reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
        strict=True,
    )
    def test_pipeline_cuda_placement_works_with_mixed_int8(self):
        transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
        transformer_8bit = SD3Transformer2DModel.from_pretrained(
            self.model_name,
            subfolder="transformer",
            quantization_config=transformer_8bit_config,
            torch_dtype=torch.float16,
521
            device_map=torch_device,
522
523
524
525
526
527
528
        )
        text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
        text_encoder_3_8bit = T5EncoderModel.from_pretrained(
            self.model_name,
            subfolder="text_encoder_3",
            quantization_config=text_encoder_3_8bit_config,
            torch_dtype=torch.float16,
529
            device_map=torch_device,
530
        )
531

532
        # CUDA device placement works.
533
        device = torch_device if torch_device != "rocm" else "cuda"
534
535
536
537
538
        pipeline_8bit = DiffusionPipeline.from_pretrained(
            self.model_name,
            transformer=transformer_8bit,
            text_encoder_3=text_encoder_3_8bit,
            torch_dtype=torch.float16,
539
        ).to(device)
540
541

        # Check if inference works.
542
        _ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
543
544
545

        del pipeline_8bit

546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    def test_device_map(self):
        """
        Test if the quantized model is working properly with "auto"
        pu/disk offloading doesn't work with bnb.
        """

        def get_dummy_tensor_inputs(device=None, seed: int = 0):
            batch_size = 1
            num_latent_channels = 4
            num_image_channels = 3
            height = width = 4
            sequence_length = 48
            embedding_dim = 32

            torch.manual_seed(seed)
            hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
                device, dtype=torch.bfloat16
            )

            torch.manual_seed(seed)
            encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
                device, dtype=torch.bfloat16
            )

            torch.manual_seed(seed)
            pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)

            torch.manual_seed(seed)
            text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)

            torch.manual_seed(seed)
            image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)

            timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)

            return {
                "hidden_states": hidden_states,
                "encoder_hidden_states": encoder_hidden_states,
                "pooled_projections": pooled_prompt_embeds,
                "txt_ids": text_ids,
                "img_ids": image_ids,
                "timestep": timestep,
            }

        inputs = get_dummy_tensor_inputs(torch_device)
        expected_slice = np.array(
            [
                0.33789062,
                -0.04736328,
                -0.00256348,
                -0.23144531,
                -0.49804688,
                0.4375,
                -0.15429688,
                -0.65234375,
                0.44335938,
            ]
        )

        # non sharded
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        quantized_model = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-pipe",
            subfolder="transformer",
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )

        weight = quantized_model.transformer_blocks[0].ff.net[2].weight
        self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))

        output = quantized_model(**inputs)[0]
        output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
        self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

        # sharded
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        quantized_model = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-sharded",
            subfolder="transformer",
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )

        weight = quantized_model.transformer_blocks[0].ff.net[2].weight
        self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
        output = quantized_model(**inputs)[0]
        output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

        self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

639
640
641
642

@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
    def setUp(self) -> None:
643
        gc.collect()
644
        backend_empty_cache(torch_device)
645
646

        model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
        transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
        self.pipeline_8bit = DiffusionPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            text_encoder_2=t5_8bit,
            transformer=transformer_8bit,
            torch_dtype=torch.float16,
        )
        self.pipeline_8bit.enable_model_cpu_offload()

    def tearDown(self):
        del self.pipeline_8bit

        gc.collect()
661
        backend_empty_cache(torch_device)
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

    def test_quality(self):
        # keep the resolution and max tokens to a lower number for faster execution.
        output = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=self.num_inference_steps,
            generator=torch.manual_seed(self.seed),
            height=256,
            width=256,
            max_sequence_length=64,
            output_type="np",
        ).images
        out_slice = output[0, -3:, -3:, -1].flatten()
        expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930])

        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-3)

680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    @require_peft_version_greater("0.14.0")
    def test_lora_loading(self):
        self.pipeline_8bit.load_lora_weights(
            hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
        )
        self.pipeline_8bit.set_adapters("hyper-sd", adapter_weights=0.125)

        output = self.pipeline_8bit(
            prompt=self.prompt,
            height=256,
            width=256,
            max_sequence_length=64,
            output_type="np",
            num_inference_steps=8,
            generator=torch.manual_seed(42),
        ).images
        out_slice = output[0, -3:, -3:, -1].flatten()

        expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])

        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-3)

703

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
@require_transformers_version_greater("4.44.0")
@require_peft_backend
class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests):
    def setUp(self) -> None:
        gc.collect()
        backend_empty_cache(torch_device)

        self.pipeline_8bit = FluxControlPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            quantization_config=PipelineQuantizationConfig(
                quant_backend="bitsandbytes_8bit",
                quant_kwargs={"load_in_8bit": True},
                components_to_quantize=["transformer", "text_encoder_2"],
            ),
            torch_dtype=torch.float16,
        )
        self.pipeline_8bit.enable_model_cpu_offload()

    def tearDown(self):
        del self.pipeline_8bit

        gc.collect()
        backend_empty_cache(torch_device)

    def test_lora_loading(self):
        self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")

        output = self.pipeline_8bit(
            prompt=self.prompt,
            control_image=Image.new(mode="RGB", size=(256, 256)),
            height=256,
            width=256,
            max_sequence_length=64,
            output_type="np",
            num_inference_steps=8,
            generator=torch.Generator().manual_seed(42),
        ).images
        out_slice = output[0, -3:, -3:, -1].flatten()
        expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292])

        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")


748
749
750
@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
    def setUp(self):
751
        gc.collect()
752
        backend_empty_cache(torch_device)
753

754
755
756
757
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
        self.model_0 = SD3Transformer2DModel.from_pretrained(
758
            self.model_name, subfolder="transformer", quantization_config=quantization_config, device_map=torch_device
759
760
761
762
763
764
        )

    def tearDown(self):
        del self.model_0

        gc.collect()
765
        backend_empty_cache(torch_device)
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823

    def test_serialization(self):
        r"""
        Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default.
        """
        self.assertTrue("_pre_quantization_dtype" in self.model_0.config)
        with tempfile.TemporaryDirectory() as tmpdirname:
            self.model_0.save_pretrained(tmpdirname)

            config = SD3Transformer2DModel.load_config(tmpdirname)
            self.assertTrue("quantization_config" in config)
            self.assertTrue("_pre_quantization_dtype" not in config)

            model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)

        # checking quantized linear module weight
        linear = get_some_linear_layer(model_1)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
        self.assertTrue(hasattr(linear.weight, "SCB"))

        # checking memory footpring
        self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)

        # Matching all parameters and their quant_state items:
        d0 = dict(self.model_0.named_parameters())
        d1 = dict(model_1.named_parameters())
        self.assertTrue(d0.keys() == d1.keys())

        # comparing forward() outputs
        dummy_inputs = self.get_dummy_inputs()
        inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
        inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
        out_0 = self.model_0(**inputs)[0]
        out_1 = model_1(**inputs)[0]
        self.assertTrue(torch.equal(out_0, out_1))

    def test_serialization_sharded(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB")

            config = SD3Transformer2DModel.load_config(tmpdirname)
            self.assertTrue("quantization_config" in config)
            self.assertTrue("_pre_quantization_dtype" not in config)

            model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)

        # checking quantized linear module weight
        linear = get_some_linear_layer(model_1)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
        self.assertTrue(hasattr(linear.weight, "SCB"))

        # comparing forward() outputs
        dummy_inputs = self.get_dummy_inputs()
        inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
        inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
        out_0 = self.model_0(**inputs)[0]
        out_1 = model_1(**inputs)[0]
        self.assertTrue(torch.equal(out_0, out_1))