test_deepseek_v3_mtp.py 3.22 KB
Newer Older
1
2
3
4
5
6
7
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
8
from sglang.test.send_one import BenchArgs, send_one_prompt
9
10
11
12
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
13
    is_in_amd_ci,
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    is_in_ci,
    popen_launch_server,
    write_github_step_summary,
)

FULL_DEEPSEEK_V3_MODEL_PATH = "deepseek-ai/DeepSeek-V3-0324"


class TestDeepseekV3MTP(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = FULL_DEEPSEEK_V3_MODEL_PATH
        cls.base_url = DEFAULT_URL_FOR_TEST
        other_args = [
            "--tp",
            "8",
            "--trust-remote-code",
            "--speculative-algorithm",
            "EAGLE",
            "--speculative-num-steps",
34
            "3",
35
            "--speculative-eagle-topk",
36
            "1",
37
            "--speculative-num-draft-tokens",
38
            "4",
39
        ]
40
        if not is_in_amd_ci():
kk's avatar
kk committed
41
            other_args += ["--mem-frac", "0.7"]
42
43
44
45
46
47
48
49
50
51
52
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=other_args,
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

53
54
55
    def test_a_gsm8k(
        self,
    ):  # Append an "a" to make this test run first (alphabetically) to warm up the server
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        requests.get(self.base_url + "/flush_cache")

        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
        )
        metrics = run_eval_few_shot_gsm8k(args)
        print(f"{metrics=}")

        server_info = requests.get(self.base_url + "/get_server_info")
71
72
73
        avg_spec_accept_length = server_info.json()["internal_states"][0][
            "avg_spec_accept_length"
        ]
74
75
        print(f"{avg_spec_accept_length=}")

76
77
        if is_in_ci():
            write_github_step_summary(
78
                f"### test_gsm8k (deepseek-v3 mtp)\n"
79
80
81
                f'{metrics["accuracy"]=:.3f}\n'
                f"{avg_spec_accept_length=:.2f}\n"
            )
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
            self.assertGreater(metrics["accuracy"], 0.935)
            self.assertGreater(avg_spec_accept_length, 2.9)

    def test_bs_1_speed(self):
        args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
        acc_length, speed = send_one_prompt(args)

        print(f"{acc_length=:.2f} {speed=:.2f}")

        if is_in_ci():
            write_github_step_summary(
                f"### test_bs_1_speed (deepseek-v3 mtp)\n"
                f"{acc_length=:.2f}\n"
                f"{speed=:.2f} token/s\n"
            )
97
            if is_in_amd_ci():
98
99
100
                self.assertGreater(acc_length, 2.8)
            else:
                self.assertGreater(acc_length, 2.9)
101
            if is_in_amd_ci():
102
103
                self.assertGreater(speed, 15)
            else:
104
                self.assertGreater(speed, 130)
105

106
107
108

if __name__ == "__main__":
    unittest.main()