test_srt_engine_with_quant_args.py 1.88 KB
Newer Older
1
2
3
import unittest

import sglang as sgl
4
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
5
6


7
class TestSRTEngineWithQuantArgs(CustomTestCase):
8
9
10

    def test_1_quantization_args(self):

11
        # we only test fp8 because other methods are currently dependent on vllm. We can add other methods back to test after vllm dependency is resolved.
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        quantization_args_list = [
            # "awq",
            "fp8",
            # "gptq",
            # "marlin",
            # "gptq_marlin",
            # "awq_marlin",
            # "bitsandbytes",
            # "gguf",
        ]

        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        for quantization_args in quantization_args_list:
            engine = sgl.Engine(
                model_path=model_path, random_seed=42, quantization=quantization_args
            )
            engine.generate(prompt, sampling_params)
            engine.shutdown()

    def test_2_torchao_args(self):

        # we don't test int8dq because currently there is conflict between int8dq and capture cuda graph
        torchao_args_list = [
            # "int8dq",
            "int8wo",
            "fp8wo",
            "fp8dq-per_tensor",
            "fp8dq-per_row",
        ] + [f"int4wo-{group_size}" for group_size in [32, 64, 128, 256]]

        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        for torchao_config in torchao_args_list:
            engine = sgl.Engine(
                model_path=model_path, random_seed=42, torchao_config=torchao_config
            )
            engine.generate(prompt, sampling_params)
            engine.shutdown()


if __name__ == "__main__":
    unittest.main()