utils.py 6.34 KB
Newer Older
1
2
3
4
import os
import subprocess
import sys
import time
5
6
import warnings
from contextlib import contextmanager
7
8
from pathlib import Path
from typing import Any, Dict, List
9

10
import openai
11
12
13
14
15
import ray
import requests

from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
16
from vllm.entrypoints.openai.cli_args import make_arg_parser
17
18
from vllm.utils import get_open_port, is_hip

19
20
21
22
23
24
25
26
27
28
29
30
31
if is_hip():
    from amdsmi import (amdsmi_get_gpu_vram_usage,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down)

    @contextmanager
    def _nvml():
        try:
            amdsmi_init()
            yield
        finally:
            amdsmi_shut_down()
else:
32
    from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
33
34
35
36
37
38
39
40
41
42
                        nvmlInit, nvmlShutdown)

    @contextmanager
    def _nvml():
        try:
            nvmlInit()
            yield
        finally:
            nvmlShutdown()

43

44
45
VLLM_PATH = Path(__file__).parent.parent
"""Path to root of the vLLM repository."""
46
47


48
49
class RemoteOpenAIServer:
    DUMMY_API_KEY = "token-abc123"  # vLLM's OpenAI server does not need API key
50
51
    MAX_SERVER_START_WAIT_S = 600  # wait for server to start for 60 seconds

52
    def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
53
54
55
56
57
58
59
60
61
62
63
64
        if auto_port:
            if "-p" in cli_args or "--port" in cli_args:
                raise ValueError("You have manually specified the port"
                                 "when `auto_port=True`.")

            cli_args = cli_args + ["--port", str(get_open_port())]

        parser = make_arg_parser()
        args = parser.parse_args(cli_args)
        self.host = str(args.host or 'localhost')
        self.port = int(args.port)

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        env = os.environ.copy()
        # the current process might initialize cuda,
        # to be safe, we should use spawn method
        env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
        self.proc = subprocess.Popen(
            [sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
            cli_args,
            env=env,
            stdout=sys.stdout,
            stderr=sys.stderr)
        self._wait_for_server(url=self.url_for("health"),
                              timeout=self.MAX_SERVER_START_WAIT_S)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.proc.terminate()

    def _wait_for_server(self, *, url: str, timeout: float):
        # run health check
        start = time.time()
        while True:
            try:
                if requests.get(url).status_code == 200:
                    break
            except Exception as err:
                result = self.proc.poll()
                if result is not None and result != 0:
                    raise RuntimeError("Server exited unexpectedly.") from err

                time.sleep(0.5)
                if time.time() - start > timeout:
                    raise RuntimeError(
                        "Server failed to start in time.") from err
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    @property
    def url_root(self) -> str:
        return f"http://{self.host}:{self.port}"

    def url_for(self, *parts: str) -> str:
        return self.url_root + "/" + "/".join(parts)

    def get_client(self):
        return openai.OpenAI(
            base_url=self.url_for("v1"),
            api_key=self.DUMMY_API_KEY,
        )

    def get_async_client(self):
        return openai.AsyncOpenAI(
            base_url=self.url_for("v1"),
            api_key=self.DUMMY_API_KEY,
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        )


def init_test_distributed_environment(
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
    local_rank: int = -1,
) -> None:
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
    init_distributed_environment(
        world_size=pp_size * tp_size,
        rank=rank,
        distributed_init_method=distributed_init_method,
        local_rank=local_rank)
    ensure_model_parallel_initialized(tp_size, pp_size)


137
def multi_process_parallel(
138
139
    tp_size: int,
    pp_size: int,
140
    test_target: Any,
141
142
143
) -> None:
    # Using ray helps debugging the error when it failed
    # as compared to multiprocessing.
144
145
    # NOTE: We need to set working_dir for distributed tests,
    # otherwise we may get import errors on ray workers
146
147
148
149
150
151
152
153
154
155
    ray.init(runtime_env={"working_dir": VLLM_PATH})

    distributed_init_port = get_open_port()
    refs = []
    for rank in range(tp_size * pp_size):
        refs.append(
            test_target.remote(tp_size, pp_size, rank, distributed_init_port))
    ray.get(refs)

    ray.shutdown()
156
157
158
159
160
161
162
163
164
165
166
167


@contextmanager
def error_on_warning():
    """
    Within the scope of this context manager, tests will fail if any warning
    is emitted.
    """
    with warnings.catch_warnings():
        warnings.simplefilter("error")

        yield
168
169


170
@_nvml()
171
172
173
174
175
176
177
178
179
180
def wait_for_gpu_memory_to_clear(devices: List[int],
                                 threshold_bytes: int,
                                 timeout_s: float = 120) -> None:
    # Use nvml instead of pytorch to reduce measurement error from torch cuda
    # context.
    start_time = time.time()
    while True:
        output: Dict[int, str] = {}
        output_raw: Dict[int, float] = {}
        for device in devices:
181
182
183
184
185
186
187
188
            if is_hip():
                dev_handle = amdsmi_get_processor_handles()[device]
                mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
                gb_used = mem_info["vram_used"] / 2**10
            else:
                dev_handle = nvmlDeviceGetHandleByIndex(device)
                mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
                gb_used = mem_info.used / 2**30
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            output_raw[device] = gb_used
            output[device] = f'{gb_used:.02f}'

        print('gpu memory used (GB): ', end='')
        for k, v in output.items():
            print(f'{k}={v}; ', end='')
        print('')

        dur_s = time.time() - start_time
        if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
            print(f'Done waiting for free GPU memory on devices {devices=} '
                  f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
            break

        if dur_s >= timeout_s:
            raise ValueError(f'Memory of devices {devices=} not free after '
                             f'{dur_s=:.02f} ({threshold_bytes/2**30=})')

        time.sleep(5)