test_expert_distribution.py 3.7 KB
Newer Older
1
import os
2
import tempfile
3
import unittest
4
from pathlib import Path
5
6

import requests
7
import torch
8
9
10

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
11
    DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE,
12
13
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
14
    CustomTestCase,
15
16
17
18
    popen_launch_server,
)


19
class TestExpertDistribution(CustomTestCase):
20
    def test_expert_distribution_record(self):
21
22
23
24
25
        # TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
        for info in [
            dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
            dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
            dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
26
27
            dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
            dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
28
29
30
31
32
        ]:
            with self.subTest(info=info):
                self._execute_core(**info)

    def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
33
        """Test expert distribution record endpoints"""
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        with tempfile.TemporaryDirectory() as tmp_dir:
            os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir

            process = popen_launch_server(
                model_path,
                DEFAULT_URL_FOR_TEST,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=[
                    "--trust-remote-code",
                    "--tp-size",
                    str(tp_size),
                    "--expert-distribution-recorder-mode",
                    mode,
                    "--disable-cuda-graph",
                    "--disable-overlap-schedule",
                ],
50
51
            )

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
            try:
                # Start recording
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
                )
                self.assertEqual(response.status_code, 200)

                # Make some requests to generate expert distribution data
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/generate",
                    json={
                        "text": "The capital of France is",
                        "sampling_params": {
                            "temperature": 0,
                            "max_new_tokens": 32,
                        },
68
                    },
69
70
                )
                self.assertEqual(response.status_code, 200)
71

72
73
74
75
76
                # Stop recording
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
                )
                self.assertEqual(response.status_code, 200)
77

78
79
80
                # Dump the recorded data
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
81
                )
82
                self.assertEqual(response.status_code, 200)
83
84

                # Check data rows
85
86
87
88
                data = torch.load(
                    list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True
                )
                print(f"{data=}")
89

90
91
92
93
94
95
                if mode in ["per_pass", "per_token"]:
                    self.assertGreater(len(data), 0, "Should contain data rows")
                else:
                    logical_count = data["logical_count"]
                    print(f"{logical_count.sum()=} {logical_count=}")
                    self.assertTrue(logical_count.sum() > 0)
96

97
98
            finally:
                kill_process_tree(process.pid)
99
100
101
102


if __name__ == "__main__":
    unittest.main()