conftest.py 42.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12

# ruff: noqa

from tblib import pickling_support

# Install support for pickling exceptions so that we can nicely propagate
# failures from tests running in a subprocess.
# This should be run before any custom exception subclasses are defined.
pickling_support.install()

13
import http.server
14
import json
15
import math
16
import mimetypes
17
import os
18
import socket
19
import tempfile
20
21
import threading
from collections.abc import Generator
22
from contextlib import nullcontext
23
from enum import Enum
24
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
25

26
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
import pytest
import torch
29
import torch.nn as nn
30
import torch.nn.functional as F
31
from huggingface_hub import snapshot_download
32
from PIL import Image
33
34
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
35
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
36

37
38
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm import LLM, SamplingParams
40
from vllm.assets.audio import AudioAsset
41
from vllm.assets.image import ImageAsset
42
from vllm.assets.video import VideoAsset
43
44
from vllm.config.model import (ConvertOption, RunnerOption,
                               _get_and_verify_dtype)
45
from vllm.connections import global_http_connection
46
from vllm.distributed import (cleanup_dist_env_and_memory,
47
48
                              init_distributed_environment,
                              initialize_model_parallel)
49
from vllm.logger import init_logger
50
from vllm.logprobs import Logprob
51
from vllm.multimodal.utils import fetch_image
52
from vllm.outputs import RequestOutput
53
from vllm.sampling_params import BeamSearchParams
54
from vllm.transformers_utils.utils import maybe_model_redirect
55
from vllm.utils import set_default_torch_num_threads
56

57
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
60
61
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
62
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
63

Cyrus Leung's avatar
Cyrus Leung committed
64
_M = TypeVar("_M")
65

66
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
67
68

PromptImageInput = _PromptMultiModalInput[Image.Image]
69
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
70
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
71

72

73
def _read_prompts(filename: str) -> list[str]:
74
    with open(filename) as f:
75
76
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78


79
class ImageAssetPrompts(TypedDict):
80
81
    stop_sign: str
    cherry_blossom: str
82
83


84
class ImageTestAssets(list[ImageAsset]):
85
86

    def __init__(self) -> None:
87
88
89
90
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
91

92
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
93
94
95
96
97
98
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
99
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
100
101


102
103
class VideoAssetPrompts(TypedDict):
    baby_reading: str
104
105


106
class VideoTestAssets(list[VideoAsset]):
107
108
109

    def __init__(self) -> None:
        super().__init__([
110
            VideoAsset("baby_reading"),
111
112
        ])

113
114
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
115
116


117
class AudioAssetPrompts(TypedDict):
118
119
120
121
    mary_had_lamb: str
    winning_call: str


122
class AudioTestAssets(list[AudioAsset]):
123
124
125
126
127
128
129

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

130
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
131
132
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

133

134
IMAGE_ASSETS = ImageTestAssets()
135
"""Singleton instance of {class}`ImageTestAssets`."""
136
VIDEO_ASSETS = VideoTestAssets()
137
"""Singleton instance of {class}`VideoTestAssets`."""
138
AUDIO_ASSETS = AudioTestAssets()
139
"""Singleton instance of {class}`AudioTestAssets`."""
140
141


142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


162
163
164
165
166
167
168
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


169
170
171
172
173
174
175
176
177
178
179
180
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
181
    cleanup_dist_env_and_memory()
182
183


184
@pytest.fixture()
185
def should_do_global_cleanup_after_test(request) -> bool:
186
187
188
189
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
190

191
    return not request.node.get_closest_marker("skip_global_cleanup")
192
193


194
@pytest.fixture(autouse=True)
195
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
196
    yield
197
    if should_do_global_cleanup_after_test:
198
        cleanup_dist_env_and_memory()
199
200


201
202
203
204
205
206
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
207
@pytest.fixture
208
def example_prompts() -> list[str]:
209
210
    prompts = []
    for filename in _TEST_PROMPTS:
211
        prompts += _read_prompts(filename)
212
213
214
    return prompts


215
216
217
218
219
220
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


221
222
223
224
225
226
227
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


228
@pytest.fixture
229
def example_long_prompts() -> list[str]:
230
231
    prompts = []
    for filename in _LONG_PROMPTS:
232
        prompts += _read_prompts(filename)
233
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235


236
@pytest.fixture(scope="session")
237
def image_assets() -> ImageTestAssets:
238
239
240
    return IMAGE_ASSETS


241
@pytest.fixture(scope="session")
242
def video_assets() -> VideoTestAssets:
243
244
245
    return VIDEO_ASSETS


246
@pytest.fixture(scope="session")
247
def audio_assets() -> AudioTestAssets:
248
249
250
    return AUDIO_ASSETS


251
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
252
_R = TypeVar("_R")
253

Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256

class HfRunner:

257
    def get_default_device(self):
258
        from vllm.platforms import current_platform
259

260
261
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
262
263

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
264
265
266
        if x is None or isinstance(x, (bool, )):
            return x

267
        if device is None:
268
            device = self.device
269

270
271
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
272

273
274
275
276
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
277

Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
280
    def __init__(
        self,
        model_name: str,
281
        dtype: str = "auto",
282
        *,
283
        model_kwargs: Optional[dict[str, Any]] = None,
284
        trust_remote_code: bool = True,
285
        is_sentence_transformer: bool = False,
286
        is_cross_encoder: bool = False,
287
        skip_tokenizer_init: bool = False,
288
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        # Set this to avoid hanging issue
        default_torch_num_threads: Optional[int] = None,
    ) -> None:
        init_ctx = (nullcontext() if default_torch_num_threads is None else
                    set_default_torch_num_threads(default_torch_num_threads))

        with init_ctx:
            self._init(
                model_name=model_name,
                dtype=dtype,
                model_kwargs=model_kwargs,
                trust_remote_code=trust_remote_code,
                is_sentence_transformer=is_sentence_transformer,
                is_cross_encoder=is_cross_encoder,
                skip_tokenizer_init=skip_tokenizer_init,
                auto_cls=auto_cls,
            )

    def _init(
        self,
        model_name: str,
        dtype: str = "auto",
        *,
        model_kwargs: Optional[dict[str, Any]] = None,
        trust_remote_code: bool = True,
        is_sentence_transformer: bool = False,
        is_cross_encoder: bool = False,
        skip_tokenizer_init: bool = False,
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
318
    ) -> None:
319
        model_name = maybe_model_redirect(model_name)
320
        self.model_name = model_name
321

322
323
        self.config = AutoConfig.from_pretrained(
            model_name,
324
            trust_remote_code=trust_remote_code,
325
326
        )
        self.device = self.get_default_device()
327
328
329
330
331
332
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
333
334
335
336

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

337
        if is_sentence_transformer:
338
339
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
340
341
342
343
344

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
345
                trust_remote_code=trust_remote_code,
346
            )
347
348
349
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
350
351
352
353
354

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
355
                trust_remote_code=trust_remote_code,
356
            )
357
        else:
358
359
            model = auto_cls.from_pretrained(
                model_name,
360
                trust_remote_code=trust_remote_code,
361
362
363
                **model_kwargs,
            )

364
365
366
367
368
369
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

370
371
372
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
373
                model = model.to(device=self.device)
374
375

            self.model = model
376

377
378
379
380
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
381
                trust_remote_code=trust_remote_code,
382
            )
383

384
385
386
387
388
389
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
390
            trust_remote_code=trust_remote_code,
391
        )
392
393
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
394

395
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
396
        self,
397
        prompts: list[str],
398
        images: Optional[PromptImageInput] = None,
399
400
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
401
    ) -> list[Union[BatchFeature, BatchEncoding]]:
402
        if images is not None:
403
            assert len(prompts) == len(images)
404

405
406
407
408
409
410
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

411
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
412
        for i, prompt in enumerate(prompts):
413
            processor_kwargs: dict[str, Any] = {
414
415
416
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
417
418
419
420
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
421
422
423
424
425
426
427
428
429
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
430
431

            inputs = self.processor(**processor_kwargs)
432
433
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
434

435
436
437
438
            all_inputs.append(inputs)

        return all_inputs

439
440
441
442
443
444
445
446
447
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

448
    def classify(self, prompts: list[str]) -> list[str]:
449
450
451
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
452
453
        problem_type = getattr(self.config, "problem_type", "")

454
455
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
456
457
458
459
460
461
            if problem_type == "regression":
                logits = output.logits[0].tolist()
            elif problem_type == "multi_label_classification":
                logits = output.logits.sigmoid()[0].tolist()
            else:
                logits = output.logits.softmax(dim=-1)[0].tolist()
462
463
464
465
            outputs.append(logits)

        return outputs

466
467
    def generate(
        self,
468
        prompts: list[str],
469
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
470
        videos: Optional[PromptVideoInput] = None,
471
472
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
473
    ) -> list[tuple[list[list[int]], list[str]]]:
474
475
476
477
478
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

479
        outputs: list[tuple[list[list[int]], list[str]]] = []
480
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
481
            output_ids = self.model.generate(
482
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
483
484
485
                use_cache=True,
                **kwargs,
            )
486
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
487
488
489
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
490
491
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
492
493
494
495
496
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
497
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
498
        max_tokens: int,
499
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
500
        videos: Optional[PromptVideoInput] = None,
501
        audios: Optional[PromptAudioInput] = None,
502
        **kwargs: Any,
503
    ) -> list[tuple[list[int], str]]:
504
505
        outputs = self.generate(prompts,
                                do_sample=False,
506
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
507
                                images=images,
508
509
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
510
                                **kwargs)
511
512
513

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
514
515
516

    def generate_beam_search(
        self,
517
        prompts: list[str],
518
519
        beam_width: int,
        max_tokens: int,
520
521
522
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
523
    ) -> list[tuple[list[list[int]], list[str]]]:
524
525
526
527
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
528
529
530
531
532
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

533
534
535
536
537
538
539
540
541
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
542

543
544
    def generate_greedy_logprobs(
        self,
545
        prompts: list[str],
546
        max_tokens: int,
547
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
548
        videos: Optional[PromptVideoInput] = None,
549
        audios: Optional[PromptAudioInput] = None,
550
        **kwargs: Any,
551
    ) -> list[list[torch.Tensor]]:
552
553
554
555
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
556

557
        all_logprobs: list[list[torch.Tensor]] = []
558
        for inputs in all_inputs:
559
            output = self.model.generate(
560
                **self.wrap_device(inputs),
561
562
563
564
565
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
566
                **kwargs,
567
            )
568
569
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
570
571
572
            all_logprobs.append(seq_logprobs)
        return all_logprobs

573
    def _hidden_states_to_seq_logprobs(
574
        self,
575
576
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
577
578
        output_embeddings = self.model.get_output_embeddings()

579
        seq_logprobs: list[torch.Tensor] = []
580
581
582
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
583
584
585
586
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
587
                output_embeddings.weight.t(),
588
            )
589
590
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
591
592
593
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

594
595
596
597
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
598
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
599
        num_logprobs: Optional[int],
600
    ) -> tuple[list[dict[int, float]], int]:
601
602
603
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

604
        # convert to dict
605
        seq_logprobs_lst: list[dict[int, float]] = []
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

623
624
    def generate_greedy_logprobs_limit(
        self,
625
        prompts: list[str],
626
        max_tokens: int,
627
        num_logprobs: Optional[int],
628
629
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
630
        videos: Optional[PromptVideoInput] = None,
631
        **kwargs: Any,
632
    ) -> list[TokensTextLogprobs]:
633
634
635
636
637
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

638
639
640
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
641

642
        for inputs in all_inputs:
643
            output = self.model.generate(
644
                **self.wrap_device(inputs),
645
646
647
648
649
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
650
                **kwargs,
651
652
            )

653
654
655
656
657
658
659
660
661
662
663
664
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
665

666
667
668
669
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

670
671
672
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
673

674
675
676
677
678
679
    def predict(self, prompts: list[list[str]], *args,
                **kwargs) -> torch.Tensor:
        return self.model.predict(prompts,
                                  *args,
                                  convert_to_tensor=True,
                                  **kwargs)
680

681
682
683
684
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
685
        del self.model
686
        cleanup_dist_env_and_memory()
687

Woosuk Kwon's avatar
Woosuk Kwon committed
688

Cyrus Leung's avatar
Cyrus Leung committed
689
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
690
691
692
693
694
def hf_runner():
    return HfRunner


class VllmRunner:
695
696
    """
    The default value of some arguments have been modified from
697
    {class}`~vllm.LLM` as follows:
698

699
700
701
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
702
703
    - `block_size`: To reduce memory usage, set default to `64` if on XPU
        devices, otherwise default to `16`.
704
705
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
706
    - `enforce_eager`: Set to `False` to test CUDA graph.
707
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
708
709
710
711

    def __init__(
        self,
        model_name: str,
712
713
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
714
        tokenizer_name: Optional[str] = None,
715
        tokenizer_mode: str = "auto",
716
717
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
718
        max_model_len: Optional[int] = 1024,
719
        dtype: str = "auto",
720
        disable_log_stats: bool = True,
721
        tensor_parallel_size: int = 1,
722
        block_size: int = 16 if not torch.xpu.is_available() else 64,
723
        enable_chunked_prefill: Optional[bool] = False,
724
        swap_space: int = 4,
725
        enforce_eager: Optional[bool] = False,
726
727
        # Set this to avoid hanging issue
        default_torch_num_threads: Optional[int] = None,
728
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
729
    ) -> None:
730
731
732
        init_ctx = (nullcontext() if default_torch_num_threads is None else
                    set_default_torch_num_threads(default_torch_num_threads))

733
        if not kwargs.get("compilation_config", None):
734
            kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]}
735

736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        with init_ctx:
            self.llm = LLM(
                model=model_name,
                runner=runner,
                convert=convert,
                tokenizer=tokenizer_name,
                tokenizer_mode=tokenizer_mode,
                trust_remote_code=trust_remote_code,
                dtype=dtype,
                seed=seed,
                swap_space=swap_space,
                enforce_eager=enforce_eager,
                disable_log_stats=disable_log_stats,
                tensor_parallel_size=tensor_parallel_size,
                max_model_len=max_model_len,
                block_size=block_size,
                enable_chunked_prefill=enable_chunked_prefill,
                **kwargs,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
755

756
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
757
        self,
758
        prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
759
        images: Optional[PromptImageInput] = None,
760
761
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
762
    ) -> list[dict[str, Any]]:
763
764
765
766
767
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
768

769
        inputs = list[dict[str, Any]]()
770
        for i, prompt in enumerate(prompts):
771
772
773
774
775
776
777
778
779
            prompt_dict = dict[str, Any]()
            if isinstance(prompt, str):
                prompt_dict["prompt"] = prompt
            elif isinstance(prompt, list):
                prompt_dict["prompt_token_ids"] = prompt
            else:
                prompt_dict["prompt_embeds"] = prompt

            multi_modal_data = dict[str, Any]()
780
781
782
783
784
785
786
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

787
788
            if multi_modal_data:
                prompt_dict["multi_modal_data"] = multi_modal_data
789

790
            inputs.append(prompt_dict)
791
792
793
794
795

        return inputs

    def generate(
        self,
796
        prompts: Union[list[str], list[torch.Tensor]],
797
798
799
800
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
801
        **kwargs: Any,
802
    ) -> list[tuple[list[list[int]], list[str]]]:
803
804
805
806
807
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

808
809
810
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
811

812
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
813
814
815
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
816
817
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
818
819
            for sample in req_output.outputs:
                output_str = sample.text
820
                output_ids = list(sample.token_ids)
821
                req_sample_output_ids.append(prompt_ids + output_ids)
822
                req_sample_output_strs.append((prompt_str or "") + output_str)
823
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
824
825
        return outputs

826
    @staticmethod
827
    def _final_steps_generate_w_logprobs(
828
829
830
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
831
        for req_output in req_outputs:
832
            assert len(req_output.outputs) > 0
833
834
            for sample in req_output.outputs:
                output_str = sample.text
835
                output_ids = list(sample.token_ids)
836
                output_logprobs = sample.logprobs
837
838
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
839
840
        return outputs

841
842
    def generate_w_logprobs(
        self,
843
        prompts: list[str],
844
        sampling_params: SamplingParams,
845
846
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
847
        videos: Optional[PromptVideoInput] = None,
848
        **kwargs: Any,
849
850
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
851
852
853
854
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
855

856
857
858
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
859
860
861
862
863
864
865

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
866

Woosuk Kwon's avatar
Woosuk Kwon committed
867
868
    def generate_greedy(
        self,
869
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
870
        max_tokens: int,
871
        images: Optional[PromptImageInput] = None,
872
873
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
874
        **kwargs: Any,
875
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
876
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
877
878
879
880
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
881
882
                                audios=audios,
                                **kwargs)
883
884
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
885

886
887
    def generate_greedy_logprobs(
        self,
888
        prompts: list[str],
889
        max_tokens: int,
890
        num_logprobs: Optional[int],
891
        num_prompt_logprobs: Optional[int] = None,
892
893
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
894
        videos: Optional[PromptVideoInput] = None,
895
896
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
897
        **kwargs: Any,
898
899
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
900
901
902
903
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
904
            prompt_logprobs=num_prompt_logprobs,
905
906
            stop_token_ids=stop_token_ids,
            stop=stop)
907
908
909
910
911

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
912
913
                                        videos=videos,
                                        **kwargs)
914

915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
    def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
        """
        Return the perplexity score associated with generating the prompts

        :param prompts: list of prompts to score
        :return: perplexity score of each prompt
        """
        outputs = self.generate_greedy_logprobs(prompts,
                                                max_tokens=1,
                                                num_logprobs=None,
                                                num_prompt_logprobs=0)

        perplexities = []
        for output in outputs:
            output = cast(TokensTextLogprobsPromptLogprobs, output)
            token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
            assert token_datas[0] is None
            token_log_probs = []
            for token_data in token_datas[1:]:
                assert token_data is not None
                assert len(token_data) == 1
                token_log_prob = list(token_data.values())[0].logprob
                token_log_probs.append(token_log_prob)

            perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
            perplexities.append(perplexity)

        return perplexities

944
    def generate_beam_search(
945
        self,
946
        prompts: list[str],
947
948
        beam_width: int,
        max_tokens: int,
949
950
951
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
952
        concurrency_limit: Optional[int] = None,
953
    ) -> list[tuple[list[list[int]], list[str]]]:
954
955
956
957
958
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

959
960
961
962
        outputs = self.llm.beam_search(inputs,
                                       BeamSearchParams(beam_width=beam_width,
                                                        max_tokens=max_tokens),
                                       concurrency_limit=concurrency_limit)
963
964
965
966
967
968
969
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

970
    def classify(self, prompts: list[str]) -> list[list[float]]:
971
        req_outputs = self.llm.classify(prompts)
972
973
        return [req_output.outputs.probs for req_output in req_outputs]

974
975
976
977
978
979
980
    def embed(self,
              prompts: list[str],
              images: Optional[PromptImageInput] = None,
              videos: Optional[PromptVideoInput] = None,
              audios: Optional[PromptAudioInput] = None,
              *args,
              **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
981
982
983
984
985
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

986
        req_outputs = self.llm.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
987
        return [req_output.outputs.embedding for req_output in req_outputs]
988

989
    def encode(self, prompts: list[str]) -> list[list[float]]:
990
        req_outputs = self.llm.encode(prompts)
991
992
        return [req_output.outputs.data for req_output in req_outputs]

993
994
995
996
    def reward(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.llm.reward(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

997
998
    def score(
        self,
999
1000
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
1001
1002
        *args,
        **kwargs,
1003
    ) -> list[float]:
1004
        req_outputs = self.llm.score(text_1, text_2, *args, **kwargs)
1005
        return [req_output.outputs.score for req_output in req_outputs]
1006

1007
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
1008
        return self.llm.apply_model(func)
1009

1010
1011
1012
    def get_llm(self) -> LLM:
        return self.llm

1013
1014
1015
1016
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1017
        del self.llm
1018
        cleanup_dist_env_and_memory()
1019

Woosuk Kwon's avatar
Woosuk Kwon committed
1020

1021
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1022
1023
def vllm_runner():
    return VllmRunner
1024
1025


1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1040
1041
1042
1043
1044
1045
1046


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1047
1048
    from vllm.platforms import current_platform
    return current_platform.device_count()
1049
1050
1051


temp_dir = tempfile.gettempdir()
1052
1053
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1054
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1055
1056
1057
1058


@pytest.fixture
def dummy_opt_path():
1059
1060
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1061
        snapshot_download(repo_id="facebook/opt-125m",
1062
                          local_dir=_dummy_opt_path,
1063
1064
1065
1066
1067
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1068
        with open(json_path) as f:
1069
1070
1071
1072
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
1084
                              "*.msgpack", "*.safetensors"
1085
1086
                          ])
        assert os.path.exists(json_path)
1087
        with open(json_path) as f:
1088
1089
1090
1091
1092
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
1103
                              "*.msgpack", "*.safetensors"
1104
1105
                          ])
        assert os.path.exists(json_path)
1106
        with open(json_path) as f:
1107
1108
1109
1110
1111
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1130
            item.add_marker(skip_optional)
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258


class AssetHandler(http.server.BaseHTTPRequestHandler):
    # _IMAGE_CACHE : Dict[str, bytes] = {}

    def log_message(self, *args, **kwargs):
        pass

    def do_GET(self):
        # Accepts paths like: /1280px-Venn_diagram_rgb.jpg
        filename = self.path.lstrip("/")
        if not filename or "." not in filename:
            self.send_error(404, "Missing filename (expected /<name>.<ext>)")
            return

        base, ext = filename.rsplit(".", 1)
        ext = ext.lower()

        if ext not in ["jpg", "png"]:
            self.send_error(404, f"Unsupported extension: .{ext}")
            return

        try:
            data = ImageAsset(base).read_bytes(ext=ext)
        except Exception as e:
            self.send_error(500, f"Failed to load asset: {ext} {base} {e} ")
            return

        ctype, _ = mimetypes.guess_type(filename)
        if ctype is None:
            ctype = {"jpg": "image/jpg", "png": "image/png"}[ext]
        self.send_response(200)
        self.send_header("Content-Type", ctype)
        self.send_header("Content-Length", str(len(data)))
        self.end_headers()
        self.wfile.write(data)


def _find_free_port() -> int:
    with socket.socket() as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]


class LocalAssetServer:

    address: str
    port: int
    server: Optional[http.server.ThreadingHTTPServer]
    thread: Optional[threading.Thread]

    def __init__(self, address: str = "127.0.0.1") -> None:
        self.address = address
        self.port = -1
        self.server = None
        self.thread = None

    def __enter__(self):
        self.port = _find_free_port()
        self.server = http.server.ThreadingHTTPServer(
            (self.address, self.port), AssetHandler)
        self.thread = threading.Thread(target=self.server.serve_forever,
                                       daemon=True)
        self.thread.start()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.server:
            self.server.shutdown()
            del self.server

        if self.thread:
            self.thread.join()
            del self.thread

        if exc_type is None:
            return None

        return False

    @property
    def base_url(self) -> str:
        assert self.port is not None
        return f"http://{self.address}:{self.port}"

    def url_for(self, name: str) -> str:
        """e.g., name='RGBA_comp.png' -> 'http://127.0.0.1:PORT/RGBA_comp.png'"""
        return f"{self.base_url}/{name}"

    def get_image_asset(self, name: str) -> Image.Image:
        return fetch_image(self.url_for(name))


@pytest.fixture(scope="session")
def local_asset_server() -> Generator[LocalAssetServer, None, None]:
    """
    Starts a thread based HTTP server bound to 127.0.0.1 on a random free port. 
    The server currently servers images at:
    http://127.0.0.1:<port>/<name>.<ext>
    """
    with LocalAssetServer() as srv:
        yield srv


@pytest.fixture
def image_url(request, local_asset_server) -> str:
    # request.param is one of the IMAGE_ASSETS filenames
    name = request.param
    return local_asset_server.url_for(name)


@pytest.fixture
def image_urls(request, local_asset_server) -> list[str]:
    """Indirect fixture: takes a list of names, returns list of full URLs."""
    names: list[str] = request.param
    return [local_asset_server.url_for(name) for name in names]