test_expert_distribution.py 4.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import csv
import glob
import os
import unittest

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
    DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)


class TestExpertDistribution(unittest.TestCase):
    def setUp(self):
        # Clean up any existing expert distribution files before each test
        for f in glob.glob("expert_distribution_*.csv"):
            os.remove(f)

    def tearDown(self):
        # Clean up any expert distribution files after each test
        for f in glob.glob("expert_distribution_*.csv"):
            os.remove(f)

    def test_expert_distribution_record(self):
        """Test expert distribution record endpoints"""
        process = popen_launch_server(
31
32
            # The feature is only implemented in deepseek_v2.py
            "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
33
34
            DEFAULT_URL_FOR_TEST,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
35
36
37
            other_args=[
                "--trust-remote-code",
            ],
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        )

        try:
            # Start recording
            response = requests.post(
                f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
            )
            self.assertEqual(response.status_code, 200)

            # Make some requests to generate expert distribution data
            response = requests.post(
                f"{DEFAULT_URL_FOR_TEST}/generate",
                json={
                    "text": "The capital of France is",
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 32,
                    },
                },
            )
            self.assertEqual(response.status_code, 200)

            # Stop recording
            response = requests.post(
                f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
            )
            self.assertEqual(response.status_code, 200)

            # Dump the recorded data
            response = requests.post(
                f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
            )
            self.assertEqual(response.status_code, 200)

            # Verify the dumped file exists and has correct format
            csv_files = glob.glob("expert_distribution_*.csv")
            self.assertEqual(
75
76
77
                len(csv_files),
                1,
                f"Expected exactly one expert distribution CSV file {csv_files=}",
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            )

            # Check CSV file format
            with open(csv_files[0], "r") as f:
                csv_reader = csv.reader(f)

                # Check header
                header = next(csv_reader)
                self.assertEqual(
                    header,
                    ["layer_id", "expert_id", "count"],
                    "CSV header should be 'layer_id,expert_id,count'",
                )

                # Check data rows
                rows = list(csv_reader)
                self.assertGreater(len(rows), 0, "CSV file should contain data rows")

                for row in rows:
                    # Verify each row has 3 columns
                    self.assertEqual(
                        len(row),
                        3,
                        "Each row should have layer_id, expert_id and count",
                    )

                    # Verify data types
                    layer_id, expert_id, count = row
                    self.assertTrue(
107
108
109
110
111
112
113
114
115
                        layer_id.isdigit(),
                        f"layer_id should be an integer {row=} {rows=}",
                    )
                    self.assertTrue(
                        expert_id.isdigit(),
                        f"expert_id should be an integer {row=} {rows=}",
                    )
                    self.assertTrue(
                        count.isdigit(), f"count should be an integer {row=} {rows=}"
116
117
118
119
120
121
122
123
                    )

        finally:
            kill_process_tree(process.pid)


if __name__ == "__main__":
    unittest.main()