"git@developer.sourcefind.cn:xdb4_94051/vllm.git" did not exist on "f04908cae782e1a2404eb3e4f331718d311d1e0d"
test_4bit.py 19.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
16
import importlib.metadata
17
import tempfile
18
19
import unittest

20
21
from packaging import version

22
23
24
25
26
27
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
28
    BitsAndBytesConfig,
29
30
    pipeline,
)
31
32
33
34
35
36
37
38
39
40
41
from transformers.testing_utils import (
    is_torch_available,
    require_accelerate,
    require_bitsandbytes,
    require_torch,
    require_torch_gpu,
    require_torch_multi_gpu,
    slow,
)


42
43
44
45
46
47
def get_some_linear_layer(model):
    if model.config.model_type == "gpt2":
        return model.transformer.h[0].mlp.c_fc
    return model.transformer.h[0].mlp.dense_4h_to_h


48
49
if is_torch_available():
    import torch
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    import torch.nn as nn

    class LoRALayer(nn.Module):
        """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""

        def __init__(self, module: nn.Module, rank: int):
            super().__init__()
            self.module = module
            self.adapter = nn.Sequential(
                nn.Linear(module.in_features, rank, bias=False),
                nn.Linear(rank, module.out_features, bias=False),
            )
            small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
            nn.init.normal_(self.adapter[0].weight, std=small_std)
            nn.init.zeros_(self.adapter[1].weight)
            self.adapter.to(module.weight.device)

        def forward(self, input, *args, **kwargs):
            return self.module(input, *args, **kwargs) + self.adapter(input)
69
70
71
72
73
74
75


@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
76
class Base4bitTest(unittest.TestCase):
77
78
79
80
81
82
83
84
    # We keep the constants inside the init function and model loading inside setUp function

    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
    # Therefore here we use only bloom-1b3 to test our module
    model_name = "bigscience/bloom-1b7"

    # Constant values
    EXPECTED_RELATIVE_DIFFERENCE = (
85
        2.109659552692574  # This was obtained on a RTX Titan so the number might slightly change
86
87
88
    )

    input_text = "Hello my name is"
89
90
91
    EXPECTED_OUTPUTS = set()
    EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
    EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
92
    EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
93
94
95
96
97
98
99
    MAX_NEW_TOKENS = 10

    def setUp(self):
        # Models and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)


100
class Bnb4BitTest(Base4bitTest):
101
102
103
104
    def setUp(self):
        super().setUp()

        # Models and tokenizer
105
106
107
        self.model_fp16 = AutoModelForCausalLM.from_pretrained(
            self.model_name, torch_dtype=torch.float16, device_map="auto"
        )
108
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
109
110
111
112
113
114
115

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.model_fp16
116
        del self.model_4bit
117
118
119
120

        gc.collect()
        torch.cuda.empty_cache()

121
122
123
124
125
126
127
128
129
130
131
    def test_quantization_num_parameters(self):
        r"""
        Test if the number of returned parameters is correct

        See: https://github.com/huggingface/transformers/issues/25978
        """
        num_params_4bit = self.model_4bit.num_parameters()
        num_params_fp16 = self.model_fp16.num_parameters()

        self.assertEqual(num_params_4bit, num_params_fp16)

132
133
134
135
136
137
138
139
140
141
142
143
144
    def test_quantization_config_json_serialization(self):
        r"""
        A simple test to check if the quantization config is correctly serialized and deserialized
        """
        config = self.model_4bit.config

        self.assertTrue(hasattr(config, "quantization_config"))

        _ = config.to_dict()
        _ = config.to_diff_dict()

        _ = config.to_json_string()

145
146
147
148
149
    def test_memory_footprint(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
150
        from bitsandbytes.nn import Params4bit
151
152

        mem_fp16 = self.model_fp16.get_memory_footprint()
153
154
155
        mem_4bit = self.model_4bit.get_memory_footprint()

        self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
156
157
        linear = get_some_linear_layer(self.model_4bit)
        self.assertTrue(linear.weight.__class__ == Params4bit)
158
159
160
161
162
163
164
165
166
167

    def test_linear_are_4bit(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        from transformers import T5PreTrainedModel

        self.model_fp16.get_memory_footprint()
        self.model_4bit.get_memory_footprint()
168

169
170
171
172
173
        for name, module in self.model_4bit.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
                    # 4-bit parameters are packed in uint8 variables
                    self.assertTrue(module.weight.dtype == torch.uint8)
174
175
176
177
178
179
180
181

    def test_generate_quality(self):
        r"""
        Test the generation quality of the quantized model and see that we are matching the expected output.
        Given that we are operating on small numbers + the testing model is relatively small, we might not get
        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
        """
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
182
        output_sequences = self.model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
183

184
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
185

186
187
188
189
190
    def test_generate_quality_config(self):
        r"""
        Test that loading the model with the config is equivalent
        """
        bnb_config = BitsAndBytesConfig()
191
        bnb_config.load_in_4bit = True
192

193
        model_4bit_from_config = AutoModelForCausalLM.from_pretrained(
194
195
196
197
            self.model_name, quantization_config=bnb_config, device_map="auto"
        )

        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
198
        output_sequences = model_4bit_from_config.generate(
199
200
201
            input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10
        )

202
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
203

204
    def test_raise_on_save_pretrained(self):
205
206
207
        r"""
        Test whether trying to save a model after converting it in 8-bit will throw a warning.
        """
208
209
        with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
            self.model_4bit.save_pretrained(tmpdirname)
210

211
    def test_raise_if_config_and_load_in_4bit(self):
212
        r"""
213
        Test that loading the model with the config and `load_in_4bit` raises an error
214
215
216
217
218
219
220
        """
        bnb_config = BitsAndBytesConfig()

        with self.assertRaises(ValueError):
            _ = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                quantization_config=bnb_config,
221
                load_in_4bit=True,
222
                device_map="auto",
223
                bnb_4bit_quant_type="nf4",
224
225
            )

226
227
228
229
230
231
232
    def test_device_and_dtype_assignment(self):
        r"""
        Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
        Checks also if other models are casted correctly.
        """
        with self.assertRaises(ValueError):
            # Tries with `str`
233
            self.model_4bit.to("cpu")
234
235
236

        with self.assertRaises(ValueError):
            # Tries with a `dtype``
237
            self.model_4bit.to(torch.float16)
238
239
240

        with self.assertRaises(ValueError):
            # Tries with a `device`
241
            self.model_4bit.to(torch.device("cuda:0"))
242
243
244

        with self.assertRaises(ValueError):
            # Tries with a `device`
245
            self.model_4bit.float()
246
247
248

        with self.assertRaises(ValueError):
            # Tries with a `device`
249
            self.model_4bit.half()
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

        # Test if we did not break anything
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        self.model_fp16 = self.model_fp16.to(torch.float32)
        _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

        # Check this does not throw an error
        _ = self.model_fp16.to("cpu")

        # Check this does not throw an error
        _ = self.model_fp16.half()

        # Check this does not throw an error
        _ = self.model_fp16.float()

266
    def test_fp32_4bit_conversion(self):
267
        r"""
268
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
269
        """
270
        model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_4bit=True, device_map="auto")
271
272
        self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)

273

274
275
276
277
278
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
279
class Bnb4BitT5Test(unittest.TestCase):
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    @classmethod
    def setUpClass(cls):
        cls.model_name = "t5-small"
        cls.dense_act_model_name = "google/flan-t5-small"  # flan-t5 uses dense-act instead of dense-relu-dense
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
        cls.input_text = "Translate in German: Hello, my dog is cute"

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        gc.collect()
        torch.cuda.empty_cache()

    def test_inference_without_keep_in_fp32(self):
        r"""
297
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
298
299
300
301
302
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
        from transformers import T5ForConditionalGeneration

303
        modules = T5ForConditionalGeneration._keep_in_fp32_modules
304
305
306
        T5ForConditionalGeneration._keep_in_fp32_modules = None

        # test with `t5-small`
307
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
308
309
310
311
312
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
313
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
314
315
316
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)
317
        T5ForConditionalGeneration._keep_in_fp32_modules = modules
318
319
320

    def test_inference_with_keep_in_fp32(self):
        r"""
321
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
322
323
324
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
325
326
        import bitsandbytes as bnb

327
328
329
        from transformers import T5ForConditionalGeneration

        # test with `t5-small`
330
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
331
332

        # there was a bug with decoders - this test checks that it is fixed
333
        self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit))
334

335
336
337
338
339
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
340
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
341
342
343
344
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

345

346
class Classes4BitModelTest(Base4bitTest):
347
348
349
350
    def setUp(self):
        super().setUp()
        # model_name
        self.model_name = "bigscience/bloom-560m"
351
352
353
354
        self.seq_to_seq_name = "t5-small"

        # Different types of model

355
        self.base_model = AutoModel.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
356
        # Sequence classification model
357
        self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
358
            self.model_name, load_in_4bit=True, device_map="auto"
359
        )
360
        # CausalLM model
361
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
362
363
        # Seq2seq model
        self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
364
            self.seq_to_seq_name, load_in_4bit=True, device_map="auto"
365
        )
366
367
368
369
370
371
372
373

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.base_model
        del self.sequence_model
374
        del self.model_4bit
375
        del self.seq_to_seq_model
376
377
378
379
380
381
382
383
384

        gc.collect()
        torch.cuda.empty_cache()

    def test_correct_head_class(self):
        r"""
        A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification)
        are kept in their native class.
        """
385
        from bitsandbytes.nn import Params4bit
386

387
        self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
388
389

        # Other heads should be nn.Parameter
390
        self.assertTrue(self.model_4bit.lm_head.weight.__class__ == torch.nn.Parameter)
391
        self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
392
        self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
393
394


395
class Pipeline4BitTest(Base4bitTest):
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    def setUp(self):
        super().setUp()

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.pipe

        gc.collect()
        torch.cuda.empty_cache()

    def test_pipeline(self):
        r"""
411
        The aim of this test is to verify that the mixed 4bit is compatible with `pipeline` from transformers. Since
412
413
414
415
416
417
418
        we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything
        on pipline.
        """
        # self._clear_cuda_cache()
        self.pipe = pipeline(
            "text-generation",
            model=self.model_name,
419
            model_kwargs={"device_map": "auto", "load_in_4bit": True, "torch_dtype": torch.float16},
420
421
422
423
424
            max_new_tokens=self.MAX_NEW_TOKENS,
        )

        # Real second forward pass
        pipeline_output = self.pipe(self.input_text)
425
        self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
426
427
428


@require_torch_multi_gpu
429
class Bnb4bitTestMultiGpu(Base4bitTest):
430
431
432
433
434
435
436
437
438
439
    def setUp(self):
        super().setUp()

    def test_multi_gpu_loading(self):
        r"""
        This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
        Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
        """

        model_parallel = AutoModelForCausalLM.from_pretrained(
440
            self.model_name, load_in_4bit=True, device_map="balanced"
441
442
        )

Younes Belkada's avatar
Younes Belkada committed
443
444
        # Check correct device map
        self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
445
446
447
448
449
450

        # Check that inference pass works on the model
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        # Second real batch
        output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
451
        self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
452
453


454
class Bnb4BitTestTraining(Base4bitTest):
455
456
457
458
459
    def setUp(self):
        self.model_name = "facebook/opt-350m"
        super().setUp()

    def test_training(self):
460
        if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.0"):
461
462
463
            return

        # Step 1: freeze all parameters
464
465
466
        model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)

        self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494

        for param in model.parameters():
            param.requires_grad = False  # freeze the model - train adapters later
            if param.ndim == 1:
                # cast the small parameters (e.g. layernorm) to fp32 for stability
                param.data = param.data.to(torch.float32)

        # Step 2: add adapters
        for _, module in model.named_modules():
            if "OPTAttention" in repr(type(module)):
                module.q_proj = LoRALayer(module.q_proj, rank=16)
                module.k_proj = LoRALayer(module.k_proj, rank=16)
                module.v_proj = LoRALayer(module.v_proj, rank=16)

        # Step 3: dummy batch
        batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)

        # Step 4: Check if the gradient is not None
        with torch.cuda.amp.autocast():
            out = model.forward(**batch)
            out.logits.norm().backward()

        for module in model.modules():
            if isinstance(module, LoRALayer):
                self.assertTrue(module.adapter[1].weight.grad is not None)
                self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
            elif isinstance(module, nn.Embedding):
                self.assertTrue(module.weight.grad is None)
495
496
497
498
499


class Bnb4BitGPT2Test(Bnb4BitTest):
    model_name = "gpt2-xl"
    EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187