test_gptqmodel_dynamic.py 6.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import time
import unittest

import requests
import torch

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
11
    CustomTestCase,
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    popen_launch_server,
)


def check_quant_method(model_path: str, use_marlin_kernel: bool):
    from sglang.srt.configs.device_config import DeviceConfig
    from sglang.srt.configs.load_config import LoadConfig
    from sglang.srt.configs.model_config import AttentionArch, ModelConfig
    from sglang.srt.distributed import (
        get_tp_group,
        init_distributed_environment,
        initialize_model_parallel,
        set_custom_all_reduce,
    )
    from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
    from sglang.srt.layers.quantization import get_dynamic_override
    from sglang.srt.model_loader import get_model
    from sglang.srt.server_args import PortArgs, ServerArgs

    try:
        init_distributed_environment(
            backend="nccl",
            world_size=1,
            rank=0,
            local_rank=0,
            distributed_init_method="tcp://127.0.0.1:2646",
        )
        initialize_model_parallel(tensor_model_parallel_size=1)
        monkey_patch_vllm_parallel_state()
    except AssertionError:
        # ignore this error: tensor model parallel group is already initialized
        pass

    server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
46
    model_config = ModelConfig.from_server_args(server_args)
47
48
49
50
51
52
53

    load_config = LoadConfig()
    device_config = DeviceConfig("cuda")
    model = get_model(
        model_config=model_config, load_config=load_config, device_config=device_config
    )

54
55
56
    from sglang.srt.layers.linear import UnquantizedLinearMethod
    from sglang.srt.layers.quantization.gptq import (
        GPTQLinearMethod,
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        GPTQMarlinLinearMethod,
    )

    linear_method_cls = (
        GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod)
    )

    for name, submodule in model.named_modules():
        if name == "lm_head":
            assert isinstance(submodule.quant_method, linear_method_cls)
        elif name == "model.layers.0.self_attn.qkv_proj":
            # The first layer is quantized using bits=4, group_size=128
            # desc_act=True
            assert isinstance(submodule.quant_method, linear_method_cls)
            config = submodule.quant_method.quant_config
            assert config.weight_bits == 4
            assert config.group_size == 128
            assert config.desc_act
        elif name == "model.layers.1.self_attn.qkv_proj":
            # The second layer is quantized using bits=8, group_size=32
            # desc_act=False
            assert isinstance(submodule.quant_method, linear_method_cls)
            config = submodule.quant_method.quant_config
            assert get_dynamic_override(config, layer_name=name, key="bits") == 8
            assert get_dynamic_override(config, layer_name=name, key="group_size") == 32
            assert not get_dynamic_override(config, layer_name=name, key="desc_act")
        elif (
            name == "model.layers.2.self_attn.qkv_proj"
            or name == "model.layers.2.mlp.gate_up_proj"
        ):
            # All other layers (layer index >= 2) are not quantized
            assert isinstance(submodule.quant_method, UnquantizedLinearMethod)

    del model


# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test GPTQ fallback kernel that is not Marlin
96
class TestGPTQModelDynamic(CustomTestCase):
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    MODEL_PATH = (
        "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse"
    )

    @classmethod
    def setUpClass(cls):
        cls.model = cls.MODEL_PATH
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=["--dtype", "float16"],
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def run_decode(self, max_new_tokens):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
                "sampling_params": {
                    "max_new_tokens": max_new_tokens,
123
                    "temperature": 0.001,
124
125
126
127
128
129
130
131
                },
            },
        )
        return response.json()

    def test_throughput(self):
        max_tokens = 256

132
        tic = time.perf_counter()
133
        result = self.run_decode(max_tokens)
134
        tok = time.perf_counter()
135
136
137

        print(f"result = `{result}`")

138
        self.assertIn("paris", result["text"].lower())
139
140
141

        throughput = max_tokens / (tok - tic)
        print(f"Throughput: {throughput} tokens/s")
142
        self.assertGreaterEqual(throughput, 140)
143
144
145
146
147
148
149
150

    def test_gptq_module(self):
        check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)


# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test Marlin kernel
151
class TestGPTQModelDynamicWithMarlin(CustomTestCase):
152
153
154
155
156
157
158
159
160
161
162
163
    MODEL_PATH = (
        "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue"
    )

    @classmethod
    def setUpClass(cls):
        cls.model = cls.MODEL_PATH
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
164
            other_args=["--dtype", "bfloat16"],
165
166
167
168
169
170
171
172
173
174
175
176
177
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def run_decode(self, max_new_tokens):
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
                "sampling_params": {
                    "max_new_tokens": max_new_tokens,
Lianmin Zheng's avatar
Lianmin Zheng committed
178
                    "temperature": 0.001,
179
180
181
182
183
184
185
186
                },
            },
        )
        return response.json()

    def test_throughput(self):
        max_tokens = 256

187
        tic = time.perf_counter()
188
        result = self.run_decode(max_tokens)
189
        tok = time.perf_counter()
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        print(f"result = `{result}`")

        assert "paris" in result["text"].lower()

        throughput = max_tokens / (tok - tic)
        print(f"Throughput: {throughput} tokens/s")
        assert throughput >= 140

    def test_gptq_marlin_module(self):
        check_quant_method(self.MODEL_PATH, use_marlin_kernel=True)


if __name__ == "__main__":
    unittest.main()