test_4bit.py 20 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
16
import importlib.metadata
17
import tempfile
18
19
import unittest

20
21
from packaging import version

22
23
24
25
26
27
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
28
    BitsAndBytesConfig,
29
30
    pipeline,
)
31
32
33
34
35
36
37
38
39
40
41
from transformers.testing_utils import (
    is_torch_available,
    require_accelerate,
    require_bitsandbytes,
    require_torch,
    require_torch_gpu,
    require_torch_multi_gpu,
    slow,
)


42
43
44
45
46
47
def get_some_linear_layer(model):
    if model.config.model_type == "gpt2":
        return model.transformer.h[0].mlp.c_fc
    return model.transformer.h[0].mlp.dense_4h_to_h


48
49
if is_torch_available():
    import torch
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    import torch.nn as nn

    class LoRALayer(nn.Module):
        """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""

        def __init__(self, module: nn.Module, rank: int):
            super().__init__()
            self.module = module
            self.adapter = nn.Sequential(
                nn.Linear(module.in_features, rank, bias=False),
                nn.Linear(rank, module.out_features, bias=False),
            )
            small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
            nn.init.normal_(self.adapter[0].weight, std=small_std)
            nn.init.zeros_(self.adapter[1].weight)
            self.adapter.to(module.weight.device)

        def forward(self, input, *args, **kwargs):
            return self.module(input, *args, **kwargs) + self.adapter(input)
69
70
71
72
73
74
75


@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
76
class Base4bitTest(unittest.TestCase):
77
78
79
80
81
82
83
84
    # We keep the constants inside the init function and model loading inside setUp function

    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
    # Therefore here we use only bloom-1b3 to test our module
    model_name = "bigscience/bloom-1b7"

    # Constant values
    EXPECTED_RELATIVE_DIFFERENCE = (
85
        2.109659552692574  # This was obtained on a RTX Titan so the number might slightly change
86
87
88
    )

    input_text = "Hello my name is"
89
90
91
    EXPECTED_OUTPUTS = set()
    EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
    EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
92
    EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
93
94
95
96
97
98
99
    MAX_NEW_TOKENS = 10

    def setUp(self):
        # Models and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)


100
class Bnb4BitTest(Base4bitTest):
101
102
103
104
    def setUp(self):
        super().setUp()

        # Models and tokenizer
105
106
107
        self.model_fp16 = AutoModelForCausalLM.from_pretrained(
            self.model_name, torch_dtype=torch.float16, device_map="auto"
        )
108
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
109
110
111
112
113
114
115

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.model_fp16
116
        del self.model_4bit
117
118
119
120

        gc.collect()
        torch.cuda.empty_cache()

121
122
123
124
125
126
127
128
129
130
131
    def test_quantization_num_parameters(self):
        r"""
        Test if the number of returned parameters is correct

        See: https://github.com/huggingface/transformers/issues/25978
        """
        num_params_4bit = self.model_4bit.num_parameters()
        num_params_fp16 = self.model_fp16.num_parameters()

        self.assertEqual(num_params_4bit, num_params_fp16)

132
133
134
135
136
137
138
139
140
141
142
143
144
    def test_quantization_config_json_serialization(self):
        r"""
        A simple test to check if the quantization config is correctly serialized and deserialized
        """
        config = self.model_4bit.config

        self.assertTrue(hasattr(config, "quantization_config"))

        _ = config.to_dict()
        _ = config.to_diff_dict()

        _ = config.to_json_string()

145
146
147
148
149
    def test_memory_footprint(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
150
        from bitsandbytes.nn import Params4bit
151
152

        mem_fp16 = self.model_fp16.get_memory_footprint()
153
154
155
        mem_4bit = self.model_4bit.get_memory_footprint()

        self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
156
157
        linear = get_some_linear_layer(self.model_4bit)
        self.assertTrue(linear.weight.__class__ == Params4bit)
158

159
160
161
162
163
164
165
166
    def test_original_dtype(self):
        r"""
        A simple test to check if the model succesfully stores the original dtype
        """
        self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype"))
        self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
        self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16)

167
168
169
170
171
172
173
174
175
    def test_linear_are_4bit(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        from transformers import T5PreTrainedModel

        self.model_fp16.get_memory_footprint()
        self.model_4bit.get_memory_footprint()
176

177
178
179
180
181
        for name, module in self.model_4bit.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
                    # 4-bit parameters are packed in uint8 variables
                    self.assertTrue(module.weight.dtype == torch.uint8)
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def test_rwkv_4bit(self):
        r"""
        A simple test to check if 4-bit RWKV inference works as expected.
        """
        model_id = "RWKV/rwkv-4-169m-pile"

        quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)

        model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
        tok = AutoTokenizer.from_pretrained(model_id)

        text = "Hello my name is"
        input_ids = tok.encode(text, return_tensors="pt").to(0)

        _ = model.generate(input_ids, max_new_tokens=30)

199
200
201
202
203
204
205
    def test_generate_quality(self):
        r"""
        Test the generation quality of the quantized model and see that we are matching the expected output.
        Given that we are operating on small numbers + the testing model is relatively small, we might not get
        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
        """
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
206
        output_sequences = self.model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
207

208
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
209

210
211
212
213
214
    def test_generate_quality_config(self):
        r"""
        Test that loading the model with the config is equivalent
        """
        bnb_config = BitsAndBytesConfig()
215
        bnb_config.load_in_4bit = True
216

217
        model_4bit_from_config = AutoModelForCausalLM.from_pretrained(
218
219
220
221
            self.model_name, quantization_config=bnb_config, device_map="auto"
        )

        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
222
        output_sequences = model_4bit_from_config.generate(
223
224
225
            input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10
        )

226
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
227

228
    def test_raise_on_save_pretrained(self):
229
230
231
        r"""
        Test whether trying to save a model after converting it in 8-bit will throw a warning.
        """
232
233
        with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
            self.model_4bit.save_pretrained(tmpdirname)
234

235
    def test_raise_if_config_and_load_in_4bit(self):
236
        r"""
237
        Test that loading the model with the config and `load_in_4bit` raises an error
238
239
240
241
242
243
244
        """
        bnb_config = BitsAndBytesConfig()

        with self.assertRaises(ValueError):
            _ = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                quantization_config=bnb_config,
245
                load_in_4bit=True,
246
                device_map="auto",
247
                bnb_4bit_quant_type="nf4",
248
249
            )

250
251
252
253
254
255
256
    def test_device_and_dtype_assignment(self):
        r"""
        Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
        Checks also if other models are casted correctly.
        """
        with self.assertRaises(ValueError):
            # Tries with `str`
257
            self.model_4bit.to("cpu")
258
259
260

        with self.assertRaises(ValueError):
            # Tries with a `dtype``
261
            self.model_4bit.to(torch.float16)
262
263
264

        with self.assertRaises(ValueError):
            # Tries with a `device`
265
            self.model_4bit.to(torch.device("cuda:0"))
266
267
268

        with self.assertRaises(ValueError):
            # Tries with a `device`
269
            self.model_4bit.float()
270
271
272

        with self.assertRaises(ValueError):
            # Tries with a `device`
273
            self.model_4bit.half()
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

        # Test if we did not break anything
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        self.model_fp16 = self.model_fp16.to(torch.float32)
        _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

        # Check this does not throw an error
        _ = self.model_fp16.to("cpu")

        # Check this does not throw an error
        _ = self.model_fp16.half()

        # Check this does not throw an error
        _ = self.model_fp16.float()

290
    def test_fp32_4bit_conversion(self):
291
        r"""
292
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
293
        """
294
        model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_4bit=True, device_map="auto")
295
296
        self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)

297

298
299
300
301
302
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
303
class Bnb4BitT5Test(unittest.TestCase):
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    @classmethod
    def setUpClass(cls):
        cls.model_name = "t5-small"
        cls.dense_act_model_name = "google/flan-t5-small"  # flan-t5 uses dense-act instead of dense-relu-dense
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
        cls.input_text = "Translate in German: Hello, my dog is cute"

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        gc.collect()
        torch.cuda.empty_cache()

    def test_inference_without_keep_in_fp32(self):
        r"""
321
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
322
323
324
325
326
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
        from transformers import T5ForConditionalGeneration

327
        modules = T5ForConditionalGeneration._keep_in_fp32_modules
328
329
330
        T5ForConditionalGeneration._keep_in_fp32_modules = None

        # test with `t5-small`
331
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
332
333
334
335
336
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
337
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
338
339
340
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)
341
        T5ForConditionalGeneration._keep_in_fp32_modules = modules
342
343
344

    def test_inference_with_keep_in_fp32(self):
        r"""
345
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
346
347
348
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
349
350
        import bitsandbytes as bnb

351
352
353
        from transformers import T5ForConditionalGeneration

        # test with `t5-small`
354
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
355
356

        # there was a bug with decoders - this test checks that it is fixed
357
        self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit))
358

359
360
361
362
363
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
364
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
365
366
367
368
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

369

370
class Classes4BitModelTest(Base4bitTest):
371
372
373
374
    def setUp(self):
        super().setUp()
        # model_name
        self.model_name = "bigscience/bloom-560m"
375
376
377
378
        self.seq_to_seq_name = "t5-small"

        # Different types of model

379
        self.base_model = AutoModel.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
380
        # Sequence classification model
381
        self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
382
            self.model_name, load_in_4bit=True, device_map="auto"
383
        )
384
        # CausalLM model
385
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
386
387
        # Seq2seq model
        self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
388
            self.seq_to_seq_name, load_in_4bit=True, device_map="auto"
389
        )
390
391
392
393
394
395
396
397

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.base_model
        del self.sequence_model
398
        del self.model_4bit
399
        del self.seq_to_seq_model
400
401
402
403
404
405
406
407
408

        gc.collect()
        torch.cuda.empty_cache()

    def test_correct_head_class(self):
        r"""
        A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification)
        are kept in their native class.
        """
409
        from bitsandbytes.nn import Params4bit
410

411
        self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
412
413

        # Other heads should be nn.Parameter
414
        self.assertTrue(self.model_4bit.lm_head.weight.__class__ == torch.nn.Parameter)
415
        self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
416
        self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
417
418


419
class Pipeline4BitTest(Base4bitTest):
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    def setUp(self):
        super().setUp()

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.pipe

        gc.collect()
        torch.cuda.empty_cache()

    def test_pipeline(self):
        r"""
435
        The aim of this test is to verify that the mixed 4bit is compatible with `pipeline` from transformers. Since
436
437
438
439
440
441
442
        we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything
        on pipline.
        """
        # self._clear_cuda_cache()
        self.pipe = pipeline(
            "text-generation",
            model=self.model_name,
443
            model_kwargs={"device_map": "auto", "load_in_4bit": True, "torch_dtype": torch.float16},
444
445
446
447
448
            max_new_tokens=self.MAX_NEW_TOKENS,
        )

        # Real second forward pass
        pipeline_output = self.pipe(self.input_text)
449
        self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
450
451
452


@require_torch_multi_gpu
453
class Bnb4bitTestMultiGpu(Base4bitTest):
454
455
456
457
458
459
460
461
462
463
    def setUp(self):
        super().setUp()

    def test_multi_gpu_loading(self):
        r"""
        This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
        Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
        """

        model_parallel = AutoModelForCausalLM.from_pretrained(
464
            self.model_name, load_in_4bit=True, device_map="balanced"
465
466
        )

Younes Belkada's avatar
Younes Belkada committed
467
468
        # Check correct device map
        self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
469
470
471
472
473
474

        # Check that inference pass works on the model
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        # Second real batch
        output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
475
        self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
476
477


478
class Bnb4BitTestTraining(Base4bitTest):
479
480
481
482
483
    def setUp(self):
        self.model_name = "facebook/opt-350m"
        super().setUp()

    def test_training(self):
484
        if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.0"):
485
486
487
            return

        # Step 1: freeze all parameters
488
489
490
        model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)

        self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518

        for param in model.parameters():
            param.requires_grad = False  # freeze the model - train adapters later
            if param.ndim == 1:
                # cast the small parameters (e.g. layernorm) to fp32 for stability
                param.data = param.data.to(torch.float32)

        # Step 2: add adapters
        for _, module in model.named_modules():
            if "OPTAttention" in repr(type(module)):
                module.q_proj = LoRALayer(module.q_proj, rank=16)
                module.k_proj = LoRALayer(module.k_proj, rank=16)
                module.v_proj = LoRALayer(module.v_proj, rank=16)

        # Step 3: dummy batch
        batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)

        # Step 4: Check if the gradient is not None
        with torch.cuda.amp.autocast():
            out = model.forward(**batch)
            out.logits.norm().backward()

        for module in model.modules():
            if isinstance(module, LoRALayer):
                self.assertTrue(module.adapter[1].weight.grad is not None)
                self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
            elif isinstance(module, nn.Embedding):
                self.assertTrue(module.weight.grad is None)
519
520
521
522
523


class Bnb4BitGPT2Test(Bnb4BitTest):
    model_name = "gpt2-xl"
    EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187