conftest.py 42.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12

# ruff: noqa

from tblib import pickling_support

# Install support for pickling exceptions so that we can nicely propagate
# failures from tests running in a subprocess.
# This should be run before any custom exception subclasses are defined.
pickling_support.install()

13
import http.server
14
import json
15
import math
16
import mimetypes
17
import os
18
import socket
19
import tempfile
20
21
import threading
from collections.abc import Generator
22
from contextlib import nullcontext
23
from enum import Enum
24
from typing import Any, Callable, TypedDict, TypeVar, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
25

26
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
import pytest
import torch
29
import torch.nn as nn
30
import torch.nn.functional as F
31
from huggingface_hub import snapshot_download
32
from PIL import Image
33
34
35
36
37
38
39
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BatchEncoding,
    BatchFeature,
)
40
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
41

42
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm import LLM, SamplingParams
44
from vllm.assets.audio import AudioAsset
45
from vllm.assets.image import ImageAsset
46
from vllm.assets.video import VideoAsset
47
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
48
from vllm.connections import global_http_connection
49
50
51
52
53
from vllm.distributed import (
    cleanup_dist_env_and_memory,
    init_distributed_environment,
    initialize_model_parallel,
)
54
from vllm.logger import init_logger
55
from vllm.logprobs import Logprob
56
from vllm.multimodal.utils import fetch_image
57
from vllm.outputs import RequestOutput
58
from vllm.sampling_params import BeamSearchParams
59
from vllm.transformers_utils.utils import maybe_model_redirect
60
61
from vllm.utils import set_default_torch_num_threads
from vllm.utils.collections import is_list_of
62

63
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
66
67
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
68
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
69

Cyrus Leung's avatar
Cyrus Leung committed
70
_M = TypeVar("_M")
71

72
_PromptMultiModalInput = list[_M] | list[list[_M]]
Cyrus Leung's avatar
Cyrus Leung committed
73
74

PromptImageInput = _PromptMultiModalInput[Image.Image]
75
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
76
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
77

78

79
def _read_prompts(filename: str) -> list[str]:
80
    with open(filename) as f:
81
82
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84


85
class ImageAssetPrompts(TypedDict):
86
87
    stop_sign: str
    cherry_blossom: str
88
89


90
class ImageTestAssets(list[ImageAsset]):
91
    def __init__(self) -> None:
92
93
94
95
96
97
        super().__init__(
            [
                ImageAsset("stop_sign"),
                ImageAsset("cherry_blossom"),
            ]
        )
98

99
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
100
101
102
103
104
105
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
106
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
107
108


109
110
class VideoAssetPrompts(TypedDict):
    baby_reading: str
111
112


113
class VideoTestAssets(list[VideoAsset]):
114
    def __init__(self) -> None:
115
116
117
118
119
        super().__init__(
            [
                VideoAsset("baby_reading"),
            ]
        )
120

121
122
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
123
124


125
class AudioAssetPrompts(TypedDict):
126
127
128
129
    mary_had_lamb: str
    winning_call: str


130
class AudioTestAssets(list[AudioAsset]):
131
    def __init__(self) -> None:
132
133
134
135
136
137
        super().__init__(
            [
                AudioAsset("mary_had_lamb"),
                AudioAsset("winning_call"),
            ]
        )
138

139
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
140
141
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

142

143
IMAGE_ASSETS = ImageTestAssets()
144
"""Singleton instance of {class}`ImageTestAssets`."""
145
VIDEO_ASSETS = VideoTestAssets()
146
"""Singleton instance of {class}`VideoTestAssets`."""
147
AUDIO_ASSETS = AudioTestAssets()
148
"""Singleton instance of {class}`AudioTestAssets`."""
149
150


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


171
172
173
174
175
176
177
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


178
179
180
181
182
183
184
185
186
187
188
189
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
190
    cleanup_dist_env_and_memory()
191
192


193
@pytest.fixture()
194
def should_do_global_cleanup_after_test(request) -> bool:
195
196
197
198
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
199

200
    return not request.node.get_closest_marker("skip_global_cleanup")
201
202


203
@pytest.fixture(autouse=True)
204
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
205
    yield
206
    if should_do_global_cleanup_after_test:
207
        cleanup_dist_env_and_memory()
208
209


210
211
212
213
214
215
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
216
@pytest.fixture
217
def example_prompts() -> list[str]:
218
219
    prompts = []
    for filename in _TEST_PROMPTS:
220
        prompts += _read_prompts(filename)
221
222
223
    return prompts


224
225
226
227
228
229
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


230
231
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
232

233
234
235
236
237
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


238
@pytest.fixture
239
def example_long_prompts() -> list[str]:
240
241
    prompts = []
    for filename in _LONG_PROMPTS:
242
        prompts += _read_prompts(filename)
243
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245


246
@pytest.fixture(scope="session")
247
def image_assets() -> ImageTestAssets:
248
249
250
    return IMAGE_ASSETS


251
@pytest.fixture(scope="session")
252
def video_assets() -> VideoTestAssets:
253
254
255
    return VIDEO_ASSETS


256
@pytest.fixture(scope="session")
257
def audio_assets() -> AudioTestAssets:
258
259
260
    return AUDIO_ASSETS


261
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
262
_R = TypeVar("_R")
263

Woosuk Kwon's avatar
Woosuk Kwon committed
264
265

class HfRunner:
266
    def get_default_device(self):
267
        from vllm.platforms import current_platform
268

269
        return "cpu" if current_platform.is_cpu() else current_platform.device_type
270

271
    def wrap_device(self, x: _T, device: str | None = None) -> _T:
272
        if x is None or isinstance(x, (bool,)):
273
274
            return x

275
        if device is None:
276
            device = self.device
277

278
279
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
280

281
282
283
284
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
285

Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
288
    def __init__(
        self,
        model_name: str,
289
        dtype: str = "auto",
290
        *,
291
        model_kwargs: dict[str, Any] | None = None,
292
        trust_remote_code: bool = True,
293
        is_sentence_transformer: bool = False,
294
        is_cross_encoder: bool = False,
295
        skip_tokenizer_init: bool = False,
296
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
297
        # Set this to avoid hanging issue
298
        default_torch_num_threads: int | None = None,
299
    ) -> None:
300
301
302
303
304
        init_ctx = (
            nullcontext()
            if default_torch_num_threads is None
            else set_default_torch_num_threads(default_torch_num_threads)
        )
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        with init_ctx:
            self._init(
                model_name=model_name,
                dtype=dtype,
                model_kwargs=model_kwargs,
                trust_remote_code=trust_remote_code,
                is_sentence_transformer=is_sentence_transformer,
                is_cross_encoder=is_cross_encoder,
                skip_tokenizer_init=skip_tokenizer_init,
                auto_cls=auto_cls,
            )

    def _init(
        self,
        model_name: str,
        dtype: str = "auto",
        *,
323
        model_kwargs: dict[str, Any] | None = None,
324
325
326
327
328
        trust_remote_code: bool = True,
        is_sentence_transformer: bool = False,
        is_cross_encoder: bool = False,
        skip_tokenizer_init: bool = False,
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
329
    ) -> None:
330
        model_name = maybe_model_redirect(model_name)
331
        self.model_name = model_name
332

333
334
        self.config = AutoConfig.from_pretrained(
            model_name,
335
            trust_remote_code=trust_remote_code,
336
337
        )
        self.device = self.get_default_device()
338
        self.dtype = dtype = _get_and_verify_dtype(
339
340
341
342
343
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
344
345

        model_kwargs = model_kwargs if model_kwargs is not None else {}
346
        model_kwargs.setdefault("dtype", dtype)
347

348
        if is_sentence_transformer:
349
350
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
351
352
353
354
355

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
356
                trust_remote_code=trust_remote_code,
357
            )
358
359
360
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
361
362
363
364
365

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
366
                trust_remote_code=trust_remote_code,
367
            )
368
        else:
369
370
            model = auto_cls.from_pretrained(
                model_name,
371
                trust_remote_code=trust_remote_code,
372
373
374
                **model_kwargs,
            )

375
            # in case some unquantized custom models are not in same dtype
376
377
378
            if getattr(model, "quantization_method", None) is None and any(
                p.dtype != self.dtype for p in model.parameters()
            ):
379
380
                model = model.to(dtype=self.dtype)

381
382
383
384
            if (
                getattr(model, "quantization_method", None) != "bitsandbytes"
                and len({p.device for p in model.parameters()}) < 2
            ):
385
                model = model.to(device=self.device)
386
387

            self.model = model
388

389
390
391
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
392
                dtype=dtype,
393
                trust_remote_code=trust_remote_code,
394
            )
395

396
397
398
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
399

400
401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
402
            dtype=dtype,
403
            trust_remote_code=trust_remote_code,
404
        )
405
406
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
407

408
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
409
        self,
410
411
412
413
414
        prompts: list[str] | list[list[int]],
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
    ) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]:
415
        if images is not None:
416
            assert len(prompts) == len(images)
417

418
419
420
421
422
423
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

424
        all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = []
425
        for i, prompt in enumerate(prompts):
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            if isinstance(prompt, str):
                processor_kwargs: dict[str, Any] = {
                    "text": prompt,
                    "return_tensors": "pt",
                }
                if images is not None and (image := images[i]) is not None:
                    processor_kwargs["images"] = image
                if videos is not None and (video := videos[i]) is not None:
                    processor_kwargs["videos"] = video
                if audios is not None and (audio_inputs := audios[i]) is not None:
                    # HACK - not all processors take sampling_rate; we should
                    # clean this up in the future.
                    if len(audio_inputs) == 2:
                        audio, sr = audio_inputs
                        processor_kwargs["audio"] = audio
                        processor_kwargs["sampling_rate"] = sr
                    else:
                        processor_kwargs["audio"] = audio_inputs

                inputs = self.processor(**processor_kwargs)
                if isinstance(inputs, BatchFeature):
                    inputs = inputs.to(dtype=self.dtype)
                all_inputs.append(inputs)
            else:
                # check that prompt is (batched) list of integers (token ids)
                if not is_list_of(prompt, typ=int, check="all"):
                    raise ValueError(
                        "Prompt must be a list of ints corresponding to the prompt token ids."
                    )
                # check that no multimodal input is provided
                if images or videos or audios:
                    raise ValueError(
                        "When providing prompt token ids multimodal inputs are not supported."
                    )
                input_dict = {
                    "input_ids": torch.tensor(prompt, dtype=torch.long).unsqueeze(0),
                }
                all_inputs.append(input_dict)
464
465
466

        return all_inputs

467
468
469
470
471
472
473
474
475
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

476
    def classify(self, prompts: list[str]) -> list[str]:
477
478
479
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
480
481
        problem_type = getattr(self.config, "problem_type", "")

482
483
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
484
485
486
487
488
489
            if problem_type == "regression":
                logits = output.logits[0].tolist()
            elif problem_type == "multi_label_classification":
                logits = output.logits.sigmoid()[0].tolist()
            else:
                logits = output.logits.softmax(dim=-1)[0].tolist()
490
491
492
493
            outputs.append(logits)

        return outputs

494
495
    def generate(
        self,
496
497
498
499
        prompts: list[str] | list[list[int]],
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
500
        **kwargs: Any,
501
    ) -> list[tuple[list[list[int]], list[str]]]:
502
503
504
        all_inputs = self.get_inputs(
            prompts, images=images, videos=videos, audios=audios
        )
505

506
        outputs: list[tuple[list[list[int]], list[str]]] = []
507
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
508
            output_ids = self.model.generate(
509
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
510
511
512
                use_cache=True,
                **kwargs,
            )
513
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
514
515
516
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
517
518
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
519
520
521
522
523
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
524
        prompts: list[str] | list[list[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
525
        max_tokens: int,
526
527
528
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
529
        **kwargs: Any,
530
    ) -> list[tuple[list[int], str]]:
531
532
533
534
535
536
537
538
539
        outputs = self.generate(
            prompts,
            do_sample=False,
            max_new_tokens=max_tokens,
            images=images,
            videos=videos,
            audios=audios,
            **kwargs,
        )
540

541
        return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
542
543
544

    def generate_beam_search(
        self,
545
        prompts: list[str],
546
547
        beam_width: int,
        max_tokens: int,
548
549
550
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
551
    ) -> list[tuple[list[list[int]], list[str]]]:
552
553
554
555
556
557
558
559
560
561
        outputs = self.generate(
            prompts,
            do_sample=False,
            max_new_tokens=max_tokens,
            num_beams=beam_width,
            num_return_sequences=beam_width,
            images=images,
            videos=videos,
            audios=audios,
        )
562

563
564
565
566
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
567
                    x for x in output_ids[j] if x != self.tokenizer.pad_token_id
568
569
570
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
571

572
573
    def generate_greedy_logprobs(
        self,
574
        prompts: list[str],
575
        max_tokens: int,
576
577
578
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
579
        **kwargs: Any,
580
    ) -> list[list[torch.Tensor]]:
581
582
583
        all_inputs = self.get_inputs(
            prompts, images=images, videos=videos, audios=audios
        )
584

585
        all_logprobs: list[list[torch.Tensor]] = []
586
        for inputs in all_inputs:
587
            output = self.model.generate(
588
                **self.wrap_device(inputs),
589
590
591
592
593
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
594
                **kwargs,
595
            )
596
            seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states)
597
598
599
            all_logprobs.append(seq_logprobs)
        return all_logprobs

600
    def _hidden_states_to_seq_logprobs(
601
        self,
602
603
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
604
605
        output_embeddings = self.model.get_output_embeddings()

606
        seq_logprobs: list[torch.Tensor] = []
607
608
609
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
610
611
612
613
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
614
                output_embeddings.weight.t(),
615
            )
616
617
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
618
619
620
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

621
622
623
624
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
625
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
626
        num_logprobs: int | None,
627
    ) -> tuple[list[dict[int, float]], int]:
628
629
630
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

631
        # convert to dict
632
        seq_logprobs_lst: list[dict[int, float]] = []
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

650
651
    def generate_greedy_logprobs_limit(
        self,
652
        prompts: list[str],
653
        max_tokens: int,
654
655
656
657
        num_logprobs: int | None,
        images: PromptImageInput | None = None,
        audios: PromptAudioInput | None = None,
        videos: PromptVideoInput | None = None,
658
        **kwargs: Any,
659
    ) -> list[TokensTextLogprobs]:
660
661
662
        all_inputs = self.get_inputs(
            prompts, images=images, videos=videos, audios=audios
        )
663

664
665
666
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
667

668
        for inputs in all_inputs:
669
            output = self.model.generate(
670
                **self.wrap_device(inputs),
671
672
673
674
675
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
676
                **kwargs,
677
678
            )

679
680
681
            (
                seq_logprobs_lst,
                output_len,
682
            ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
683
684
685
686
687
688
689

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
690

691
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
692
693
694
695
        return [
            (output_ids, output_str, output_logprobs)
            for output_ids, output_str, output_logprobs in outputs
        ]
696

697
    def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
698
        return self.model.encode(prompts, *args, **kwargs)
699

700
701
    def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
        return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
702

703
704
705
706
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
707
        del self.model
708
        cleanup_dist_env_and_memory()
709

Woosuk Kwon's avatar
Woosuk Kwon committed
710

Cyrus Leung's avatar
Cyrus Leung committed
711
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
712
713
714
715
716
def hf_runner():
    return HfRunner


class VllmRunner:
717
718
    """
    The default value of some arguments have been modified from
719
    {class}`~vllm.LLM` as follows:
720

721
722
723
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
724
725
    - `block_size`: To reduce memory usage, set default to `64` if on XPU
        devices, otherwise default to `16`.
726
727
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
728
    - `enforce_eager`: Set to `False` to test CUDA graph.
729
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
730
731
732
733

    def __init__(
        self,
        model_name: str,
734
735
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
736
        tokenizer_name: str | None = None,
737
        tokenizer_mode: str = "auto",
738
        trust_remote_code: bool = True,
739
740
        seed: int | None = 0,
        max_model_len: int | None = 1024,
741
        dtype: str = "auto",
742
        disable_log_stats: bool = True,
743
        tensor_parallel_size: int = 1,
744
        block_size: int = 16 if not torch.xpu.is_available() else 64,
745
        enable_chunked_prefill: bool | None = False,
746
        swap_space: int = 4,
747
        enforce_eager: bool | None = False,
748
        # Set this to avoid hanging issue
749
        default_torch_num_threads: int | None = None,
750
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
751
    ) -> None:
752
753
754
755
756
        init_ctx = (
            nullcontext()
            if default_torch_num_threads is None
            else set_default_torch_num_threads(default_torch_num_threads)
        )
757

758
        if not kwargs.get("compilation_config", None):
759
760
761
762
            # Note(@tdoublep): This is set to 4 because some tests (e.g., hybrid
            # model tests) may set max_num_seqs=4. If min cudagraph_capture_size is
            # set to larger than max_num_seqs, then it will lead to *no* graphs
            # being captured which can trigger edge cases that we don't handle yet.
763
            kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
764

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        with init_ctx:
            self.llm = LLM(
                model=model_name,
                runner=runner,
                convert=convert,
                tokenizer=tokenizer_name,
                tokenizer_mode=tokenizer_mode,
                trust_remote_code=trust_remote_code,
                dtype=dtype,
                seed=seed,
                swap_space=swap_space,
                enforce_eager=enforce_eager,
                disable_log_stats=disable_log_stats,
                tensor_parallel_size=tensor_parallel_size,
                max_model_len=max_model_len,
                block_size=block_size,
                enable_chunked_prefill=enable_chunked_prefill,
                **kwargs,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
784

785
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
786
        self,
787
788
789
790
        prompts: list[str] | list[torch.Tensor] | list[list[int]],
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
791
    ) -> list[dict[str, Any]]:
792
793
794
        if any(
            x is not None and len(x) != len(prompts) for x in [images, videos, audios]
        ):
795
            raise ValueError(
796
797
                "All non-None multimodal inputs must have the same length as prompts"
            )
798

799
        inputs = list[dict[str, Any]]()
800
        for i, prompt in enumerate(prompts):
801
802
803
804
805
806
807
808
809
            prompt_dict = dict[str, Any]()
            if isinstance(prompt, str):
                prompt_dict["prompt"] = prompt
            elif isinstance(prompt, list):
                prompt_dict["prompt_token_ids"] = prompt
            else:
                prompt_dict["prompt_embeds"] = prompt

            multi_modal_data = dict[str, Any]()
810
811
812
813
814
815
816
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

817
818
            if multi_modal_data:
                prompt_dict["multi_modal_data"] = multi_modal_data
819

820
            inputs.append(prompt_dict)
821
822
823
824
825

        return inputs

    def generate(
        self,
826
        prompts: list[str] | list[torch.Tensor] | list[list[int]],
827
        sampling_params: SamplingParams,
828
829
830
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
831
        **kwargs: Any,
832
    ) -> list[tuple[list[list[int]], list[str]]]:
833
        inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
834

835
836
837
        req_outputs = self.llm.generate(
            inputs, sampling_params=sampling_params, **kwargs
        )
838

839
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
840
841
842
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
843
844
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
845
846
            for sample in req_output.outputs:
                output_str = sample.text
847
                output_ids = list(sample.token_ids)
848
                req_sample_output_ids.append(prompt_ids + output_ids)
849
                req_sample_output_strs.append((prompt_str or "") + output_str)
850
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
851
852
        return outputs

853
    @staticmethod
854
    def _final_steps_generate_w_logprobs(
855
856
857
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
858
        for req_output in req_outputs:
859
            assert len(req_output.outputs) > 0
860
861
            for sample in req_output.outputs:
                output_str = sample.text
862
                output_ids = list(sample.token_ids)
863
                output_logprobs = sample.logprobs
864
865
866
            outputs.append(
                (output_ids, output_str, output_logprobs, req_output.prompt_logprobs)
            )
867
868
        return outputs

869
870
    def generate_w_logprobs(
        self,
871
        prompts: list[str],
872
        sampling_params: SamplingParams,
873
874
875
        images: PromptImageInput | None = None,
        audios: PromptAudioInput | None = None,
        videos: PromptVideoInput | None = None,
876
        **kwargs: Any,
877
    ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
878
879
880
881
882
883
884
885
886
        inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)

        req_outputs = self.llm.generate(
            inputs, sampling_params=sampling_params, **kwargs
        )

        toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(
            req_outputs
        )
887
        # Omit prompt logprobs if not required by sampling params
888
889
890
891
892
        return (
            [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
            if sampling_params.prompt_logprobs is None
            else toks_str_logsprobs_prompt_logprobs
        )
893

Woosuk Kwon's avatar
Woosuk Kwon committed
894
895
    def generate_greedy(
        self,
896
        prompts: list[str] | list[torch.Tensor] | list[list[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
897
        max_tokens: int,
898
899
900
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
901
        **kwargs: Any,
902
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
903
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
904
905
906
907
908
909
910
911
912
        outputs = self.generate(
            prompts,
            greedy_params,
            images=images,
            videos=videos,
            audios=audios,
            **kwargs,
        )
        return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
913

914
915
    def generate_greedy_logprobs(
        self,
916
        prompts: list[str],
917
        max_tokens: int,
918
919
920
921
922
923
924
        num_logprobs: int | None,
        num_prompt_logprobs: int | None = None,
        images: PromptImageInput | None = None,
        audios: PromptAudioInput | None = None,
        videos: PromptVideoInput | None = None,
        stop_token_ids: list[int] | None = None,
        stop: list[str] | None = None,
925
        **kwargs: Any,
926
    ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
927
928
929
930
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
931
            prompt_logprobs=num_prompt_logprobs,
932
            stop_token_ids=stop_token_ids,
933
934
            stop=stop,
        )
935

936
937
938
939
940
941
942
943
        return self.generate_w_logprobs(
            prompts,
            greedy_logprobs_params,
            images=images,
            audios=audios,
            videos=videos,
            **kwargs,
        )
944

945
946
947
948
949
950
951
    def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
        """
        Return the perplexity score associated with generating the prompts

        :param prompts: list of prompts to score
        :return: perplexity score of each prompt
        """
952
953
954
        outputs = self.generate_greedy_logprobs(
            prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0
        )
955
956
957
958

        perplexities = []
        for output in outputs:
            output = cast(TokensTextLogprobsPromptLogprobs, output)
959
            token_datas = cast(list[dict[int, Logprob] | None], output[3])
960
961
962
963
964
965
966
967
968
969
970
971
972
            assert token_datas[0] is None
            token_log_probs = []
            for token_data in token_datas[1:]:
                assert token_data is not None
                assert len(token_data) == 1
                token_log_prob = list(token_data.values())[0].logprob
                token_log_probs.append(token_log_prob)

            perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
            perplexities.append(perplexity)

        return perplexities

973
    def generate_beam_search(
974
        self,
975
        prompts: list[str],
976
977
        beam_width: int,
        max_tokens: int,
978
979
980
981
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
        concurrency_limit: int | None = None,
982
    ) -> list[tuple[list[list[int]], list[str]]]:
983
984
985
986
987
988
989
        inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)

        outputs = self.llm.beam_search(
            inputs,
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
            concurrency_limit=concurrency_limit,
        )
990
991
992
993
994
995
996
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

997
    def classify(self, prompts: list[str]) -> list[list[float]]:
998
        req_outputs = self.llm.classify(prompts)
999
1000
        return [req_output.outputs.probs for req_output in req_outputs]

1001
1002
1003
    def embed(
        self,
        prompts: list[str],
1004
1005
1006
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
1007
1008
1009
1010
        *args,
        **kwargs,
    ) -> list[list[float]]:
        inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
Cyrus Leung's avatar
Cyrus Leung committed
1011

1012
        req_outputs = self.llm.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1013
        return [req_output.outputs.embedding for req_output in req_outputs]
1014

1015
1016
1017
1018
1019
1020
    def token_embed(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.llm.encode(prompts, pooling_task="token_embed")
        return [req_output.outputs.data for req_output in req_outputs]

    def token_classify(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.llm.encode(prompts, pooling_task="token_classify")
1021
1022
        return [req_output.outputs.data for req_output in req_outputs]

1023
1024
1025
1026
    def reward(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.llm.reward(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

1027
1028
    def score(
        self,
1029
1030
        text_1: list[str] | str,
        text_2: list[str] | str,
1031
1032
        *args,
        **kwargs,
1033
    ) -> list[float]:
1034
        req_outputs = self.llm.score(text_1, text_2, *args, **kwargs)
1035
        return [req_output.outputs.score for req_output in req_outputs]
1036

1037
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
1038
        return self.llm.apply_model(func)
1039

1040
1041
1042
    def get_llm(self) -> LLM:
        return self.llm

1043
1044
1045
1046
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1047
        del self.llm
1048
        cleanup_dist_env_and_memory()
1049

Woosuk Kwon's avatar
Woosuk Kwon committed
1050

1051
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1052
1053
def vllm_runner():
    return VllmRunner
1054
1055


1056
1057
1058
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
1059

1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1071
1072
1073
1074
1075
1076
1077


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1078
    from vllm.platforms import current_platform
1079

1080
    return current_platform.device_count()
1081
1082
1083


temp_dir = tempfile.gettempdir()
1084
1085
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1086
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1087
1088
1089
1090


@pytest.fixture
def dummy_opt_path():
1091
1092
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1093
1094
1095
1096
1097
        snapshot_download(
            repo_id="facebook/opt-125m",
            local_dir=_dummy_opt_path,
            ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"],
        )
1098
        assert os.path.exists(json_path)
1099
        with open(json_path) as f:
1100
1101
1102
1103
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1104
1105
1106
1107
1108
1109
1110
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        snapshot_download(
            repo_id="llava-hf/llava-1.5-7b-hf",
            local_dir=_dummy_llava_path,
            ignore_patterns=[
                "*.bin",
                "*.bin.index.json",
                "*.pt",
                "*.h5",
                "*.msgpack",
                "*.safetensors",
            ],
        )
1123
        assert os.path.exists(json_path)
1124
        with open(json_path) as f:
1125
1126
1127
1128
1129
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1130
1131
1132
1133
1134
1135


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        snapshot_download(
            repo_id="BAAI/bge-multilingual-gemma2",
            local_dir=_dummy_gemma2_embedding_path,
            ignore_patterns=[
                "*.bin",
                "*.bin.index.json",
                "*.pt",
                "*.h5",
                "*.msgpack",
                "*.safetensors",
            ],
        )
1148
        assert os.path.exists(json_path)
1149
        with open(json_path) as f:
1150
1151
1152
1153
1154
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1155
1156
1157
1158
1159


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
1160
1161
1162
    parser.addoption(
        "--optional", action="store_true", default=False, help="run optional test"
    )
1163
1164
1165
1166
1167
1168
1169
1170
1171


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1172
            item.add_marker(skip_optional)
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231


class AssetHandler(http.server.BaseHTTPRequestHandler):
    # _IMAGE_CACHE : Dict[str, bytes] = {}

    def log_message(self, *args, **kwargs):
        pass

    def do_GET(self):
        # Accepts paths like: /1280px-Venn_diagram_rgb.jpg
        filename = self.path.lstrip("/")
        if not filename or "." not in filename:
            self.send_error(404, "Missing filename (expected /<name>.<ext>)")
            return

        base, ext = filename.rsplit(".", 1)
        ext = ext.lower()

        if ext not in ["jpg", "png"]:
            self.send_error(404, f"Unsupported extension: .{ext}")
            return

        try:
            data = ImageAsset(base).read_bytes(ext=ext)
        except Exception as e:
            self.send_error(500, f"Failed to load asset: {ext} {base} {e} ")
            return

        ctype, _ = mimetypes.guess_type(filename)
        if ctype is None:
            ctype = {"jpg": "image/jpg", "png": "image/png"}[ext]
        self.send_response(200)
        self.send_header("Content-Type", ctype)
        self.send_header("Content-Length", str(len(data)))
        self.end_headers()
        self.wfile.write(data)


def _find_free_port() -> int:
    with socket.socket() as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]


class LocalAssetServer:
    address: str
    port: int
1232
1233
    server: http.server.ThreadingHTTPServer | None
    thread: threading.Thread | None
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243

    def __init__(self, address: str = "127.0.0.1") -> None:
        self.address = address
        self.port = -1
        self.server = None
        self.thread = None

    def __enter__(self):
        self.port = _find_free_port()
        self.server = http.server.ThreadingHTTPServer(
1244
1245
1246
            (self.address, self.port), AssetHandler
        )
        self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        self.thread.start()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.server:
            self.server.shutdown()
            del self.server

        if self.thread:
            self.thread.join()
            del self.thread

        if exc_type is None:
            return None

        return False

    @property
    def base_url(self) -> str:
        assert self.port is not None
        return f"http://{self.address}:{self.port}"

    def url_for(self, name: str) -> str:
        """e.g., name='RGBA_comp.png' -> 'http://127.0.0.1:PORT/RGBA_comp.png'"""
        return f"{self.base_url}/{name}"

    def get_image_asset(self, name: str) -> Image.Image:
        return fetch_image(self.url_for(name))


@pytest.fixture(scope="session")
def local_asset_server() -> Generator[LocalAssetServer, None, None]:
    """
1280
    Starts a thread based HTTP server bound to 127.0.0.1 on a random free port.
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
    The server currently servers images at:
    http://127.0.0.1:<port>/<name>.<ext>
    """
    with LocalAssetServer() as srv:
        yield srv


@pytest.fixture
def image_url(request, local_asset_server) -> str:
    # request.param is one of the IMAGE_ASSETS filenames
    name = request.param
    return local_asset_server.url_for(name)


@pytest.fixture
def image_urls(request, local_asset_server) -> list[str]:
    """Indirect fixture: takes a list of names, returns list of full URLs."""
    names: list[str] = request.param
    return [local_asset_server.url_for(name) for name in names]