disagg_test.py 3.89 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import subprocess
import sys
import time
from subprocess import Popen

import pytest
import requests
import torch


# Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True)
def setup_servers():
    if torch.cuda.device_count() < 4:
        pytest.skip("Skipping test: fewer than 4 GPUs available")

    # Set up environment variables
    VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
                                           shell=True).decode().strip()
    os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP

    # Start prefill instance
    prefill_cmd = [
        sys.executable,
        "-m",
        "vllm.entrypoints.openai.api_server",
        "--model",
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
        "--port",
        "8100",
        "--gpu-memory-utilization",
        "0.5",
        "--max-model-len",
        "1000",
        "--kv-transfer-config",
        '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
        '"kv_rank":0,"kv_parallel_size":2}',
    ]
    prefill_env = os.environ.copy()
    prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
    prefill_proc = Popen(prefill_cmd, env=prefill_env)

    # Start decode instance
    decode_cmd = [
        sys.executable,
        "-m",
        "vllm.entrypoints.openai.api_server",
        "--model",
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
        "--port",
        "8200",
        "--gpu-memory-utilization",
        "0.5",
        "--max-model-len",
        "1000",
        "--kv-transfer-config",
        '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
        '"kv_rank":1,"kv_parallel_size":2}',
    ]
    decode_env = os.environ.copy()
    decode_env["CUDA_VISIBLE_DEVICES"] = "1"
    decode_proc = Popen(decode_cmd, env=decode_env)

    # Wait for servers to be ready
    assert wait_for_server(8100), "Prefill server did not start in time"
    assert wait_for_server(8200), "Decode server did not start in time"

    # Yield to the test function and handle teardown after tests
    yield

    # Cleanup: kill the processes
    prefill_proc.terminate()
    decode_proc.terminate()

    # Additional cleanup if needed
    prefill_proc.wait()
    decode_proc.wait()


# Helper function to wait for server
def wait_for_server(port, timeout=240):
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            response = requests.get(f"http://localhost:{port}/v1/completions")
            if response.status_code in [200, 405]:
                return True
        except requests.ConnectionError:
            time.sleep(1)
    return False


# Test function to send curl requests and validate responses
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
def test_disaggregated_prefilling(prompt):
    # Send to prefill
    response = requests.post("http://localhost:8100/v1/completions",
                             headers={"Content-Type": "application/json"},
                             json={
                                 "model":
                                 "meta-llama/Meta-Llama-3.1-8B-Instruct",
                                 "prompt": prompt,
                                 "max_tokens": 1,
                                 "temperature": 0
                             })
    assert response.status_code == 200

    # Send to decode
    response = requests.post("http://localhost:8200/v1/completions",
                             headers={"Content-Type": "application/json"},
                             json={
                                 "model":
                                 "meta-llama/Meta-Llama-3.1-8B-Instruct",
                                 "prompt": prompt,
                                 "max_tokens": 10,
                                 "temperature": 0
                             })
    assert response.status_code == 200