test_awq.py 13.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import gc
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import tempfile
import unittest

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqConfig, OPTForCausalLM
from transformers.testing_utils import (
    require_accelerate,
    require_auto_awq,
    require_torch_gpu,
    require_torch_multi_gpu,
    slow,
    torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available


if is_torch_available():
    import torch

if is_accelerate_available():
    from accelerate import init_empty_weights


@require_torch_gpu
class AwqConfigTest(unittest.TestCase):
    def test_wrong_backend(self):
        """
        Simple test that checks if a user passes a wrong backend an error is raised
        """
        # This should work fine
        _ = AwqConfig(bits=4)

        with self.assertRaises(ValueError):
            AwqConfig(bits=4, backend="")

51
52
53
54
55
56
57
        # These should work fine
        _ = AwqConfig(bits=4, version="GEMM")
        _ = AwqConfig(bits=4, version="gemm")

        with self.assertRaises(ValueError):
            AwqConfig(bits=4, backend="unexisting-backend")

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        # LLMAWQ does not work on a T4
        with self.assertRaises(ValueError):
            AwqConfig(bits=4, backend="llm-awq")

    def test_to_dict(self):
        """
        Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
        """
        quantization_config = AwqConfig(bits=4)
        config_to_dict = quantization_config.to_dict()

        for key in config_to_dict:
            self.assertEqual(getattr(quantization_config, key), config_to_dict[key])

    def test_from_dict(self):
        """
        Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
        """
        dict = {"bits": 2, "zero_point": False, "backend": "autoawq"}
        quantization_config = AwqConfig.from_dict(dict)

        self.assertEqual(dict["bits"], quantization_config.bits)
        self.assertEqual(dict["zero_point"], quantization_config.zero_point)
        self.assertEqual(dict["backend"], quantization_config.backend)


@slow
@require_torch_gpu
@require_auto_awq
@require_accelerate
class AwqTest(unittest.TestCase):
89
    model_name = "TheBloke/Mistral-7B-v0.1-AWQ"
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    dummy_transformers_model_name = "bigscience/bloom-560m"

    input_text = "Hello my name is"

    EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish"
    EXPECTED_OUTPUT_BF16 = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Exercise and Sport Science with a"

    device_map = "cuda"

    # called only once for all test in this class
    @classmethod
    def setUpClass(cls):
        """
        Setup quantized model
        """
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
        cls.quantized_model = AutoModelForCausalLM.from_pretrained(
            cls.model_name,
            device_map=cls.device_map,
        )

111
112
113
114
115
    def tearDown(self):
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def test_quantized_model_conversion(self):
        """
        Simple test that checks if the quantized model has been converted properly
        """
        from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV

        from transformers.integrations.awq import replace_with_awq_linear

        model_id = "facebook/opt-350m"
        config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
        quantization_config = AwqConfig(bits=4)

        with init_empty_weights():
            model = OPTForCausalLM(config)

        nb_linears = 0
        for module in model.modules():
            if isinstance(module, torch.nn.Linear):
                nb_linears += 1

        model, _ = replace_with_awq_linear(model, quantization_config=quantization_config)
        nb_awq_linear = 0
        for module in model.modules():
            if isinstance(module, (WQLinear_GEMM, WQLinear_GEMV)):
                nb_awq_linear += 1

        self.assertEqual(nb_linears, nb_awq_linear)

        # Try with `modules_not_to_convert`
        with init_empty_weights():
            model = OPTForCausalLM(config)

        model, _ = replace_with_awq_linear(
            model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]
        )
        nb_awq_linear = 0
        for module in model.modules():
            if isinstance(module, (WQLinear_GEMM, WQLinear_GEMV)):
                nb_awq_linear += 1

        self.assertEqual(nb_linears - 1, nb_awq_linear)

    def test_quantized_model(self):
        """
        Simple test that checks if the quantized model is working properly
        """
        input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

        output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
        self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

167
168
169
170
171
172
173
    def test_raise_if_non_quantized(self):
        model_id = "facebook/opt-125m"
        quantization_config = AwqConfig(bits=4)

        with self.assertRaises(ValueError):
            _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    def test_quantized_model_bf16(self):
        """
        Simple test that checks if the quantized model is working properly with bf16
        """
        input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

        quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.bfloat16).to(
            torch_device
        )

        output = quantized_model.generate(**input_ids, max_new_tokens=40)
        self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)

    def test_quantized_model_no_device_map(self):
        """
        Simple test that checks if the quantized model is working properly
        """
        input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

        quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name).to(torch_device)
        output = quantized_model.generate(**input_ids, max_new_tokens=40)

        self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

    def test_save_pretrained(self):
        """
        Simple test that checks if the quantized model is working properly after being saved and loaded
        """
        with tempfile.TemporaryDirectory() as tmpdirname:
            self.quantized_model.save_pretrained(tmpdirname)
            model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)

            input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

            output = model.generate(**input_ids, max_new_tokens=40)
            self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

    @require_torch_multi_gpu
    def test_quantized_model_multi_gpu(self):
        """
        Simple test that checks if the quantized model is working properly with multiple GPUs
        """
        input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

        quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")

        self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1, 2, 3})

        output = quantized_model.generate(**input_ids, max_new_tokens=40)

        self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365


@slow
@require_torch_gpu
@require_auto_awq
@require_accelerate
class AwqFusedTest(unittest.TestCase):
    model_name = "TheBloke/Mistral-7B-OpenOrca-AWQ"
    model_revision = "7048b2af77d0dd1c81b000b19d73f9cc8950b510"

    custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
    custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"

    prompt = (
        "You're standing on the surface of the Earth. "
        "You walk one mile south, one mile west and one mile north. "
        "You end up exactly where you started. Where are you?"
    )

    EXPECTED_GENERATION = prompt + "\n\nThis is a classic puzzle that has been around for"
    EXPECTED_GENERATION_CUSTOM_MODEL = "HelloWorld.java:11)\r\n\tat org"

    def tearDown(self):
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()

    def _check_fused_modules(self, model):
        has_fused_modules = False
        fused_modules_name = ["QuantAttentionFused", "QuantFusedMLP", "FasterTransformerRMSNorm"]

        for _, module in model.named_modules():
            if module.__class__.__name__ in fused_modules_name:
                has_fused_modules = True
                break

        self.assertTrue(has_fused_modules, "Modules fusing not performed correctly!")

    def test_raise_save_pretrained(self):
        """
        Test that `save_pretrained` is effectively blocked for fused models
        """
        quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)

        model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=quantization_config,
            low_cpu_mem_usage=True,
            revision=self.model_revision,
        ).to(torch_device)

        self._check_fused_modules(model)

        with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)

    def test_generation_fused(self):
        """
        Test generation quality for fused models - single batch case
        """
        quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)

        model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=quantization_config,
            low_cpu_mem_usage=True,
            revision=self.model_revision,
        ).to(torch_device)

        self._check_fused_modules(model)

        tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision)

        inputs = tokenizer(self.prompt, return_tensors="pt").to(torch_device)

        outputs = model.generate(**inputs, max_new_tokens=12)

        self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)

    def test_generation_fused_batched(self):
        """
        Test generation quality for fused models - multi batch case
        """
        quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)

        model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            quantization_config=quantization_config,
            low_cpu_mem_usage=True,
            revision=self.model_revision,
        ).to(torch_device)

        self._check_fused_modules(model)

        tokenizer = AutoTokenizer.from_pretrained(self.model_name, revision=self.model_revision)

        tokenizer.pad_token_id = tokenizer.eos_token_id
        inputs = tokenizer([self.prompt, self.prompt], return_tensors="pt", padding=True).to(torch_device)

        outputs = model.generate(**inputs, max_new_tokens=12)

        self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)

    @require_torch_multi_gpu
    def test_generation_custom_model(self):
        """
        Test generation quality for fused models using custom fused map.
        """
        quantization_config = AwqConfig(
            bits=4,
            fuse_max_seq_len=512,
            modules_to_fuse={
                "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
                "layernorm": ["ln1", "ln2", "norm"],
                "mlp": ["gate_proj", "up_proj", "down_proj"],
                "use_alibi": False,
                "num_attention_heads": 56,
                "num_key_value_heads": 8,
                "hidden_size": 7168,
            },
        )

        model = AutoModelForCausalLM.from_pretrained(
            self.custom_mapping_model_id,
            quantization_config=quantization_config,
            trust_remote_code=True,
            device_map="balanced",
            revision=self.custom_model_revision,
        )

        self._check_fused_modules(model)

        tokenizer = AutoTokenizer.from_pretrained(
            self.custom_mapping_model_id, revision=self.custom_model_revision, trust_remote_code=True
        )

        prompt = "Hello"
        inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)

        outputs = model.generate(**inputs, max_new_tokens=12)
        self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)