test_moe_deepep.py 3.55 KB
Newer Older
1
2
import json
import os
3
4
5
import unittest
from types import SimpleNamespace

6
from sglang.srt.environ import envs
7
8
9
10
11
12
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
    DEFAULT_MLA_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
13
    CustomTestCase,
14
15
16
17
    popen_launch_server,
)


18
19
20
21
22
23
24
25
26
27
28
29
30
class TestPureTP(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--trust-remote-code",
                "--tp",
                "2",
31
32
                "--moe-a2a-backend",
                "deepep",
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
                "--disable-cuda-graph",
            ],
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_mmlu(self):
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="mmlu",
            num_examples=64,
            num_threads=32,
        )

        metrics = run_eval(args)
        self.assertGreater(metrics["score"], 0.5)


class TestDPAttn(unittest.TestCase):
55
56
57
58
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False):
            cls.process = popen_launch_server(
                cls.model,
                cls.base_url,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=[
                    "--trust-remote-code",
                    "--tp",
                    "2",
                    "--dp",
                    "2",
                    "--enable-dp-attention",
                    "--moe-a2a-backend",
                    "deepep",
                    "--deepep-mode",
                    "normal",
                    "--disable-cuda-graph",
                    # Test custom config
                    "--deepep-config",
                    json.dumps(
                        {
                            "normal_dispatch": {
                                "num_sms": 20,
                                "num_max_nvl_chunked_send_tokens": 16,
                                "num_max_nvl_chunked_recv_tokens": 256,
                                "num_max_rdma_chunked_send_tokens": 6,
                                "num_max_rdma_chunked_recv_tokens": 128,
                            },
                            "normal_combine": {
                                "num_sms": 20,
                                "num_max_nvl_chunked_send_tokens": 6,
                                "num_max_nvl_chunked_recv_tokens": 256,
                                "num_max_rdma_chunked_send_tokens": 6,
                                "num_max_rdma_chunked_recv_tokens": 128,
                            },
                        }
                    ),
                ],
            )
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def test_mmlu(self):
        args = SimpleNamespace(
            base_url=self.base_url,
            model=self.model,
            eval_name="mmlu",
            num_examples=64,
            num_threads=32,
        )

        metrics = run_eval(args)
        self.assertGreater(metrics["score"], 0.5)


if __name__ == "__main__":
    unittest.main()