test_mixed_int8.py 28.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest

import numpy as np
20
import pytest
21
from huggingface_hub import hf_hub_download
22

hlky's avatar
hlky committed
23
24
25
26
27
28
29
30
from diffusers import (
    BitsAndBytesConfig,
    DiffusionPipeline,
    FluxTransformer2DModel,
    SanaTransformer2DModel,
    SD3Transformer2DModel,
    logging,
)
31
from diffusers.utils import is_accelerate_version
32
33
34
35
36
37
38
39
40
from diffusers.utils.testing_utils import (
    CaptureLogger,
    is_bitsandbytes_available,
    is_torch_available,
    is_transformers_available,
    load_pt,
    numpy_cosine_similarity_distance,
    require_accelerate,
    require_bitsandbytes_version_greater,
41
    require_peft_version_greater,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    require_torch,
    require_torch_gpu,
    require_transformers_version_greater,
    slow,
    torch_device,
)


def get_some_linear_layer(model):
    if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
        return model.transformer_blocks[0].attn.to_q
    else:
        return NotImplementedError("Don't know what layer to retrieve here.")


if is_transformers_available():
58
    from transformers import BitsAndBytesConfig as BnbConfig
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    from transformers import T5EncoderModel

if is_torch_available():
    import torch
    import torch.nn as nn

    class LoRALayer(nn.Module):
        """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only

        Taken from
        https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
        """

        def __init__(self, module: nn.Module, rank: int):
            super().__init__()
            self.module = module
            self.adapter = nn.Sequential(
                nn.Linear(module.in_features, rank, bias=False),
                nn.Linear(rank, module.out_features, bias=False),
            )
            small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
            nn.init.normal_(self.adapter[0].weight, std=small_std)
            nn.init.zeros_(self.adapter[1].weight)
            self.adapter.to(module.weight.device)

        def forward(self, input, *args, **kwargs):
            return self.module(input, *args, **kwargs) + self.adapter(input)


if is_bitsandbytes_available():
    import bitsandbytes as bnb


@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@require_torch
@require_torch_gpu
@slow
class Base8bitTests(unittest.TestCase):
    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
    # Therefore here we use only SD3 to test our module
    model_name = "stabilityai/stable-diffusion-3-medium-diffusers"

    # This was obtained on audace so the number might slightly change
    expected_rel_difference = 1.94

    prompt = "a beautiful sunset amidst the mountains."
    num_inference_steps = 10
    seed = 0

    def get_dummy_inputs(self):
        prompt_embeds = load_pt(
            "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
        )
        pooled_prompt_embeds = load_pt(
            "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
        )
        latent_model_input = load_pt(
            "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
        )

        input_dict_for_transformer = {
            "hidden_states": latent_model_input,
            "encoder_hidden_states": prompt_embeds,
            "pooled_projections": pooled_prompt_embeds,
            "timestep": torch.Tensor([1.0]),
            "return_dict": False,
        }
        return input_dict_for_transformer


class BnB8bitBasicTests(Base8bitTests):
    def setUp(self):
132
133
134
        gc.collect()
        torch.cuda.empty_cache()

135
136
137
138
139
140
        # Models
        self.model_fp16 = SD3Transformer2DModel.from_pretrained(
            self.model_name, subfolder="transformer", torch_dtype=torch.float16
        )
        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        self.model_8bit = SD3Transformer2DModel.from_pretrained(
141
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        )

    def tearDown(self):
        del self.model_fp16
        del self.model_8bit

        gc.collect()
        torch.cuda.empty_cache()

    def test_quantization_num_parameters(self):
        r"""
        Test if the number of returned parameters is correct
        """
        num_params_8bit = self.model_8bit.num_parameters()
        num_params_fp16 = self.model_fp16.num_parameters()

        self.assertEqual(num_params_8bit, num_params_fp16)

    def test_quantization_config_json_serialization(self):
        r"""
        A simple test to check if the quantization config is correctly serialized and deserialized
        """
        config = self.model_8bit.config

        self.assertTrue("quantization_config" in config)

        _ = config["quantization_config"].to_dict()
        _ = config["quantization_config"].to_diff_dict()

        _ = config["quantization_config"].to_json_string()

    def test_memory_footprint(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        mem_fp16 = self.model_fp16.get_memory_footprint()
        mem_8bit = self.model_8bit.get_memory_footprint()

        self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2)
        linear = get_some_linear_layer(self.model_8bit)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)

    def test_original_dtype(self):
        r"""
        A simple test to check if the model succesfully stores the original dtype
        """
        self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config)
        self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
        self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16)

    def test_keep_modules_in_fp32(self):
        r"""
        A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
        Also ensures if inference works.
        """
        fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
        SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]

        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        model = SD3Transformer2DModel.from_pretrained(
203
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        )

        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name in model._keep_in_fp32_modules:
                    self.assertTrue(module.weight.dtype == torch.float32)
                else:
                    # 8-bit parameters are packed in int8 variables
                    self.assertTrue(module.weight.dtype == torch.int8)

        # test if inference works.
        with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
            input_dict_for_transformer = self.get_dummy_inputs()
            model_inputs = {
                k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
            }
            model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
            _ = model(**model_inputs)

        SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules

    def test_linear_are_8bit(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        self.model_fp16.get_memory_footprint()
        self.model_8bit.get_memory_footprint()

        for name, module in self.model_8bit.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name not in ["proj_out"]:
                    # 8-bit parameters are packed in int8 variables
                    self.assertTrue(module.weight.dtype == torch.int8)

    def test_llm_skip(self):
        r"""
        A simple test to check if `llm_int8_skip_modules` works as expected
        """
        config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
        model_8bit = SD3Transformer2DModel.from_pretrained(
245
            self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
246
247
248
249
250
251
252
253
254
255
        )
        linear = get_some_linear_layer(model_8bit)
        self.assertTrue(linear.weight.dtype == torch.int8)
        self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))

        self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
        self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)

    def test_config_from_pretrained(self):
        transformer_8bit = FluxTransformer2DModel.from_pretrained(
256
            "hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        )
        linear = get_some_linear_layer(transformer_8bit)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
        self.assertTrue(hasattr(linear.weight, "SCB"))

    def test_device_and_dtype_assignment(self):
        r"""
        Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
        Checks also if other models are casted correctly.
        """
        with self.assertRaises(ValueError):
            # Tries with `str`
            self.model_8bit.to("cpu")

        with self.assertRaises(ValueError):
            # Tries with a `dtype``
            self.model_8bit.to(torch.float16)

        with self.assertRaises(ValueError):
            # Tries with a `device`
            self.model_8bit.to(torch.device("cuda:0"))

        with self.assertRaises(ValueError):
            # Tries with a `device`
            self.model_8bit.float()

        with self.assertRaises(ValueError):
            # Tries with a `device`
            self.model_8bit.half()

        # Test if we did not break anything
        self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
        input_dict_for_transformer = self.get_dummy_inputs()
        model_inputs = {
            k: v.to(dtype=torch.float32, device=torch_device)
            for k, v in input_dict_for_transformer.items()
            if not isinstance(v, bool)
        }
        model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
        with torch.no_grad():
            _ = self.model_fp16(**model_inputs)

        # Check this does not throw an error
        _ = self.model_fp16.to("cpu")

        # Check this does not throw an error
        _ = self.model_fp16.half()

        # Check this does not throw an error
        _ = self.model_fp16.float()

        # Check that this does not throw an error
        _ = self.model_fp16.cuda()


hlky's avatar
hlky committed
312
313
314
315
316
317
318
319
320
321
class Bnb8bitDeviceTests(Base8bitTests):
    def setUp(self) -> None:
        gc.collect()
        torch.cuda.empty_cache()

        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        self.model_8bit = SanaTransformer2DModel.from_pretrained(
            "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
            subfolder="transformer",
            quantization_config=mixed_int8_config,
322
            device_map=torch_device,
hlky's avatar
hlky committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        )

    def tearDown(self):
        del self.model_8bit

        gc.collect()
        torch.cuda.empty_cache()

    def test_buffers_device_assignment(self):
        for buffer_name, buffer in self.model_8bit.named_buffers():
            self.assertEqual(
                buffer.device.type,
                torch.device(torch_device).type,
                f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
            )


340
341
class BnB8bitTrainingTests(Base8bitTests):
    def setUp(self):
342
343
344
        gc.collect()
        torch.cuda.empty_cache()

345
346
        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        self.model_8bit = SD3Transformer2DModel.from_pretrained(
347
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        )

    def test_training(self):
        # Step 1: freeze all parameters
        for param in self.model_8bit.parameters():
            param.requires_grad = False  # freeze the model - train adapters later
            if param.ndim == 1:
                # cast the small parameters (e.g. layernorm) to fp32 for stability
                param.data = param.data.to(torch.float32)

        # Step 2: add adapters
        for _, module in self.model_8bit.named_modules():
            if "Attention" in repr(type(module)):
                module.to_k = LoRALayer(module.to_k, rank=4)
                module.to_q = LoRALayer(module.to_q, rank=4)
                module.to_v = LoRALayer(module.to_v, rank=4)

        # Step 3: dummy batch
        input_dict_for_transformer = self.get_dummy_inputs()
        model_inputs = {
            k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
        }
        model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})

        # Step 4: Check if the gradient is not None
        with torch.amp.autocast("cuda", dtype=torch.float16):
            out = self.model_8bit(**model_inputs)[0]
            out.norm().backward()

        for module in self.model_8bit.modules():
            if isinstance(module, LoRALayer):
                self.assertTrue(module.adapter[1].weight.grad is not None)
                self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)


@require_transformers_version_greater("4.44.0")
class SlowBnb8bitTests(Base8bitTests):
    def setUp(self) -> None:
386
387
388
        gc.collect()
        torch.cuda.empty_cache()

389
390
        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
        model_8bit = SD3Transformer2DModel.from_pretrained(
391
            self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        )
        self.pipeline_8bit = DiffusionPipeline.from_pretrained(
            self.model_name, transformer=model_8bit, torch_dtype=torch.float16
        )
        self.pipeline_8bit.enable_model_cpu_offload()

    def tearDown(self):
        del self.pipeline_8bit

        gc.collect()
        torch.cuda.empty_cache()

    def test_quality(self):
        output = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=self.num_inference_steps,
            generator=torch.manual_seed(self.seed),
            output_type="np",
        ).images
        out_slice = output[0, -3:, -3:, -1].flatten()
412
        expected_slice = np.array([0.0674, 0.0623, 0.0364, 0.0632, 0.0671, 0.0430, 0.0317, 0.0493, 0.0583])
413
414
415
416
417
418

        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-2)

    def test_model_cpu_offload_raises_warning(self):
        model_8bit = SD3Transformer2DModel.from_pretrained(
419
420
421
422
            self.model_name,
            subfolder="transformer",
            quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            device_map=torch_device,
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        )
        pipeline_8bit = DiffusionPipeline.from_pretrained(
            self.model_name, transformer=model_8bit, torch_dtype=torch.float16
        )
        logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
        logger.setLevel(30)

        with CaptureLogger(logger) as cap_logger:
            pipeline_8bit.enable_model_cpu_offload()

        assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out

    def test_moving_to_cpu_throws_warning(self):
        model_8bit = SD3Transformer2DModel.from_pretrained(
437
438
439
440
            self.model_name,
            subfolder="transformer",
            quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            device_map=torch_device,
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        )
        logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
        logger.setLevel(30)

        with CaptureLogger(logger) as cap_logger:
            # Because `model.dtype` will return torch.float16 as SD3 transformer has
            # a conv layer as the first layer.
            _ = DiffusionPipeline.from_pretrained(
                self.model_name, transformer=model_8bit, torch_dtype=torch.float16
            ).to("cpu")

        assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out

    def test_generate_quality_dequantize(self):
        r"""
        Test that loading the model and unquantize it produce correct results.
        """
        self.pipeline_8bit.transformer.dequantize()
        output = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=self.num_inference_steps,
            generator=torch.manual_seed(self.seed),
            output_type="np",
        ).images

        out_slice = output[0, -3:, -3:, -1].flatten()
        expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208])
        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-2)

        # 8bit models cannot be offloaded to CPU.
        self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda")
        # calling it again shouldn't be a problem
        _ = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=2,
            generator=torch.manual_seed(self.seed),
            output_type="np",
        ).images

481
482
483
484
485
486
487
488
489
490
491
492
    @pytest.mark.xfail(
        condition=is_accelerate_version("<=", "1.1.1"),
        reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
        strict=True,
    )
    def test_pipeline_cuda_placement_works_with_mixed_int8(self):
        transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
        transformer_8bit = SD3Transformer2DModel.from_pretrained(
            self.model_name,
            subfolder="transformer",
            quantization_config=transformer_8bit_config,
            torch_dtype=torch.float16,
493
            device_map=torch_device,
494
495
496
497
498
499
500
        )
        text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
        text_encoder_3_8bit = T5EncoderModel.from_pretrained(
            self.model_name,
            subfolder="text_encoder_3",
            quantization_config=text_encoder_3_8bit_config,
            torch_dtype=torch.float16,
501
            device_map=torch_device,
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        )
        # CUDA device placement works.
        pipeline_8bit = DiffusionPipeline.from_pretrained(
            self.model_name,
            transformer=transformer_8bit,
            text_encoder_3=text_encoder_3_8bit,
            torch_dtype=torch.float16,
        ).to("cuda")

        # Check if inference works.
        _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)

        del pipeline_8bit

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    def test_device_map(self):
        """
        Test if the quantized model is working properly with "auto"
        pu/disk offloading doesn't work with bnb.
        """

        def get_dummy_tensor_inputs(device=None, seed: int = 0):
            batch_size = 1
            num_latent_channels = 4
            num_image_channels = 3
            height = width = 4
            sequence_length = 48
            embedding_dim = 32

            torch.manual_seed(seed)
            hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
                device, dtype=torch.bfloat16
            )

            torch.manual_seed(seed)
            encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
                device, dtype=torch.bfloat16
            )

            torch.manual_seed(seed)
            pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)

            torch.manual_seed(seed)
            text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)

            torch.manual_seed(seed)
            image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)

            timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)

            return {
                "hidden_states": hidden_states,
                "encoder_hidden_states": encoder_hidden_states,
                "pooled_projections": pooled_prompt_embeds,
                "txt_ids": text_ids,
                "img_ids": image_ids,
                "timestep": timestep,
            }

        inputs = get_dummy_tensor_inputs(torch_device)
        expected_slice = np.array(
            [
                0.33789062,
                -0.04736328,
                -0.00256348,
                -0.23144531,
                -0.49804688,
                0.4375,
                -0.15429688,
                -0.65234375,
                0.44335938,
            ]
        )

        # non sharded
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        quantized_model = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-pipe",
            subfolder="transformer",
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )

        weight = quantized_model.transformer_blocks[0].ff.net[2].weight
        self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))

        output = quantized_model(**inputs)[0]
        output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
        self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

        # sharded
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        quantized_model = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-sharded",
            subfolder="transformer",
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )

        weight = quantized_model.transformer_blocks[0].ff.net[2].weight
        self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
        output = quantized_model(**inputs)[0]
        output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

        self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

609
610
611
612

@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
    def setUp(self) -> None:
613
614
615
616
        gc.collect()
        torch.cuda.empty_cache()

        model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
        transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
        self.pipeline_8bit = DiffusionPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            text_encoder_2=t5_8bit,
            transformer=transformer_8bit,
            torch_dtype=torch.float16,
        )
        self.pipeline_8bit.enable_model_cpu_offload()

    def tearDown(self):
        del self.pipeline_8bit

        gc.collect()
        torch.cuda.empty_cache()

    def test_quality(self):
        # keep the resolution and max tokens to a lower number for faster execution.
        output = self.pipeline_8bit(
            prompt=self.prompt,
            num_inference_steps=self.num_inference_steps,
            generator=torch.manual_seed(self.seed),
            height=256,
            width=256,
            max_sequence_length=64,
            output_type="np",
        ).images
        out_slice = output[0, -3:, -3:, -1].flatten()
        expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930])

        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-3)

650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    @require_peft_version_greater("0.14.0")
    def test_lora_loading(self):
        self.pipeline_8bit.load_lora_weights(
            hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
        )
        self.pipeline_8bit.set_adapters("hyper-sd", adapter_weights=0.125)

        output = self.pipeline_8bit(
            prompt=self.prompt,
            height=256,
            width=256,
            max_sequence_length=64,
            output_type="np",
            num_inference_steps=8,
            generator=torch.manual_seed(42),
        ).images
        out_slice = output[0, -3:, -3:, -1].flatten()

        expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])

        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
        self.assertTrue(max_diff < 1e-3)

673
674
675
676

@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
    def setUp(self):
677
678
679
        gc.collect()
        torch.cuda.empty_cache()

680
681
682
683
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
        self.model_0 = SD3Transformer2DModel.from_pretrained(
684
            self.model_name, subfolder="transformer", quantization_config=quantization_config, device_map=torch_device
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        )

    def tearDown(self):
        del self.model_0

        gc.collect()
        torch.cuda.empty_cache()

    def test_serialization(self):
        r"""
        Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default.
        """
        self.assertTrue("_pre_quantization_dtype" in self.model_0.config)
        with tempfile.TemporaryDirectory() as tmpdirname:
            self.model_0.save_pretrained(tmpdirname)

            config = SD3Transformer2DModel.load_config(tmpdirname)
            self.assertTrue("quantization_config" in config)
            self.assertTrue("_pre_quantization_dtype" not in config)

            model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)

        # checking quantized linear module weight
        linear = get_some_linear_layer(model_1)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
        self.assertTrue(hasattr(linear.weight, "SCB"))

        # checking memory footpring
        self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)

        # Matching all parameters and their quant_state items:
        d0 = dict(self.model_0.named_parameters())
        d1 = dict(model_1.named_parameters())
        self.assertTrue(d0.keys() == d1.keys())

        # comparing forward() outputs
        dummy_inputs = self.get_dummy_inputs()
        inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
        inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
        out_0 = self.model_0(**inputs)[0]
        out_1 = model_1(**inputs)[0]
        self.assertTrue(torch.equal(out_0, out_1))

    def test_serialization_sharded(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB")

            config = SD3Transformer2DModel.load_config(tmpdirname)
            self.assertTrue("quantization_config" in config)
            self.assertTrue("_pre_quantization_dtype" not in config)

            model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)

        # checking quantized linear module weight
        linear = get_some_linear_layer(model_1)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
        self.assertTrue(hasattr(linear.weight, "SCB"))

        # comparing forward() outputs
        dummy_inputs = self.get_dummy_inputs()
        inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
        inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
        out_0 = self.model_0(**inputs)[0]
        out_1 = model_1(**inputs)[0]
        self.assertTrue(torch.equal(out_0, out_1))