test_random.py 971 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import collections

import pytest
import requests


@pytest.mark.integration
def test_random_distribution(mock_workers, router_manager):
    procs, urls, ids = mock_workers(n=4)
    rh = router_manager.start_router(worker_urls=urls, policy="random")

    counts = collections.Counter()
    N = 200
    with requests.Session() as s:
        for i in range(N):
            r = s.post(
                f"{rh.url}/v1/completions",
                json={
                    "model": "test-model",
                    "prompt": f"p{i}",
                    "max_tokens": 1,
                    "stream": False,
                },
            )
            assert r.status_code == 200
            wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
            counts[wid] += 1

    # simple statistical tolerance: each worker should be within ±50% of mean
    mean = N / len(ids)
    for wid in ids:
        assert 0.5 * mean <= counts[wid] <= 1.5 * mean, counts