"torchvision/vscode:/vscode.git/clone" did not exist on "c39c23ed14883a4effb3454000595a9933f18b05"
test_srt_engine_with_quant_args.py 1.88 KB
Newer Older
1
2
3
import unittest

import sglang as sgl
4
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
5
6


7
class TestSRTEngineWithQuantArgs(CustomTestCase):
8
9
10

    def test_1_quantization_args(self):

11
        # we only test fp8 because other methods are currently dependent on vllm. We can add other methods back to test after vllm dependency is resolved.
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
        quantization_args_list = [
            # "awq",
            "fp8",
            # "gptq",
            # "marlin",
            # "gptq_marlin",
            # "awq_marlin",
            # "bitsandbytes",
            # "gguf",
        ]

        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        for quantization_args in quantization_args_list:
            engine = sgl.Engine(
                model_path=model_path, random_seed=42, quantization=quantization_args
            )
            engine.generate(prompt, sampling_params)
            engine.shutdown()

    def test_2_torchao_args(self):

37
        # we don't test int8dq because currently there is conflict between int8dq and capture cuda graph
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        torchao_args_list = [
            # "int8dq",
            "int8wo",
            "fp8wo",
            "fp8dq-per_tensor",
            "fp8dq-per_row",
        ] + [f"int4wo-{group_size}" for group_size in [32, 64, 128, 256]]

        prompt = "Today is a sunny day and I like"
        model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

        sampling_params = {"temperature": 0, "max_new_tokens": 8}

        for torchao_config in torchao_args_list:
            engine = sgl.Engine(
                model_path=model_path, random_seed=42, torchao_config=torchao_config
            )
            engine.generate(prompt, sampling_params)
            engine.shutdown()


if __name__ == "__main__":
    unittest.main()