test_expert_distribution.py 3.73 KB
Newer Older
1
import os
2
import tempfile
3
import unittest
4
from pathlib import Path
5
6

import requests
7
import torch
8
9
10
11
12
13

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
14
    CustomTestCase,
15
16
17
18
    popen_launch_server,
)


19
class TestExpertDistribution(CustomTestCase):
20
    def test_expert_distribution_record(self):
21
22
23
24
25
26
27
28
29
30
31
32
33
        # TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
        for info in [
            dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
            dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
            dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
            # TODO enable in next PR
            # dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
            # dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
        ]:
            with self.subTest(info=info):
                self._execute_core(**info)

    def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
34
        """Test expert distribution record endpoints"""
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        with tempfile.TemporaryDirectory() as tmp_dir:
            os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir

            process = popen_launch_server(
                model_path,
                DEFAULT_URL_FOR_TEST,
                timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
                other_args=[
                    "--trust-remote-code",
                    "--tp-size",
                    str(tp_size),
                    "--expert-distribution-recorder-mode",
                    mode,
                    "--disable-cuda-graph",
                    "--disable-overlap-schedule",
                ],
51
52
            )

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
            try:
                # Start recording
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
                )
                self.assertEqual(response.status_code, 200)

                # Make some requests to generate expert distribution data
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/generate",
                    json={
                        "text": "The capital of France is",
                        "sampling_params": {
                            "temperature": 0,
                            "max_new_tokens": 32,
                        },
69
                    },
70
71
                )
                self.assertEqual(response.status_code, 200)
72

73
74
75
76
77
                # Stop recording
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
                )
                self.assertEqual(response.status_code, 200)
78

79
80
81
                # Dump the recorded data
                response = requests.post(
                    f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
82
                )
83
                self.assertEqual(response.status_code, 200)
84
85

                # Check data rows
86
87
88
89
                data = torch.load(
                    list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True
                )
                print(f"{data=}")
90

91
92
93
94
95
96
                if mode in ["per_pass", "per_token"]:
                    self.assertGreater(len(data), 0, "Should contain data rows")
                else:
                    logical_count = data["logical_count"]
                    print(f"{logical_count.sum()=} {logical_count=}")
                    self.assertTrue(logical_count.sum() > 0)
97

98
99
            finally:
                kill_process_tree(process.pid)
100
101
102
103


if __name__ == "__main__":
    unittest.main()