"RealESRGAN/utils.py" did not exist on "86d4c940d96abd6106ec47141b4740d98ff24bed"
test_4bit.py 17.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
16
import tempfile
17
18
import unittest

19
20
from packaging import version

21
22
23
24
25
26
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
27
    BitsAndBytesConfig,
28
29
    pipeline,
)
30
31
32
33
34
35
36
37
38
from transformers.testing_utils import (
    is_torch_available,
    require_accelerate,
    require_bitsandbytes,
    require_torch,
    require_torch_gpu,
    require_torch_multi_gpu,
    slow,
)
39
from transformers.utils.versions import importlib_metadata
40
41
42
43


if is_torch_available():
    import torch
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    import torch.nn as nn

    class LoRALayer(nn.Module):
        """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""

        def __init__(self, module: nn.Module, rank: int):
            super().__init__()
            self.module = module
            self.adapter = nn.Sequential(
                nn.Linear(module.in_features, rank, bias=False),
                nn.Linear(rank, module.out_features, bias=False),
            )
            small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
            nn.init.normal_(self.adapter[0].weight, std=small_std)
            nn.init.zeros_(self.adapter[1].weight)
            self.adapter.to(module.weight.device)

        def forward(self, input, *args, **kwargs):
            return self.module(input, *args, **kwargs) + self.adapter(input)
63
64
65
66
67
68
69


@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
70
class Base4bitTest(unittest.TestCase):
71
72
73
74
75
76
77
78
    # We keep the constants inside the init function and model loading inside setUp function

    # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
    # Therefore here we use only bloom-1b3 to test our module
    model_name = "bigscience/bloom-1b7"

    # Constant values
    EXPECTED_RELATIVE_DIFFERENCE = (
79
        2.109659552692574  # This was obtained on a RTX Titan so the number might slightly change
80
81
82
    )

    input_text = "Hello my name is"
83
84
85
    EXPECTED_OUTPUTS = set()
    EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
    EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
86
87
88
89
90
91
92
    MAX_NEW_TOKENS = 10

    def setUp(self):
        # Models and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)


93
class Bnb4BitTest(Base4bitTest):
94
95
96
97
    def setUp(self):
        super().setUp()

        # Models and tokenizer
98
99
100
        self.model_fp16 = AutoModelForCausalLM.from_pretrained(
            self.model_name, torch_dtype=torch.float16, device_map="auto"
        )
101
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
102
103
104
105
106
107
108

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.model_fp16
109
        del self.model_4bit
110
111
112
113
114
115
116
117
118

        gc.collect()
        torch.cuda.empty_cache()

    def test_memory_footprint(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
119
        from bitsandbytes.nn import Params4bit
120
121

        mem_fp16 = self.model_fp16.get_memory_footprint()
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        mem_4bit = self.model_4bit.get_memory_footprint()

        self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
        self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit)

    def test_linear_are_4bit(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        from transformers import T5PreTrainedModel

        self.model_fp16.get_memory_footprint()
        self.model_4bit.get_memory_footprint()
136

137
138
139
140
141
        for name, module in self.model_4bit.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
                    # 4-bit parameters are packed in uint8 variables
                    self.assertTrue(module.weight.dtype == torch.uint8)
142
143
144
145
146
147
148
149

    def test_generate_quality(self):
        r"""
        Test the generation quality of the quantized model and see that we are matching the expected output.
        Given that we are operating on small numbers + the testing model is relatively small, we might not get
        the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
        """
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
150
        output_sequences = self.model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
151

152
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
153

154
155
156
157
158
    def test_generate_quality_config(self):
        r"""
        Test that loading the model with the config is equivalent
        """
        bnb_config = BitsAndBytesConfig()
159
        bnb_config.load_in_4bit = True
160

161
        model_4bit_from_config = AutoModelForCausalLM.from_pretrained(
162
163
164
165
            self.model_name, quantization_config=bnb_config, device_map="auto"
        )

        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
166
        output_sequences = model_4bit_from_config.generate(
167
168
169
            input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10
        )

170
        self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
171

172
    def test_raise_on_save_pretrained(self):
173
174
175
        r"""
        Test whether trying to save a model after converting it in 8-bit will throw a warning.
        """
176
177
        with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
            self.model_4bit.save_pretrained(tmpdirname)
178

179
    def test_raise_if_config_and_load_in_4bit(self):
180
        r"""
181
        Test that loading the model with the config and `load_in_4bit` raises an error
182
183
184
185
186
187
188
        """
        bnb_config = BitsAndBytesConfig()

        with self.assertRaises(ValueError):
            _ = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                quantization_config=bnb_config,
189
                load_in_4bit=True,
190
                device_map="auto",
191
                bnb_4bit_quant_type="nf4",
192
193
            )

194
195
196
197
198
199
200
    def test_device_and_dtype_assignment(self):
        r"""
        Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
        Checks also if other models are casted correctly.
        """
        with self.assertRaises(ValueError):
            # Tries with `str`
201
            self.model_4bit.to("cpu")
202
203
204

        with self.assertRaises(ValueError):
            # Tries with a `dtype``
205
            self.model_4bit.to(torch.float16)
206
207
208

        with self.assertRaises(ValueError):
            # Tries with a `device`
209
            self.model_4bit.to(torch.device("cuda:0"))
210
211
212

        with self.assertRaises(ValueError):
            # Tries with a `device`
213
            self.model_4bit.float()
214
215
216

        with self.assertRaises(ValueError):
            # Tries with a `device`
217
            self.model_4bit.half()
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

        # Test if we did not break anything
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        self.model_fp16 = self.model_fp16.to(torch.float32)
        _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

        # Check this does not throw an error
        _ = self.model_fp16.to("cpu")

        # Check this does not throw an error
        _ = self.model_fp16.half()

        # Check this does not throw an error
        _ = self.model_fp16.float()

234
    def test_fp32_4bit_conversion(self):
235
        r"""
236
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
237
        """
238
        model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_4bit=True, device_map="auto")
239
240
        self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)

241

242
243
244
245
246
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu
@slow
247
class Bnb4BitT5Test(unittest.TestCase):
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    @classmethod
    def setUpClass(cls):
        cls.model_name = "t5-small"
        cls.dense_act_model_name = "google/flan-t5-small"  # flan-t5 uses dense-act instead of dense-relu-dense
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
        cls.input_text = "Translate in German: Hello, my dog is cute"

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        gc.collect()
        torch.cuda.empty_cache()

    def test_inference_without_keep_in_fp32(self):
        r"""
265
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
266
267
268
269
270
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
        from transformers import T5ForConditionalGeneration

271
        modules = T5ForConditionalGeneration._keep_in_fp32_modules
272
273
274
        T5ForConditionalGeneration._keep_in_fp32_modules = None

        # test with `t5-small`
275
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
276
277
278
279
280
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
281
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
282
283
284
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)
285
        T5ForConditionalGeneration._keep_in_fp32_modules = modules
286
287
288

    def test_inference_with_keep_in_fp32(self):
        r"""
289
        Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
290
291
292
        `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
        both cases.
        """
293
294
        import bitsandbytes as bnb

295
296
297
        from transformers import T5ForConditionalGeneration

        # test with `t5-small`
298
        model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
299
300

        # there was a bug with decoders - this test checks that it is fixed
301
        self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit))
302

303
304
305
306
307
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

        # test with `flan-t5-small`
        model = T5ForConditionalGeneration.from_pretrained(
308
            self.dense_act_model_name, load_in_4bit=True, device_map="auto"
309
310
311
312
        )
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
        _ = model.generate(**encoded_input)

313

314
class Classes4BitModelTest(Base4bitTest):
315
316
317
318
    def setUp(self):
        super().setUp()
        # model_name
        self.model_name = "bigscience/bloom-560m"
319
320
321
322
        self.seq_to_seq_name = "t5-small"

        # Different types of model

323
        self.base_model = AutoModel.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
324
        # Sequence classification model
325
        self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
326
            self.model_name, load_in_4bit=True, device_map="auto"
327
        )
328
        # CausalLM model
329
        self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
330
331
        # Seq2seq model
        self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
332
            self.seq_to_seq_name, load_in_4bit=True, device_map="auto"
333
        )
334
335
336
337
338
339
340
341

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.base_model
        del self.sequence_model
342
        del self.model_4bit
343
        del self.seq_to_seq_model
344
345
346
347
348
349
350
351
352

        gc.collect()
        torch.cuda.empty_cache()

    def test_correct_head_class(self):
        r"""
        A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification)
        are kept in their native class.
        """
353
        from bitsandbytes.nn import Params4bit
354

355
        self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
356
357

        # Other heads should be nn.Parameter
358
        self.assertTrue(self.model_4bit.lm_head.weight.__class__ == torch.nn.Parameter)
359
        self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
360
        self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
361
362


363
class Pipeline4BitTest(Base4bitTest):
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    def setUp(self):
        super().setUp()

    def tearDown(self):
        r"""
        TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
        avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
        """
        del self.pipe

        gc.collect()
        torch.cuda.empty_cache()

    def test_pipeline(self):
        r"""
379
        The aim of this test is to verify that the mixed 4bit is compatible with `pipeline` from transformers. Since
380
381
382
383
384
385
386
        we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything
        on pipline.
        """
        # self._clear_cuda_cache()
        self.pipe = pipeline(
            "text-generation",
            model=self.model_name,
387
            model_kwargs={"device_map": "auto", "load_in_4bit": True, "torch_dtype": torch.float16},
388
389
390
391
392
            max_new_tokens=self.MAX_NEW_TOKENS,
        )

        # Real second forward pass
        pipeline_output = self.pipe(self.input_text)
393
        self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
394
395
396


@require_torch_multi_gpu
397
class Bnb4bitTestMultiGpu(Base4bitTest):
398
399
400
401
402
403
404
405
406
407
    def setUp(self):
        super().setUp()

    def test_multi_gpu_loading(self):
        r"""
        This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
        Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
        """

        model_parallel = AutoModelForCausalLM.from_pretrained(
408
            self.model_name, load_in_4bit=True, device_map="balanced"
409
410
        )

Younes Belkada's avatar
Younes Belkada committed
411
412
        # Check correct device map
        self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
413
414
415
416
417
418

        # Check that inference pass works on the model
        encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

        # Second real batch
        output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
419
        self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
420
421


422
class Bnb4BitTestTraining(Base4bitTest):
423
424
425
426
427
428
429
430
431
    def setUp(self):
        self.model_name = "facebook/opt-350m"
        super().setUp()

    def test_training(self):
        if version.parse(importlib_metadata.version("bitsandbytes")) < version.parse("0.37.0"):
            return

        # Step 1: freeze all parameters
432
        model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460

        for param in model.parameters():
            param.requires_grad = False  # freeze the model - train adapters later
            if param.ndim == 1:
                # cast the small parameters (e.g. layernorm) to fp32 for stability
                param.data = param.data.to(torch.float32)

        # Step 2: add adapters
        for _, module in model.named_modules():
            if "OPTAttention" in repr(type(module)):
                module.q_proj = LoRALayer(module.q_proj, rank=16)
                module.k_proj = LoRALayer(module.k_proj, rank=16)
                module.v_proj = LoRALayer(module.v_proj, rank=16)

        # Step 3: dummy batch
        batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)

        # Step 4: Check if the gradient is not None
        with torch.cuda.amp.autocast():
            out = model.forward(**batch)
            out.logits.norm().backward()

        for module in model.modules():
            if isinstance(module, LoRALayer):
                self.assertTrue(module.adapter[1].weight.grad is not None)
                self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
            elif isinstance(module, nn.Embedding):
                self.assertTrue(module.weight.grad is None)