test_intel_amx_attention_backend.py 3.9 KB
Newer Older
YanbingJiang's avatar
YanbingJiang committed
1
2
3
4
5
6
"""
Usage:
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
"""

import unittest
DiweiSun's avatar
DiweiSun committed
7
from functools import wraps
YanbingJiang's avatar
YanbingJiang committed
8
9
10
11
12
13
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DiweiSun's avatar
DiweiSun committed
14
15
16
17
18
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE,
    DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8,
    DEFAULT_MODEL_NAME_FOR_TEST_W8A8,
    DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE,
YanbingJiang's avatar
YanbingJiang committed
19
20
21
22
23
24
25
26
27
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
    is_in_ci,
    popen_launch_server,
    run_bench_one_batch,
)


DiweiSun's avatar
DiweiSun committed
28
29
30
31
32
def intel_amx_benchmark(extra_args=None, min_throughput=None):
    def decorator(test_func):
        @wraps(test_func)
        def wrapper(self):
            common_args = [
YanbingJiang's avatar
YanbingJiang committed
33
34
35
36
                "--attention-backend",
                "intel_amx",
                "--disable-radix",
                "--trust-remote-code",
DiweiSun's avatar
DiweiSun committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
            ]
            full_args = common_args + (extra_args or [])

            model = test_func(self)
            prefill_latency, decode_throughput, decode_latency = run_bench_one_batch(
                model, full_args
            )

            print(f"{model=}")
            print(f"{prefill_latency=}")
            print(f"{decode_throughput=}")
            print(f"{decode_latency=}")

            if is_in_ci() and min_throughput is not None:
                self.assertGreater(decode_throughput, min_throughput)

        return wrapper
YanbingJiang's avatar
YanbingJiang committed
54

DiweiSun's avatar
DiweiSun committed
55
    return decorator
YanbingJiang's avatar
YanbingJiang committed
56

DiweiSun's avatar
DiweiSun committed
57
58
59

class TestIntelAMXAttnBackend(CustomTestCase):

60
    @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=10)
DiweiSun's avatar
DiweiSun committed
61
62
63
    def test_latency_mla_model(self):
        return DEFAULT_MLA_MODEL_NAME_FOR_TEST

64
    @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=40)
DiweiSun's avatar
DiweiSun committed
65
66
67
    def test_latency_default_model(self):
        return DEFAULT_MODEL_NAME_FOR_TEST

68
    @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=150)
DiweiSun's avatar
DiweiSun committed
69
70
71
    def test_latency_fp8_qwen(self):
        return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8

72
    @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=50)
DiweiSun's avatar
DiweiSun committed
73
74
75
    def test_latency_fp8_moe_model(self):
        return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE

76
77
78
79
    @intel_amx_benchmark(
        extra_args=["--batch-size", "4", "--quantization", "w8a8_int8"],
        min_throughput=100,
    )
DiweiSun's avatar
DiweiSun committed
80
81
82
83
84
    def test_latency_w8a8_default_model(self):
        return DEFAULT_MODEL_NAME_FOR_TEST_W8A8

    @intel_amx_benchmark(
        extra_args=[
85
86
            "--batch-size",
            "4",
DiweiSun's avatar
DiweiSun committed
87
88
89
90
91
92
93
94
95
96
97
98
99
            "--quantization",
            "w8a8_int8",
            "--mem-fraction-static",
            "0.9",
            "--max-total-tokens",
            "65536",
            "--tp",
            "6",
        ],
        min_throughput=100,
    )
    def test_latency_w8a8_moe_model(self):
        return DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE
YanbingJiang's avatar
YanbingJiang committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    def test_mmlu(self):
        model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
        base_url = DEFAULT_URL_FOR_TEST
        process = popen_launch_server(
            model,
            base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--attention-backend",
                "intel_amx",
                "--mem-fraction-static",
                "0.05",
                "--disable-radix",
                "--trust-remote-code",
                "--disable-overlap-schedule",
            ],
        )

        try:
            args = SimpleNamespace(
                base_url=base_url,
                model=model,
                eval_name="mmlu",
                num_examples=64,
                num_threads=32,
            )
            metrics = run_eval(args)
DiweiSun's avatar
DiweiSun committed
128
129
            if is_in_ci():
                self.assertGreater(metrics["score"], 0.45)
YanbingJiang's avatar
YanbingJiang committed
130
131
132
133
134
135
        finally:
            kill_process_tree(process.pid)


if __name__ == "__main__":
    unittest.main()