utils.py 44.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import contextlib
6
import copy
7
import functools
8
import importlib
9
import itertools
10
import json
11
import os
12
import random
13
import signal
14
15
import subprocess
import sys
16
import tempfile
17
import time
18
import warnings
19
from collections.abc import Callable, Iterable
20
from contextlib import ExitStack, contextmanager, suppress
21
from multiprocessing import Process
22
from pathlib import Path
23
from typing import Any, Literal
24
from unittest.mock import patch
25

26
import anthropic
27
import cloudpickle
28
import httpx
29
import openai
30
import pytest
31
import requests
32
import torch
33
import torch.nn.functional as F
34
from openai.types.completion import Completion
35
from typing_extensions import ParamSpec
36

37
import vllm.envs as envs
38
from tests.models.utils import TextTextLogprobs
39
40
41
42
from vllm.distributed import (
    ensure_model_parallel_initialized,
    init_distributed_environment,
)
43
from vllm.engine.arg_utils import AsyncEngineArgs
44
from vllm.entrypoints.cli.serve import ServeSubcommand
45
from vllm.model_executor.model_loader import get_model_loader
46
from vllm.platforms import current_platform
47
from vllm.transformers_utils.tokenizer import get_tokenizer
48
49
50
from vllm.utils import (
    FlexibleArgumentParser,
)
51
from vllm.utils.mem_constants import GB_bytes
52
from vllm.utils.network_utils import get_open_port
53
from vllm.utils.torch_utils import cuda_device_count_stateless
54

55
if current_platform.is_rocm():
56
57
58
59
60
61
    from amdsmi import (
        amdsmi_get_gpu_vram_usage,
        amdsmi_get_processor_handles,
        amdsmi_init,
        amdsmi_shut_down,
    )
62
63
64
65
66
67
68
69

    @contextmanager
    def _nvml():
        try:
            amdsmi_init()
            yield
        finally:
            amdsmi_shut_down()
70
elif current_platform.is_cuda():
71
72
73
74
75
76
    from vllm.third_party.pynvml import (
        nvmlDeviceGetHandleByIndex,
        nvmlDeviceGetMemoryInfo,
        nvmlInit,
        nvmlShutdown,
    )
77
78
79
80
81
82
83
84

    @contextmanager
    def _nvml():
        try:
            nvmlInit()
            yield
        finally:
            nvmlShutdown()
85
86
87
88
89
else:

    @contextmanager
    def _nvml():
        yield
90

91

92
93
VLLM_PATH = Path(__file__).parent.parent
"""Path to root of the vLLM repository."""
94
95


96
97
class RemoteOpenAIServer:
    DUMMY_API_KEY = "token-abc123"  # vLLM's OpenAI server does not need API key
98

99
    def _start_server(
100
        self, model: str, vllm_serve_args: list[str], env_dict: dict[str, str] | None
101
102
    ) -> None:
        """Subclasses override this method to customize server process launch"""
103
104
105
        env = os.environ.copy()
        # the current process might initialize cuda,
        # to be safe, we should use spawn method
106
        env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
107
108
        if env_dict is not None:
            env.update(env_dict)
109
110
        serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
        print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
111
        self.proc: subprocess.Popen = subprocess.Popen(
112
            serve_cmd,
113
114
115
116
117
            env=env,
            stdout=sys.stdout,
            stderr=sys.stderr,
        )

118
119
120
121
122
    def __init__(
        self,
        model: str,
        vllm_serve_args: list[str],
        *,
123
124
        env_dict: dict[str, str] | None = None,
        seed: int | None = 0,
125
        auto_port: bool = True,
126
127
        max_wait_seconds: float | None = None,
        override_hf_configs: dict[str, Any] | None = None,
128
    ) -> None:
129
        if auto_port:
130
            if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
131
132
133
                raise ValueError(
                    "You have manually specified the port when `auto_port=True`."
                )
134

135
136
137
            # No need for a port if using unix sockets
            if "--uds" not in vllm_serve_args:
                # Don't mutate the input args
138
                vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
139
140
        if seed is not None:
            if "--seed" in vllm_serve_args:
141
142
143
                raise ValueError(
                    f"You have manually specified the seed when `seed={seed}`."
                )
144
145

            vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
146

147
148
149
        if override_hf_configs is not None:
            vllm_serve_args = vllm_serve_args + [
                "--hf-overrides",
150
                json.dumps(override_hf_configs),
151
152
            ]

153
        parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
154
155
        subparsers = parser.add_subparsers(required=False, dest="subparser")
        parser = ServeSubcommand().subparser_init(subparsers)
156
        args = parser.parse_args(["--model", model, *vllm_serve_args])
157
158
159
160
161
        self.uds = args.uds
        if args.uds:
            self.host = None
            self.port = None
        else:
162
            self.host = str(args.host or "127.0.0.1")
163
            self.port = int(args.port)
164

165
        self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None
166

167
168
169
170
        # download the model before starting the server to avoid timeout
        is_local = os.path.isdir(model)
        if not is_local:
            engine_args = AsyncEngineArgs.from_cli_args(args)
171
172
173
174
175
            model_config = engine_args.create_model_config()
            load_config = engine_args.create_load_config()

            model_loader = get_model_loader(load_config)
            model_loader.download_model(model_config)
176

177
        self._start_server(model, vllm_serve_args, env_dict)
178
        max_wait_seconds = max_wait_seconds or 240
179
        self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds)
180
181
182
183
184
185

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.proc.terminate()
186
        try:
187
            self.proc.wait(8)
188
189
190
        except subprocess.TimeoutExpired:
            # force kill if needed
            self.proc.kill()
191

192
    def _poll(self) -> int | None:
193
194
195
        """Subclasses override this method to customize process polling"""
        return self.proc.poll()

196
197
198
    def _wait_for_server(self, *, url: str, timeout: float):
        # run health check
        start = time.time()
199
200
201
202
203
        client = (
            httpx.Client(transport=httpx.HTTPTransport(uds=self.uds))
            if self.uds
            else requests
        )
204
205
        while True:
            try:
206
                if client.get(url).status_code == 200:
207
                    break
208
209
210
211
212
            except Exception:
                # this exception can only be raised by requests.get,
                # which means the server is not ready yet.
                # the stack trace is not useful, so we suppress it
                # by using `raise from None`.
213
                result = self._poll()
214
                if result is not None and result != 0:
215
                    raise RuntimeError("Server exited unexpectedly.") from None
216
217
218

                time.sleep(0.5)
                if time.time() - start > timeout:
219
                    raise RuntimeError("Server failed to start in time.") from None
220
221
222

    @property
    def url_root(self) -> str:
223
224
225
226
227
        return (
            f"http://{self.uds.split('/')[-1]}"
            if self.uds
            else f"http://{self.host}:{self.port}"
        )
228
229
230
231

    def url_for(self, *parts: str) -> str:
        return self.url_root + "/" + "/".join(parts)

232
233
234
    def get_client(self, **kwargs):
        if "timeout" not in kwargs:
            kwargs["timeout"] = 600
235
236
237
        return openai.OpenAI(
            base_url=self.url_for("v1"),
            api_key=self.DUMMY_API_KEY,
238
239
            max_retries=0,
            **kwargs,
240
241
        )

242
    def get_async_client(self, **kwargs):
243
244
        if "timeout" not in kwargs:
            kwargs["timeout"] = 600
245
246
247
248
249
250
        return openai.AsyncOpenAI(
            base_url=self.url_for("v1"),
            api_key=self.DUMMY_API_KEY,
            max_retries=0,
            **kwargs,
        )
251
252


253
254
255
class RemoteOpenAIServerCustom(RemoteOpenAIServer):
    """Launch test server with custom child process"""

256
    def _start_server(
257
        self, model: str, vllm_serve_args: list[str], env_dict: dict[str, str] | None
258
    ) -> None:
259
        self.proc: Process = Process(
260
261
            target=self.child_process_fxn, args=(env_dict, model, vllm_serve_args)
        )  # type: ignore[assignment]
262
263
        self.proc.start()

264
265
266
267
    def __init__(
        self,
        model: str,
        vllm_serve_args: list[str],
268
        child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None],
269
        *,
270
271
        env_dict: dict[str, str] | None = None,
        seed: int | None = 0,
272
        auto_port: bool = True,
273
        max_wait_seconds: float | None = None,
274
    ) -> None:
275
276
277
        """Store custom child process function then invoke superclass
        constructor which will indirectly launch it."""
        self.child_process_fxn = child_process_fxn
278
279
280
281
282
283
284
285
        super().__init__(
            model=model,
            vllm_serve_args=vllm_serve_args,
            env_dict=env_dict,
            seed=seed,
            auto_port=auto_port,
            max_wait_seconds=max_wait_seconds,
        )
286

287
    def _poll(self) -> int | None:
288
289
290
291
292
293
294
295
296
297
        return self.proc.exitcode

    def __exit__(self, exc_type, exc_value, traceback):
        self.proc.terminate()
        self.proc.join(8)
        if self.proc.is_alive():
            # force kill if needed
            self.proc.kill()


298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
class RemoteAnthropicServer:
    DUMMY_API_KEY = "token-abc123"  # vLLM's Anthropic server does not need API key

    def __init__(
        self,
        model: str,
        vllm_serve_args: list[str],
        *,
        env_dict: dict[str, str] | None = None,
        seed: int | None = 0,
        auto_port: bool = True,
        max_wait_seconds: float | None = None,
    ) -> None:
        if auto_port:
            if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
                raise ValueError(
                    "You have manually specified the port when `auto_port=True`."
                )

            # Don't mutate the input args
            vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
        if seed is not None:
            if "--seed" in vllm_serve_args:
                raise ValueError(
                    f"You have manually specified the seed when `seed={seed}`."
                )

            vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]

        parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.")
        subparsers = parser.add_subparsers(required=False, dest="subparser")
        parser = ServeSubcommand().subparser_init(subparsers)
        args = parser.parse_args(["--model", model, *vllm_serve_args])
        self.host = str(args.host or "localhost")
        self.port = int(args.port)

        self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None

        # download the model before starting the server to avoid timeout
        is_local = os.path.isdir(model)
        if not is_local:
            engine_args = AsyncEngineArgs.from_cli_args(args)
            model_config = engine_args.create_model_config()
            load_config = engine_args.create_load_config()

            model_loader = get_model_loader(load_config)
            model_loader.download_model(model_config)

        env = os.environ.copy()
        # the current process might initialize cuda,
        # to be safe, we should use spawn method
        env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
        if env_dict is not None:
            env.update(env_dict)
        self.proc = subprocess.Popen(
            [
                sys.executable,
                "-m",
                "vllm.entrypoints.anthropic.api_server",
                model,
                *vllm_serve_args,
            ],
            env=env,
            stdout=sys.stdout,
            stderr=sys.stderr,
        )
        max_wait_seconds = max_wait_seconds or 240
        self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.proc.terminate()
        try:
            self.proc.wait(8)
        except subprocess.TimeoutExpired:
            # force kill if needed
            self.proc.kill()

    def _wait_for_server(self, *, url: str, timeout: float):
        # run health check
        start = time.time()
        while True:
            try:
                if requests.get(url).status_code == 200:
                    break
            except Exception:
                # this exception can only be raised by requests.get,
                # which means the server is not ready yet.
                # the stack trace is not useful, so we suppress it
                # by using `raise from None`.
                result = self.proc.poll()
                if result is not None and result != 0:
                    raise RuntimeError("Server exited unexpectedly.") from None

                time.sleep(0.5)
                if time.time() - start > timeout:
                    raise RuntimeError("Server failed to start in time.") from None

    @property
    def url_root(self) -> str:
        return f"http://{self.host}:{self.port}"

    def url_for(self, *parts: str) -> str:
        return self.url_root + "/" + "/".join(parts)

    def get_client(self, **kwargs):
        if "timeout" not in kwargs:
            kwargs["timeout"] = 600
        return anthropic.Anthropic(
            base_url=self.url_for(),
            api_key=self.DUMMY_API_KEY,
            max_retries=0,
            **kwargs,
        )

    def get_async_client(self, **kwargs):
        if "timeout" not in kwargs:
            kwargs["timeout"] = 600
        return anthropic.AsyncAnthropic(
            base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs
        )


423
424
425
426
def _test_completion(
    client: openai.OpenAI,
    model: str,
    prompt: str,
427
    token_ids: list[int],
428
429
430
431
):
    results = []

    # test with text prompt
432
433
434
435
436
437
438
439
440
441
442
443
    completion = client.completions.create(
        model=model, prompt=prompt, max_tokens=5, temperature=0.0
    )

    results.append(
        {
            "test": "single_completion",
            "text": completion.choices[0].text,
            "finish_reason": completion.choices[0].finish_reason,
            "usage": completion.usage,
        }
    )
444
445
446
447
448
449
450
451
452

    # test using token IDs
    completion = client.completions.create(
        model=model,
        prompt=token_ids,
        max_tokens=5,
        temperature=0.0,
    )

453
454
455
456
457
458
459
460
    results.append(
        {
            "test": "token_ids",
            "text": completion.choices[0].text,
            "finish_reason": completion.choices[0].finish_reason,
            "usage": completion.usage,
        }
    )
461
462

    # test seeded random sampling
463
464
465
466
467
468
469
470
471
472
473
474
    completion = client.completions.create(
        model=model, prompt=prompt, max_tokens=5, seed=33, temperature=1.0
    )

    results.append(
        {
            "test": "seeded_sampling",
            "text": completion.choices[0].text,
            "finish_reason": completion.choices[0].finish_reason,
            "usage": completion.usage,
        }
    )
475
476

    # test seeded random sampling with multiple prompts
477
478
479
480
481
482
483
484
485
486
487
488
    completion = client.completions.create(
        model=model, prompt=[prompt, prompt], max_tokens=5, seed=33, temperature=1.0
    )

    results.append(
        {
            "test": "seeded_sampling",
            "text": [choice.text for choice in completion.choices],
            "finish_reason": [choice.finish_reason for choice in completion.choices],
            "usage": completion.usage,
        }
    )
489
490
491
492
493
494
495
496
497

    # test simple list
    batch = client.completions.create(
        model=model,
        prompt=[prompt, prompt],
        max_tokens=5,
        temperature=0.0,
    )

498
499
500
501
502
503
504
    results.append(
        {
            "test": "simple_list",
            "text0": batch.choices[0].text,
            "text1": batch.choices[1].text,
        }
    )
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520

    # test streaming
    batch = client.completions.create(
        model=model,
        prompt=[prompt, prompt],
        max_tokens=5,
        temperature=0.0,
        stream=True,
    )

    texts = [""] * 2
    for chunk in batch:
        assert len(chunk.choices) == 1
        choice = chunk.choices[0]
        texts[choice.index] += choice.text

521
522
523
524
525
526
    results.append(
        {
            "test": "streaming",
            "texts": texts,
        }
    )
527
528
529
530

    return results


531
532
533
534
535
536
537
538
def _test_completion_close(
    client: openai.OpenAI,
    model: str,
    prompt: str,
):
    results = []

    # test with text prompt
539
540
541
    completion = client.completions.create(
        model=model, prompt=prompt, max_tokens=1, logprobs=5, temperature=0.0
    )
542

543
544
    logprobs = completion.choices[0].logprobs.top_logprobs[0]
    logprobs = {k: round(v, 2) for k, v in logprobs.items()}
545

546
547
548
549
550
551
    results.append(
        {
            "test": "completion_close",
            "logprobs": logprobs,
        }
    )
552
553
554
555

    return results


556
557
558
559
560
561
562
def _test_chat(
    client: openai.OpenAI,
    model: str,
    prompt: str,
):
    results = []

563
    messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
564
565

    # test with text prompt
566
567
568
569
570
571
572
573
574
575
576
577
    chat_response = client.chat.completions.create(
        model=model, messages=messages, max_tokens=5, temperature=0.0
    )

    results.append(
        {
            "test": "completion_close",
            "text": chat_response.choices[0].message.content,
            "finish_reason": chat_response.choices[0].finish_reason,
            "usage": chat_response.usage,
        }
    )
578
579
580
581

    return results


582
583
584
585
586
587
588
589
590
591
592
593
594
595
def _test_embeddings(
    client: openai.OpenAI,
    model: str,
    text: str,
):
    results = []

    # test with text input
    embeddings = client.embeddings.create(
        model=model,
        input=text,
        encoding_format="float",
    )

596
597
598
599
600
601
602
    results.append(
        {
            "test": "single_embedding",
            "embedding": embeddings.data[0].embedding,
            "usage": embeddings.usage,
        }
    )
603
604
605
606

    return results


607
608
609
610
611
612
613
614
def _test_image_text(
    client: openai.OpenAI,
    model_name: str,
    image_url: str,
):
    results = []

    # test pure text input
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "How do you feel today?"},
            ],
        }
    ]

    chat_completion = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0.0,
        max_tokens=1,
        logprobs=True,
        top_logprobs=5,
    )
632
633
634
635
636
    top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs

    for x in top_logprobs:
        x.logprob = round(x.logprob, 2)

637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
    results.append(
        {
            "test": "pure_text",
            "logprobs": top_logprobs,
        }
    )

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": image_url}},
                {"type": "text", "text": "What's in this image?"},
            ],
        }
    ]

    chat_completion = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0.0,
        max_tokens=1,
        logprobs=True,
        top_logprobs=5,
    )
662
663
    top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs

664
665
666
667
668
669
    results.append(
        {
            "test": "text_image",
            "logprobs": top_logprobs,
        }
    )
670
671
672
673

    return results


674
675
676
677
def compare_two_settings(
    model: str,
    arg1: list[str],
    arg2: list[str],
678
679
    env1: dict[str, str] | None = None,
    env2: dict[str, str] | None = None,
680
681
    *,
    method: str = "generate",
682
    max_wait_seconds: float | None = None,
683
) -> None:
684
    """
685
686
687
688
689
690
691
692
693
    Launch API server with two different sets of arguments/environments
    and compare the results of the API calls.

    Args:
        model: The model to test.
        arg1: The first set of arguments to pass to the API server.
        arg2: The second set of arguments to pass to the API server.
        env1: The first set of environment variables to pass to the API server.
        env2: The second set of environment variables to pass to the API server.
694
695
    """

696
697
698
699
700
701
702
703
704
    compare_all_settings(
        model,
        [arg1, arg2],
        [env1, env2],
        method=method,
        max_wait_seconds=max_wait_seconds,
    )


705
706
707
def compare_all_settings(
    model: str,
    all_args: list[list[str]],
708
    all_envs: list[dict[str, str] | None],
709
710
    *,
    method: str = "generate",
711
    max_wait_seconds: float | None = None,
712
) -> None:
713
714
715
716
717
718
719
720
721
    """
    Launch API server with several different sets of arguments/environments
    and compare the results of the API calls with the first set of arguments.
    Args:
        model: The model to test.
        all_args: A list of argument lists to pass to the API server.
        all_envs: A list of environment dictionaries to pass to the API server.
    """

722
    trust_remote_code = False
723
    for args in all_args:
724
725
726
727
728
        if "--trust-remote-code" in args:
            trust_remote_code = True
            break

    tokenizer_mode = "auto"
729
    for args in all_args:
730
731
732
733
734
735
736
737
738
        if "--tokenizer-mode" in args:
            tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
            break

    tokenizer = get_tokenizer(
        model,
        trust_remote_code=trust_remote_code,
        tokenizer_mode=tokenizer_mode,
    )
739

740
741
742
743
744
745
746
    can_force_load_format = True

    for args in all_args:
        if "--load-format" in args:
            can_force_load_format = False
            break

747
    prompt = "Hello, my name is"
748
    token_ids = tokenizer(prompt).input_ids
749
    ref_results: list = []
750
    for i, (args, env) in enumerate(zip(all_args, all_envs)):
751
752
753
754
755
756
757
758
759
        if can_force_load_format:
            # we are comparing the results and
            # usually we don't need real weights.
            # we force to use dummy weights by default,
            # and it should work for most of the cases.
            # if not, we can use VLLM_TEST_FORCE_LOAD_FORMAT
            # environment variable to force the load format,
            # e.g. in quantization tests.
            args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT]
760
        compare_results: list = []
761
        results = ref_results if i == 0 else compare_results
762
763
764
        with RemoteOpenAIServer(
            model, args, env_dict=env, max_wait_seconds=max_wait_seconds
        ) as server:
765
766
767
768
769
770
            client = server.get_client()

            # test models list
            models = client.models.list()
            models = models.data
            served_model = models[0]
771
772
773
774
775
776
777
            results.append(
                {
                    "test": "models_list",
                    "id": served_model.id,
                    "root": served_model.root,
                }
            )
778

779
780
            if method == "generate":
                results += _test_completion(client, model, prompt, token_ids)
781
782
            elif method == "generate_close":
                results += _test_completion_close(client, model, prompt)
783
784
            elif method == "generate_chat":
                results += _test_chat(client, model, prompt)
785
786
            elif method == "generate_with_image":
                results += _test_image_text(
787
788
789
                    client,
                    model,
                    "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
790
                )
791
792
793
            elif method == "encode":
                results += _test_embeddings(client, model, prompt)
            else:
794
                raise ValueError(f"Unknown method: {method}")
795

796
797
798
799
800
801
            if i > 0:
                # if any setting fails, raise an error early
                ref_args = all_args[0]
                ref_envs = all_envs[0]
                compare_args = all_args[i]
                compare_envs = all_envs[i]
802
                for ref_result, compare_result in zip(ref_results, compare_results):
803
804
805
                    ref_result = copy.deepcopy(ref_result)
                    compare_result = copy.deepcopy(compare_result)
                    if "embedding" in ref_result and method == "encode":
806
807
808
809
810
811
                        sim = F.cosine_similarity(
                            torch.tensor(ref_result["embedding"]),
                            torch.tensor(compare_result["embedding"]),
                            dim=0,
                        )
                        assert sim >= 0.999, (
812
                            f"Embedding for {model=} are not the same.\n"
813
814
                            f"cosine_similarity={sim}\n"
                        )
815
816
                        del ref_result["embedding"]
                        del compare_result["embedding"]
817
818
819
820
821
                    assert ref_result == compare_result, (
                        f"Results for {model=} are not the same.\n"
                        f"{ref_args=} {ref_envs=}\n"
                        f"{compare_args=} {compare_envs=}\n"
                        f"{ref_result=}\n"
822
823
                        f"{compare_result=}\n"
                    )
824
825


826
827
828
829
830
831
832
833
834
835
836
837
def init_test_distributed_environment(
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
    local_rank: int = -1,
) -> None:
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
    init_distributed_environment(
        world_size=pp_size * tp_size,
        rank=rank,
        distributed_init_method=distributed_init_method,
838
839
        local_rank=local_rank,
    )
840
841
842
    ensure_model_parallel_initialized(tp_size, pp_size)


843
def multi_process_parallel(
844
    monkeypatch: pytest.MonkeyPatch,
845
846
    tp_size: int,
    pp_size: int,
847
    test_target: Any,
848
) -> None:
849
850
    import ray

851
852
    # Using ray helps debugging the error when it failed
    # as compared to multiprocessing.
853
854
    # NOTE: We need to set working_dir for distributed tests,
    # otherwise we may get import errors on ray workers
855
856
857
858
859
860
    # NOTE: Force ray not to use gitignore file as excluding, otherwise
    # it will not move .so files to working dir.
    # So we have to manually add some of large directories
    os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1"
    ray.init(
        runtime_env={
861
            "working_dir": VLLM_PATH,
862
            "excludes": [
863
864
865
866
867
868
869
870
871
                "build",
                ".git",
                "cmake-build-*",
                "shellcheck",
                "dist",
                "ep_kernels_workspace",
            ],
        }
    )
872
873
874
875
876

    distributed_init_port = get_open_port()
    refs = []
    for rank in range(tp_size * pp_size):
        refs.append(
877
878
879
880
881
882
            test_target.remote(
                monkeypatch,
                tp_size,
                pp_size,
                rank,
                distributed_init_port,
883
884
            ),
        )
885
886
887
    ray.get(refs)

    ray.shutdown()
888
889
890


@contextmanager
891
def error_on_warning(category: type[Warning] = Warning):
892
893
    """
    Within the scope of this context manager, tests will fail if any warning
894
    of the given category is emitted.
895
896
    """
    with warnings.catch_warnings():
897
        warnings.filterwarnings("error", category=category)
898
899

        yield
900
901


902
903
904
905
906
907
908
909
910
911
def get_physical_device_indices(devices):
    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
    if visible_devices is None:
        return devices

    visible_indices = [int(x) for x in visible_devices.split(",")]
    index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
    return [index_mapping[i] for i in devices if i in index_mapping]


912
@_nvml()
913
914
915
def wait_for_gpu_memory_to_clear(
    *,
    devices: list[int],
916
917
    threshold_bytes: int | None = None,
    threshold_ratio: float | None = None,
918
919
    timeout_s: float = 120,
) -> None:
920
    assert threshold_bytes is not None or threshold_ratio is not None
921
922
    # Use nvml instead of pytorch to reduce measurement error from torch cuda
    # context.
923
    devices = get_physical_device_indices(devices)
924
925
    start_time = time.time()
    while True:
926
        output: dict[int, str] = {}
927
        output_raw: dict[int, tuple[float, float]] = {}
928
        for device in devices:
929
            if current_platform.is_rocm():
930
931
932
                dev_handle = amdsmi_get_processor_handles()[device]
                mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
                gb_used = mem_info["vram_used"] / 2**10
933
                gb_total = mem_info["vram_total"] / 2**10
934
935
936
937
            else:
                dev_handle = nvmlDeviceGetHandleByIndex(device)
                mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
                gb_used = mem_info.used / 2**30
938
939
                gb_total = mem_info.total / 2**30
            output_raw[device] = (gb_used, gb_total)
940
            output[device] = f"{gb_used:.02f}/{gb_total:.02f}"
941

942
        print("gpu memory used/total (GiB): ", end="")
943
        for k, v in output.items():
944
945
            print(f"{k}={v}; ", end="")
        print("")
946

947
948
        if threshold_bytes is not None:
            is_free = lambda used, total: used <= threshold_bytes / 2**30
949
            threshold = f"{threshold_bytes / 2**30} GiB"
950
951
952
953
        else:
            is_free = lambda used, total: used / total <= threshold_ratio
            threshold = f"{threshold_ratio:.2f}"

954
        dur_s = time.time() - start_time
955
        if all(is_free(used, total) for used, total in output_raw.values()):
956
957
958
959
            print(
                f"Done waiting for free GPU memory on devices {devices=} "
                f"({threshold=}) {dur_s=:.02f}"
            )
960
961
962
            break

        if dur_s >= timeout_s:
963
964
965
966
            raise ValueError(
                f"Memory of devices {devices=} not free after "
                f"{dur_s=:.02f} ({threshold=})"
            )
967
968

        time.sleep(5)
969
970


971
972
973
_P = ParamSpec("_P")


974
def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
975
976
977
    """Decorator to fork a new process for each test function.
    See https://github.com/vllm-project/vllm/issues/7053 for more details.
    """
978

979
    @functools.wraps(func)
980
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
981
982
983
984
        # Make the process the leader of its own process group
        # to avoid sending SIGTERM to the parent process
        os.setpgrp()
        from _pytest.outcomes import Skipped
985
986
987

        # Create a unique temporary file to store exception info from child
        # process. Use test function name and process ID to avoid collisions.
988
989
        with (
            tempfile.NamedTemporaryFile(
990
                delete=False,
991
                mode="w+b",
992
                prefix=f"vllm_test_{func.__name__}_{os.getpid()}_",
993
994
995
996
                suffix=".exc",
            ) as exc_file,
            ExitStack() as delete_after,
        ):
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
            exc_file_path = exc_file.name
            delete_after.callback(os.remove, exc_file_path)

            pid = os.fork()
            print(f"Fork a new process to run a test {pid}")
            if pid == 0:
                # Parent process responsible for deleting, don't delete
                # in child.
                delete_after.pop_all()
                try:
                    func(*args, **kwargs)
                except Skipped as e:
                    # convert Skipped to exit code 0
                    print(str(e))
                    os._exit(0)
                except Exception as e:
                    import traceback
1014

1015
1016
1017
1018
1019
1020
1021
                    tb_string = traceback.format_exc()

                    # Try to serialize the exception object first
                    exc_to_serialize: dict[str, Any]
                    try:
                        # First, try to pickle the actual exception with
                        # its traceback.
1022
                        exc_to_serialize = {"pickled_exception": e}
1023
1024
1025
1026
1027
                        # Test if it can be pickled
                        cloudpickle.dumps(exc_to_serialize)
                    except (Exception, KeyboardInterrupt):
                        # Fall back to string-based approach.
                        exc_to_serialize = {
1028
1029
1030
                            "exception_type": type(e).__name__,
                            "exception_msg": str(e),
                            "traceback": tb_string,
1031
1032
                        }
                    try:
1033
                        with open(exc_file_path, "wb") as f:
1034
1035
1036
1037
1038
1039
1040
                            cloudpickle.dump(exc_to_serialize, f)
                    except Exception:
                        # Fallback: just print the traceback.
                        print(tb_string)
                    os._exit(1)
                else:
                    os._exit(0)
1041
            else:
1042
1043
1044
                pgid = os.getpgid(pid)
                _pid, _exitcode = os.waitpid(pid, 0)
                # ignore SIGTERM signal itself
1045
                old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
1046
1047
1048
1049
1050
1051
1052
1053
                # kill all child processes
                os.killpg(pgid, signal.SIGTERM)
                # restore the signal handler
                signal.signal(signal.SIGTERM, old_signal_handler)
                if _exitcode != 0:
                    # Try to read the exception from the child process
                    exc_info = {}
                    if os.path.exists(exc_file_path):
1054
1055
1056
1057
                        with (
                            contextlib.suppress(Exception),
                            open(exc_file_path, "rb") as f,
                        ):
1058
1059
                            exc_info = cloudpickle.load(f)

1060
1061
1062
                    if (
                        original_exception := exc_info.get("pickled_exception")
                    ) is not None:
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
                        # Re-raise the actual exception object if it was
                        # successfully pickled.
                        assert isinstance(original_exception, Exception)
                        raise original_exception

                    if (original_tb := exc_info.get("traceback")) is not None:
                        # Use string-based traceback for fallback case
                        raise AssertionError(
                            f"Test {func.__name__} failed when called with"
                            f" args {args} and kwargs {kwargs}"
                            f" (exit code: {_exitcode}):\n{original_tb}"
                        ) from None

                    # Fallback to the original generic error
                    raise AssertionError(
                        f"function {func.__name__} failed when called with"
                        f" args {args} and kwargs {kwargs}"
1080
1081
                        f" (exit code: {_exitcode})"
                    ) from None
1082
1083

    return wrapper
1084
1085


1086
1087
def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]:
    """Decorator to spawn a new process for each test function."""
1088
1089
1090
1091

    @functools.wraps(f)
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
        # Check if we're already in a subprocess
1092
        if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
1093
1094
1095
1096
            # If we are, just run the function directly
            return f(*args, **kwargs)

        import torch.multiprocessing as mp
1097

1098
        with suppress(RuntimeError):
1099
            mp.set_start_method("spawn")
1100
1101
1102
1103
1104
1105

        # Get the module
        module_name = f.__module__

        # Create a process with environment variable set
        env = os.environ.copy()
1106
        env["RUNNING_IN_SUBPROCESS"] = "1"
1107
1108
1109
1110
1111
1112
1113

        with tempfile.TemporaryDirectory() as tempdir:
            output_filepath = os.path.join(tempdir, "new_process.tmp")

            # `cloudpickle` allows pickling complex functions directly
            input_bytes = cloudpickle.dumps((f, output_filepath))

1114
1115
1116
1117
1118
            repo_root = str(VLLM_PATH.resolve())

            env = dict(env or os.environ)
            env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")

1119
1120
            cmd = [sys.executable, "-m", f"{module_name}"]

1121
1122
1123
            returned = subprocess.run(
                cmd, input=input_bytes, capture_output=True, env=env
            )
1124
1125
1126
1127
1128
1129

            # check if the subprocess is successful
            try:
                returned.check_returncode()
            except Exception as e:
                # wrap raised exception to provide more information
1130
1131
1132
                raise RuntimeError(
                    f"Error raised in subprocess:\n{returned.stderr.decode()}"
                ) from e
1133
1134
1135
1136
1137

    return wrapper


def create_new_process_for_each_test(
1138
    method: Literal["spawn", "fork"] | None = None,
1139
1140
1141
1142
) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
    """Creates a decorator that runs each test function in a new process.

    Args:
1143
        method: The process creation method. Can be either "spawn" or "fork".
1144
1145
               If not specified, it defaults to "spawn" on ROCm and XPU
               platforms and "fork" otherwise.
1146
1147
1148
1149
1150

    Returns:
        A decorator to run test functions in separate processes.
    """
    if method is None:
1151
1152
        use_spawn = current_platform.is_rocm() or current_platform.is_xpu()
        method = "spawn" if use_spawn else "fork"
1153

1154
    assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
1155
1156
1157
1158
1159
1160
1161

    if method == "fork":
        return fork_new_process_for_each_test

    return spawn_new_process_for_each_test


1162
def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
1163
1164
1165
    """
    Get a pytest mark, which skips the test if the GPU doesn't meet
    a minimum memory requirement in GB.
1166

1167
1168
    This can be leveraged via `@large_gpu_test` to skip tests in environments
    without enough resources, or called when filtering tests to run directly.
1169
1170
    """
    try:
1171
        if current_platform.is_cpu():
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
            memory_gb = 0
        else:
            memory_gb = current_platform.get_device_total_memory() / GB_bytes
    except Exception as e:
        warnings.warn(
            f"An error occurred when finding the available memory: {e}",
            stacklevel=2,
        )
        memory_gb = 0

1182
    return pytest.mark.skipif(
1183
        memory_gb < min_gb,
1184
        reason=f"Need at least {min_gb}GB GPU memory to run the test.",
1185
1186
    )

1187
1188
1189
1190
1191
1192
1193
1194

def large_gpu_test(*, min_gb: int):
    """
    Decorate a test to be skipped if no GPU is available or it does not have
    sufficient memory.

    Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
    """
1195
    mark = large_gpu_mark(min_gb)
1196

1197
    def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
1198
        return mark(f)
1199
1200
1201
1202

    return wrapper


1203
1204
1205
def multi_gpu_marks(*, num_gpus: int):
    """Get a collection of pytest marks to apply for `@multi_gpu_test`."""
    test_selector = pytest.mark.distributed(num_gpus=num_gpus)
1206
1207
1208
1209
1210
    test_skipif = pytest.mark.skipif(
        cuda_device_count_stateless() < num_gpus,
        reason=f"Need at least {num_gpus} GPUs to run the test.",
    )

1211
1212
1213
1214
1215
1216
1217
1218
1219
    return [test_selector, test_skipif]


def multi_gpu_test(*, num_gpus: int):
    """
    Decorate a test to be run only when multiple GPUs are available.
    """
    marks = multi_gpu_marks(num_gpus=num_gpus)

1220
    def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
1221
        func = create_new_process_for_each_test()(f)
1222
1223
1224
1225
        for mark in reversed(marks):
            func = mark(func)

        return func
1226
1227
1228
1229

    return wrapper


1230
async def completions_with_server_args(
1231
    prompts: list[str],
1232
    model_name: str,
1233
    server_cli_args: list[str],
1234
    num_logprobs: int | None,
1235
    max_wait_seconds: int = 240,
1236
    max_tokens: int | list = 5,
1237
) -> list[Completion]:
1238
    """Construct a remote OpenAI server, obtain an async client to the
1239
1240
1241
1242
1243
1244
1245
1246
1247
    server & invoke the completions API to obtain completions.

    Args:
      prompts: test prompts
      model_name: model to spin up on the vLLM server
      server_cli_args: CLI args for starting the server
      num_logprobs: Number of logprobs to report (or `None`)
      max_wait_seconds: timeout interval for bringing up server.
                        Default: 240sec
1248
1249
1250
      max_tokens: max_tokens value for each of the given input prompts.
        if only one max_token value is given, the same value is used
        for all the prompts.
1251
1252
1253

    Returns:
      OpenAI Completion instance
1254
    """
1255

1256
1257
1258
1259
1260
    if isinstance(max_tokens, int):
        max_tokens = [max_tokens] * len(prompts)

    assert len(max_tokens) == len(prompts)

1261
    outputs = None
1262
1263
1264
    with RemoteOpenAIServer(
        model_name, server_cli_args, max_wait_seconds=max_wait_seconds
    ) as server:
1265
        client = server.get_async_client()
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        outputs = [
            client.completions.create(
                model=model_name,
                prompt=[p],
                temperature=0,
                stream=False,
                max_tokens=max_tok,
                logprobs=num_logprobs,
            )
            for p, max_tok in zip(prompts, max_tokens)
        ]
1277
1278
        outputs = await asyncio.gather(*outputs)

1279
    assert outputs is not None, "Completion API call failed."
1280
1281
1282
1283

    return outputs


1284
def get_client_text_generations(completions: list[Completion]) -> list[str]:
1285
    """Extract generated tokens from the output of a
1286
    request made to an Open-AI-protocol completions endpoint.
1287
    """
1288
1289
    assert all([len(x.choices) == 1 for x in completions])
    return [x.choices[0].text for x in completions]
1290
1291
1292


def get_client_text_logprob_generations(
1293
1294
1295
    completions: list[Completion],
) -> list[TextTextLogprobs]:
    """Operates on the output of a request made to an Open-AI-protocol
1296
    completions endpoint; obtains top-rank logprobs for each token in
1297
    each {class}`SequenceGroup`
1298
    """
1299
    text_generations = get_client_text_generations(completions)
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
    text = "".join(text_generations)
    return [
        (
            text_generations,
            text,
            (None if x.logprobs is None else x.logprobs.top_logprobs),
        )
        for completion in completions
        for x in completion.choices
    ]
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320


def has_module_attribute(module_name, attribute_name):
    """
    Helper function to check if a module has a specific attribute.
    """
    try:
        module = importlib.import_module(module_name)
        return hasattr(module, attribute_name)
    except ImportError:
        return False
1321
1322
1323
1324


def get_attn_backend_list_based_on_platform() -> list[str]:
    if current_platform.is_cuda():
1325
        return ["FLASH_ATTN", "TRITON_ATTN", "TREE_ATTN"]
1326
    elif current_platform.is_rocm():
1327
        attn_backend_list = ["TRITON_ATTN"]
1328
1329
        try:
            import aiter  # noqa: F401
1330

1331
            attn_backend_list.append("FLASH_ATTN")
1332
        except Exception:
1333
            print("Skip FLASH_ATTN on ROCm as aiter is not installed")
1334
1335

        return attn_backend_list
1336
1337
    elif current_platform.is_xpu():
        return ["FLASH_ATTN", "TRITON_ATTN"]
1338
1339
    else:
        raise ValueError("Unsupported platform")
1340
1341
1342
1343
1344


@contextmanager
def override_cutlass_fp8_supported(value: bool):
    with patch(
1345
1346
1347
        "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
        return_value=value,
    ):
1348
        yield
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367


def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
    """
    Generate prompts which a bunch of assignments,
    then asking for the value of one of them.
    The prompt is just under 10k tokens; sliding window is 4k
    so the answer is outside sliding window, but should still be correct.
    Args:
        batch_size: number of prompts to generate
        ln_range: an argument to control the length of the prompt
    """
    prompts: list[str] = []
    answer: list[int] = []
    indices: list[int] = []
    random.seed(1)
    for _ in range(batch_size):
        idx = random.randint(30, 90)
        indices.append(idx)
1368
1369
1370
1371
        prompt = (
            "```python\n# We set a number of variables, "
            + f"x{idx} will be important later\n"
        )
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
        ln = random.randint(*ln_range)
        for k in range(30, ln):
            v = random.randint(10, 99)
            if k == idx:
                answer.append(v)
            prompt += f"x{k} = {v}\n"
        prompt += f"# Now, we check the value of x{idx}:\n"
        prompt += f"assert x{idx} == "
        prompts.append(prompt)
    return prompts, answer, indices


1384
1385
1386
def check_answers(
    indices: list[int], answer: list[int], outputs: list[str], accept_rate: float = 0.7
):
1387
1388
1389
1390
1391
1392
1393
1394
1395
    answer2 = [int(text[0:2].strip()) for text in outputs]
    print(list(zip(indices, zip(answer, answer2))))
    numok = 0
    for a1, a2 in zip(answer, answer2):
        if a1 == a2:
            numok += 1
    frac_ok = numok / len(answer)
    print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
    assert frac_ok >= accept_rate
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415


def flat_product(*iterables: Iterable[Any]):
    """
    Flatten lists of tuples of the cartesian product.
    Useful when we want to avoid nested tuples to allow
    test params to be unpacked directly from the decorator.

    Example:
    flat_product([(1, 2), (3, 4)], ["a", "b"]) ->
    [
      (1, 2, "a"),
      (1, 2, "b"),
      (3, 4, "a"),
      (3, 4, "b"),
    ]
    """
    for element in itertools.product(*iterables):
        normalized = (e if isinstance(e, tuple) else (e,) for e in element)
        yield tuple(itertools.chain(*normalized))