test_4bit.py 24.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
16
import importlib.metadata
17
import tempfile
18
19
import unittest

20
21
from packaging import version

22
from transformers import (
23
    AutoConfig,
24
25
26
27
28
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
29
    BitsAndBytesConfig,
30
31
    pipeline,
)
32
from transformers.testing_utils import (
33
    is_bitsandbytes_available,
34
35
36
37
38
39
40
    is_torch_available,
    require_accelerate,
    require_bitsandbytes,
    require_torch,
    require_torch_gpu,
    require_torch_multi_gpu,
    slow,
41
    torch_device,
42
43
44
)


45
46
47
def get_some_linear_layer(model):
    if model.config.model_type == "gpt2":
        return model.transformer.h[0].mlp.c_fc
48
49
50
51
52
53
54
55
    elif model.config.model_type == "opt":
        try:
            return model.decoder.layers[0].fc1
        except AttributeError:
            # for AutoModelforCausalLM
            return model.model.decoder.layers[0].fc1
    else:
        return model.transformer.h[0].mlp.dense_4h_to_h
56
57


58
59
if is_torch_available():
    import torch
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    import torch.nn as nn

    class LoRALayer(nn.Module):
        """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""

        def __init__(self, module: nn.Module, rank: int):
            super().__init__()
            self.module = module
            self.adapter = nn.Sequential(
                nn.Linear(module.in_features, rank, bias=False),
                nn.Linear(rank, module.out_features, bias=False),
            )
            small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
            nn.init.normal_(self.adapter[0].weight, std=small_std)
            nn.init.zeros_(self.adapter[1].weight)
            self.adapter.to(module.weight.device)

        def forward(self, input, *args, **kwargs):
            return self.module(input, *args, **kwargs) + self.adapter(input)
79
80


81
82
83
84
if is_bitsandbytes_available():
    import bitsandbytes as bnb


85
86
87
88
89
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
90
class Base4bitTest(unittest.TestCase):
91
92
93
94
95
96
97
98
    # We keep the constants inside the init function and model loading inside setUp function

    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
    # Therefore here we use only bloom-1b3 to test our module
    model_name = "bigscience/bloom-1b7"

    # Constant values
    EXPECTED_RELATIVE_DIFFERENCE = (
99
        2.109659552692574  # This was obtained on a RTX Titan so the number might slightly change
100
101
102
    )

    input_text = "Hello my name is"
103
104
105
    EXPECTED_OUTPUTS = set()
    EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
    EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
106
    EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
107
108
109
110
111
112
113
    MAX_NEW_TOKENS = 10

    def setUp(self):
        # Models and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)


114
class Bnb4BitTest(Base4bitTest):
115
116
117
118
    def setUp(self):
        super().setUp()

        # Models and tokenizer
119
120
121
        self.model_fp16 = AutoModelForCausalLM.from_pretrained(
            self.model_name, torch_dtype=torch.float16, device_map="auto"
        )
122
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
123
124
125
126
127
128
129

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.model_fp16
130
        del self.model_4bit
131
132
133
134

        gc.collect()
        torch.cuda.empty_cache()

135
136
137
138
139
140
141
142
143
144
145
    def test_quantization_num_parameters(self):
        r"""
        Test if the number of returned parameters is correct

        See: https://github.com/huggingface/transformers/issues/25978
        """
        num_params_4bit = self.model_4bit.num_parameters()
        num_params_fp16 = self.model_fp16.num_parameters()

        self.assertEqual(num_params_4bit, num_params_fp16)

146
147
148
149
150
151
152
153
154
155
156
157
158
    def test_quantization_config_json_serialization(self):
        r"""
        A simple test to check if the quantization config is correctly serialized and deserialized
        """
        config = self.model_4bit.config

        self.assertTrue(hasattr(config, "quantization_config"))

        _ = config.to_dict()
        _ = config.to_diff_dict()

        _ = config.to_json_string()

159
160
161
162
163
    def test_memory_footprint(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
164
        from bitsandbytes.nn import Params4bit
165
166

        mem_fp16 = self.model_fp16.get_memory_footprint()
167
168
169
        mem_4bit = self.model_4bit.get_memory_footprint()

        self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
170
171
        linear = get_some_linear_layer(self.model_4bit)
        self.assertTrue(linear.weight.__class__ == Params4bit)
172

173
174
175
176
177
178
179
180
    def test_original_dtype(self):
        r"""
        A simple test to check if the model succesfully stores the original dtype
        """
        self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype"))
        self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
        self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16)

181
182
183
184
185
186
187
188
189
    def test_linear_are_4bit(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        from transformers import T5PreTrainedModel

        self.model_fp16.get_memory_footprint()
        self.model_4bit.get_memory_footprint()
190

191
192
193
194
195
        for name, module in self.model_4bit.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
                    # 4-bit parameters are packed in uint8 variables
                    self.assertTrue(module.weight.dtype == torch.uint8)
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    def test_rwkv_4bit(self):
        r"""
        A simple test to check if 4-bit RWKV inference works as expected.
        """
        model_id = "RWKV/rwkv-4-169m-pile"

        quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)

        model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
        tok = AutoTokenizer.from_pretrained(model_id)

        text = "Hello my name is"
        input_ids = tok.encode(text, return_tensors="pt").to(0)

        _ = model.generate(input_ids, max_new_tokens=30)

213
214
215
216
217
218
219
    def test_generate_quality(self):
        r"""
        Test the generation quality of the quantized model and see that we are matching the expected output.
        Given that we are operating on small numbers + the testing model is relatively small, we might not get
        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
        """
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
220
        output_sequences = self.model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
221

222
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
223

224
225
226
227
228
    def test_generate_quality_config(self):
        r"""
        Test that loading the model with the config is equivalent
        """
        bnb_config = BitsAndBytesConfig()
229
        bnb_config.load_in_4bit = True
230

231
        model_4bit_from_config = AutoModelForCausalLM.from_pretrained(
232
233
234
235
            self.model_name, quantization_config=bnb_config, device_map="auto"
        )

        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
236
        output_sequences = model_4bit_from_config.generate(
237
238
239
            input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10
        )

240
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
241

242
243
244
245
246
247
248
    def test_device_and_dtype_assignment(self):
        r"""
        Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
        Checks also if other models are casted correctly.
        """
        with self.assertRaises(ValueError):
            # Tries with `str`
249
            self.model_4bit.to("cpu")
250
251
252

        with self.assertRaises(ValueError):
            # Tries with a `dtype``
253
            self.model_4bit.to(torch.float16)
254
255
256

        with self.assertRaises(ValueError):
            # Tries with a `device`
257
            self.model_4bit.to(torch.device("cuda:0"))
258
259
260

        with self.assertRaises(ValueError):
            # Tries with a `device`
261
            self.model_4bit.float()
262
263
264

        with self.assertRaises(ValueError):
            # Tries with a `device`
265
            self.model_4bit.half()
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        # Test if we did not break anything
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        self.model_fp16 = self.model_fp16.to(torch.float32)
        _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

        # Check this does not throw an error
        _ = self.model_fp16.to("cpu")

        # Check this does not throw an error
        _ = self.model_fp16.half()

        # Check this does not throw an error
        _ = self.model_fp16.float()

282
    def test_fp32_4bit_conversion(self):
283
        r"""
284
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
285
        """
286
        model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_4bit=True, device_map="auto")
287
288
        self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)

289

290
291
292
293
294
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
295
class Bnb4BitT5Test(unittest.TestCase):
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    @classmethod
    def setUpClass(cls):
        cls.model_name = "t5-small"
        cls.dense_act_model_name = "google/flan-t5-small"  # flan-t5 uses dense-act instead of dense-relu-dense
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
        cls.input_text = "Translate in German: Hello, my dog is cute"

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        gc.collect()
        torch.cuda.empty_cache()

    def test_inference_without_keep_in_fp32(self):
        r"""
313
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
314
315
316
317
318
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
        from transformers import T5ForConditionalGeneration

319
        modules = T5ForConditionalGeneration._keep_in_fp32_modules
320
321
322
        T5ForConditionalGeneration._keep_in_fp32_modules = None

        # test with `t5-small`
323
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
324
325
326
327
328
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
329
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
330
331
332
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)
333
        T5ForConditionalGeneration._keep_in_fp32_modules = modules
334
335
336

    def test_inference_with_keep_in_fp32(self):
        r"""
337
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
338
339
340
341
342
343
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
        from transformers import T5ForConditionalGeneration

        # test with `t5-small`
344
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
345
346

        # there was a bug with decoders - this test checks that it is fixed
347
        self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit))
348

349
350
351
352
353
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
354
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
355
356
357
358
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

359

360
class Classes4BitModelTest(Base4bitTest):
361
362
363
364
    def setUp(self):
        super().setUp()
        # model_name
        self.model_name = "bigscience/bloom-560m"
365
366
367
368
        self.seq_to_seq_name = "t5-small"

        # Different types of model

369
        self.base_model = AutoModel.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
370
        # Sequence classification model
371
        self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
372
            self.model_name, load_in_4bit=True, device_map="auto"
373
        )
374
        # CausalLM model
375
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
376
377
        # Seq2seq model
        self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
378
            self.seq_to_seq_name, load_in_4bit=True, device_map="auto"
379
        )
380
381
382
383
384
385
386
387

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.base_model
        del self.sequence_model
388
        del self.model_4bit
389
        del self.seq_to_seq_model
390
391
392
393
394
395
396
397
398

        gc.collect()
        torch.cuda.empty_cache()

    def test_correct_head_class(self):
        r"""
        A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification)
        are kept in their native class.
        """
399
        from bitsandbytes.nn import Params4bit
400

401
        self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
402
403

        # Other heads should be nn.Parameter
404
        self.assertTrue(self.model_4bit.lm_head.weight.__class__ == torch.nn.Parameter)
405
        self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
406
        self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
407
408


409
class Pipeline4BitTest(Base4bitTest):
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    def setUp(self):
        super().setUp()

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.pipe

        gc.collect()
        torch.cuda.empty_cache()

    def test_pipeline(self):
        r"""
425
        The aim of this test is to verify that the mixed 4bit is compatible with `pipeline` from transformers. Since
426
427
428
429
430
431
432
        we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything
        on pipline.
        """
        # self._clear_cuda_cache()
        self.pipe = pipeline(
            "text-generation",
            model=self.model_name,
433
            model_kwargs={"device_map": "auto", "load_in_4bit": True, "torch_dtype": torch.float16},
434
435
436
437
438
            max_new_tokens=self.MAX_NEW_TOKENS,
        )

        # Real second forward pass
        pipeline_output = self.pipe(self.input_text)
439
        self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
440
441
442


@require_torch_multi_gpu
443
class Bnb4bitTestMultiGpu(Base4bitTest):
444
445
446
447
448
449
450
451
452
453
    def setUp(self):
        super().setUp()

    def test_multi_gpu_loading(self):
        r"""
        This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
        Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
        """

        model_parallel = AutoModelForCausalLM.from_pretrained(
454
            self.model_name, load_in_4bit=True, device_map="balanced"
455
456
        )

Younes Belkada's avatar
Younes Belkada committed
457
458
        # Check correct device map
        self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
459
460
461
462
463
464

        # Check that inference pass works on the model
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        # Second real batch
        output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
465
        self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
466
467


468
class Bnb4BitTestTraining(Base4bitTest):
469
470
471
472
473
    def setUp(self):
        self.model_name = "facebook/opt-350m"
        super().setUp()

    def test_training(self):
474
        if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.0"):
475
476
477
            return

        # Step 1: freeze all parameters
478
479
480
        model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)

        self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508

        for param in model.parameters():
            param.requires_grad = False  # freeze the model - train adapters later
            if param.ndim == 1:
                # cast the small parameters (e.g. layernorm) to fp32 for stability
                param.data = param.data.to(torch.float32)

        # Step 2: add adapters
        for _, module in model.named_modules():
            if "OPTAttention" in repr(type(module)):
                module.q_proj = LoRALayer(module.q_proj, rank=16)
                module.k_proj = LoRALayer(module.k_proj, rank=16)
                module.v_proj = LoRALayer(module.v_proj, rank=16)

        # Step 3: dummy batch
        batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)

        # Step 4: Check if the gradient is not None
        with torch.cuda.amp.autocast():
            out = model.forward(**batch)
            out.logits.norm().backward()

        for module in model.modules():
            if isinstance(module, LoRALayer):
                self.assertTrue(module.adapter[1].weight.grad is not None)
                self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
            elif isinstance(module, nn.Embedding):
                self.assertTrue(module.weight.grad is None)
509
510
511
512
513


class Bnb4BitGPT2Test(Bnb4BitTest):
    model_name = "gpt2-xl"
    EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650


@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
class BaseSerializationTest(unittest.TestCase):
    model_name = "facebook/opt-125m"
    input_text = "Mars colonists' favorite meals are"

    def tearDown(self):
        gc.collect()
        torch.cuda.empty_cache()

    def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
        r"""
        Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
        See ExtendedSerializationTest class for more params combinations.
        """

        tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        self.quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=quant_type,
            bnb_4bit_use_double_quant=double_quant,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        model_0 = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=self.quantization_config,
            device_map=torch_device,
        )

        with tempfile.TemporaryDirectory() as tmpdirname:
            model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)

            config = AutoConfig.from_pretrained(tmpdirname)
            self.assertTrue(hasattr(config, "quantization_config"))

            model_1 = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device)

        # checking quantized linear module weight
        linear = get_some_linear_layer(model_1)
        self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
        self.assertTrue(hasattr(linear.weight, "quant_state"))
        self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)

        # checking memory footpring
        self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)

        # Matching all parameters and their quant_state items:
        d0 = dict(model_0.named_parameters())
        d1 = dict(model_1.named_parameters())
        self.assertTrue(d0.keys() == d1.keys())

        for k in d0.keys():
            self.assertTrue(d0[k].shape == d1[k].shape)
            self.assertTrue(d0[k].device.type == d1[k].device.type)
            self.assertTrue(d0[k].device == d1[k].device)
            self.assertTrue(d0[k].dtype == d1[k].dtype)
            self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))

            if isinstance(d0[k], bnb.nn.modules.Params4bit):
                for v0, v1 in zip(
                    d0[k].quant_state.as_dict().values(),
                    d1[k].quant_state.as_dict().values(),
                ):
                    if isinstance(v0, torch.Tensor):
                        self.assertTrue(torch.equal(v0, v1.to(v0.device)))
                    else:
                        self.assertTrue(v0 == v1)

        # comparing forward() outputs
        encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
        out_0 = model_0(**encoded_input)
        out_1 = model_1(**encoded_input)
        self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))

        # comparing generate() outputs
        encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
        output_sequences_0 = model_0.generate(**encoded_input, max_new_tokens=10)
        output_sequences_1 = model_1.generate(**encoded_input, max_new_tokens=10)

        def _decode(token):
            return tokenizer.decode(token, skip_special_tokens=True)

        self.assertEqual(
            [_decode(x) for x in output_sequences_0],
            [_decode(x) for x in output_sequences_1],
        )


class ExtendedSerializationTest(BaseSerializationTest):
    """
    tests more combinations of parameters
    """

    def test_nf4_single_unsafe(self):
        self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)

    def test_nf4_single_safe(self):
        self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)

    def test_nf4_double_unsafe(self):
        self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)

    # nf4 double safetensors quantization is tested in test_serialization() method from the parent class

    def test_fp4_single_unsafe(self):
        self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)

    def test_fp4_single_safe(self):
        self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)

    def test_fp4_double_unsafe(self):
        self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)

    def test_fp4_double_safe(self):
        self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)


class BloomSerializationTest(BaseSerializationTest):
    """
    default BaseSerializationTest config tested with Bloom family model
    """

    model_name = "bigscience/bloom-560m"


class GPTSerializationTest(BaseSerializationTest):
    """
    default BaseSerializationTest config tested with GPT family model
    """

    model_name = "gpt2-xl"