test_pd_routing.py 3.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import collections
import concurrent.futures
import time

import pytest
import requests


@pytest.mark.integration
def test_pd_power_of_two_decode_attribution(router_manager, mock_workers):
    # Start two prefill and three decode mock workers via fixture
    _, prefill_urls_raw, prefill_ids = mock_workers(n=2)
    _, decode_urls_raw, decode_ids_list = mock_workers(n=3)
    prefill_urls = [(u, None) for u in prefill_urls_raw]
    decode_urls = list(decode_urls_raw)
    decode_ids = set(decode_ids_list)

    rh = router_manager.start_router(
        policy="power_of_two",
        pd_disaggregation=True,
        prefill_urls=prefill_urls,
        decode_urls=decode_urls,
        extra={"worker_startup_check_interval": 1},
    )

    counts = collections.Counter()
    with requests.Session() as s:
        for i in range(30):
            r = s.post(
                f"{rh.url}/v1/completions",
                json={
                    "model": "test-model",
                    "prompt": f"p{i}",
                    "max_tokens": 1,
                    "stream": False,
                },
            )
            assert r.status_code == 200
            wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
            assert wid in decode_ids
            counts[wid] += 1

    assert sum(1 for v in counts.values() if v > 0) >= 2


@pytest.mark.integration
def test_pd_power_of_two_skews_to_faster_decode(router_manager, mock_workers):
    # Start two prefill workers (fast)
    _, prefill_urls_raw, _ = mock_workers(n=2)

    # Start two decode workers: one slow, one fast
    _, [decode_slow_url], [slow_id] = mock_workers(
        n=1, args=["--latency-ms", "300"]
    )  # slower decode
    _, [decode_fast_url], [fast_id] = mock_workers(n=1)
    decode_urls_raw = [decode_slow_url, decode_fast_url]

    prefill_urls = [(u, None) for u in prefill_urls_raw]
    decode_urls = list(decode_urls_raw)

    rh = router_manager.start_router(
        policy="power_of_two",
        pd_disaggregation=True,
        prefill_urls=prefill_urls,
        decode_urls=decode_urls,
        extra={"worker_startup_check_interval": 1},
    )

    def _prime_call(i):
        try:
            requests.post(
                f"{rh.url}/v1/completions",
                json={
                    "model": "test-model",
                    "prompt": f"warm-{i}",
                    "max_tokens": 1,
                    "stream": False,
                },
                timeout=8,
            )
        except Exception:
            pass

    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
        list(ex.map(_prime_call, range(128)))
    time.sleep(2)

    def _direct_decode_load(i):
        try:
            requests.post(
                f"{decode_slow_url}/v1/completions",
                json={
                    "model": "test-model",
                    "prompt": f"bg-{i}",
                    "max_tokens": 1,
                    "stream": False,
                },
                timeout=8,
            )
        except Exception:
            pass

    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
        list(ex.map(_direct_decode_load, range(128)))
    time.sleep(1)

    def call(i):
        r = requests.post(
            f"{rh.url}/v1/completions",
            json={
                "model": "test-model",
                "prompt": f"p{i}",
                "max_tokens": 1,
                "stream": False,
            },
            timeout=8,
        )
        assert r.status_code == 200
        return r.headers.get("X-Worker-Id") or r.json().get("worker_id")

    counts = collections.Counter()
    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
        for wid in ex.map(call, range(200)):
            counts[wid] += 1

    assert counts[slow_id] < counts[fast_id], counts