test_dp_attention.py 3.91 KB
Newer Older
Ke Bao's avatar
Ke Bao committed
1
2
3
import unittest
from types import SimpleNamespace

4
5
import requests

6
from sglang.srt.utils import kill_process_tree
7
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
Ke Bao's avatar
Ke Bao committed
8
9
10
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_MLA_MODEL_NAME_FOR_TEST,
11
12
    DEFAULT_MODEL_NAME_FOR_TEST_MLA,
    DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
Ke Bao's avatar
Ke Bao committed
13
14
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
15
    CustomTestCase,
16
    is_in_amd_ci,
Ke Bao's avatar
Ke Bao committed
17
18
19
20
    popen_launch_server,
)


21
class TestDPAttentionDP2TP2(CustomTestCase):
Ke Bao's avatar
Ke Bao committed
22
23
24
25
26
27
28
29
30
31
32
33
34
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--trust-remote-code",
                "--tp",
                "2",
                "--enable-dp-attention",
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36
                "--dp",
                "2",
37
38
39
                "--enable-torch-compile",
                "--torch-compile-max-bs",
                "2",
Ke Bao's avatar
Ke Bao committed
40
41
42
43
44
            ],
        )

    @classmethod
    def tearDownClass(cls):
45
        kill_process_tree(cls.process.pid)
Ke Bao's avatar
Ke Bao committed
46
47
48
49
50
51
52
53
54
55
56

    def test_mmlu(self):
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="mmlu",
            num_examples=64,
            num_threads=32,
        )

        metrics = run_eval(args)
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
        print(f"{metrics=}")
        self.assertGreater(metrics["score"], 0.5)
Ke Bao's avatar
Ke Bao committed
59
60
61
62
63
64
65
66
67
68
69

    def test_mgsm_en(self):
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="mgsm_en",
            num_examples=None,
            num_threads=1024,
        )

        metrics = run_eval(args)
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
        print(f"{metrics=}")
        self.assertGreater(metrics["score"], 0.8)
72
73


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
        cls.base_url = DEFAULT_URL_FOR_TEST
        other_args = [
            "--trust-remote-code",
            "--disable-radix",
            "--speculative-algorithm",
            "EAGLE",
            "--speculative-num-steps",
            "2",
            "--speculative-eagle-topk",
            "4",
            "--speculative-num-draft-tokens",
            "4",
            "--speculative-draft",
            DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
            "--tp-size",
            "2",
            "--enable-dp-attention",
            "--dp-size",
            "2",
        ]
        if not is_in_amd_ci():
            other_args += ["--mem-frac", "0.7"]
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=other_args,
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_gsm8k(self):
        requests.get(self.base_url + "/flush_cache")

        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(metrics)

        self.assertGreater(metrics["accuracy"], 0.60)

        server_info = requests.get(self.base_url + "/get_server_info")
        avg_spec_accept_length = server_info.json()["internal_states"][0][
            "avg_spec_accept_length"
        ]
        print(
            f"###test_gsm8k (deepseek-v3 mtp + dp):\n"
            f"accuracy={metrics['accuracy']=:.3f}\n"
            f"{avg_spec_accept_length=:.3f}\n"
        )
        self.assertGreater(avg_spec_accept_length, 2.5)


140
141
if __name__ == "__main__":
    unittest.main()