conftest.py 39.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
import os
5
import tempfile
6
from collections import UserList
7
from enum import Enum
8
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
import pytest
import torch
13
import torch.nn as nn
14
import torch.nn.functional as F
15
from huggingface_hub import snapshot_download
16
from PIL import Image
17
18
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
19
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
20

21
22
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm import LLM, SamplingParams
24
from vllm.assets.audio import AudioAsset
25
from vllm.assets.image import ImageAsset
26
from vllm.assets.video import VideoAsset
27
from vllm.config import TaskOption, _get_and_verify_dtype
28
from vllm.connections import global_http_connection
29
from vllm.distributed import (cleanup_dist_env_and_memory,
30
31
                              init_distributed_environment,
                              initialize_model_parallel)
32
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
33
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
34
from vllm.logger import init_logger
35
from vllm.outputs import RequestOutput
36
from vllm.sampling_params import BeamSearchParams
37
from vllm.utils import cuda_device_count_stateless
38

39
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
40

41
42
43
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
44
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
45

Cyrus Leung's avatar
Cyrus Leung committed
46
_M = TypeVar("_M")
47

48
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
49
50

PromptImageInput = _PromptMultiModalInput[Image.Image]
51
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
52
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
53

54

55
def _read_prompts(filename: str) -> list[str]:
56
    with open(filename) as f:
57
58
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60


61
62
63
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
64
65


66
67
class _ImageAssetsBase(UserList[ImageAsset]):
    pass
68

69
70

class _ImageAssets(_ImageAssetsBase):
71
72

    def __init__(self) -> None:
73
74
75
76
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
77

78
    def prompts(self, prompts: _ImageAssetPrompts) -> list[str]:
79
80
81
82
83
84
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
85
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
86
87


88
89
90
91
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


92
93
class _VideoAssetsBase(UserList[VideoAsset]):
    pass
94
95
96
97
98
99


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
100
            VideoAsset("sample_demo_1"),
101
102
        ])

103
    def prompts(self, prompts: _VideoAssetPrompts) -> list[str]:
104
105
106
        return [prompts["sample_demo_1"]]


107
108
109
110
111
class _AudioAssetPrompts(TypedDict):
    mary_had_lamb: str
    winning_call: str


112
113
114
115
116
117
118
119
120
121
122
123
class _AudioAssetsBase(UserList[AudioAsset]):
    pass


class _AudioAssets(_AudioAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

124
125
126
    def prompts(self, prompts: _AudioAssetPrompts) -> list[str]:
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

127

128
129
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
130
131
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
132
133
AUDIO_ASSETS = _AudioAssets()
"""Singleton instance of :class:`_AudioAssets`."""
134
135


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
156
@pytest.fixture(params=[True, False])
157
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
158
159
160
161
162
163
164
165
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
166
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
167
    else:
168
169
170
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
171
172


173
174
175
176
177
178
179
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


180
181
182
183
184
185
186
187
188
189
190
191
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
192
    cleanup_dist_env_and_memory()
193
194


195
@pytest.fixture()
196
def should_do_global_cleanup_after_test(request) -> bool:
197
198
199
200
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
201

202
    return not request.node.get_closest_marker("skip_global_cleanup")
203
204


205
@pytest.fixture(autouse=True)
206
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
207
    yield
208
    if should_do_global_cleanup_after_test:
209
        cleanup_dist_env_and_memory()
210
211


212
213
214
215
216
217
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
218
@pytest.fixture
219
def example_prompts() -> list[str]:
220
221
    prompts = []
    for filename in _TEST_PROMPTS:
222
        prompts += _read_prompts(filename)
223
224
225
    return prompts


226
227
228
229
230
231
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


232
233
234
235
236
237
238
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


239
@pytest.fixture
240
def example_encoder_decoder_prompts(
241
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
242
243
244
245
246
247
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
264
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
265
        DecoderPromptType.EMPTY_STR:
266
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
267
        DecoderPromptType.CUSTOM:
268
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
269
270
271
    }


272
@pytest.fixture
273
def example_long_prompts() -> list[str]:
274
275
    prompts = []
    for filename in _LONG_PROMPTS:
276
        prompts += _read_prompts(filename)
277
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
278
279


280
281
282
283
284
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


285
286
287
288
289
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


290
291
292
293
294
@pytest.fixture(scope="session")
def audio_assets() -> _AudioAssets:
    return AUDIO_ASSETS


295
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
296
_R = TypeVar("_R")
297

Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300

class HfRunner:

301
    def get_default_device(self):
302
        from vllm.platforms import current_platform
303

304
305
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
306
307

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
308
309
310
        if x is None or isinstance(x, (bool, )):
            return x

311
        if device is None:
312
            device = self.device
313

314
315
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
316

317
318
319
320
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
321

Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324
    def __init__(
        self,
        model_name: str,
325
        dtype: str = "auto",
326
        *,
327
        model_kwargs: Optional[dict[str, Any]] = None,
328
        is_sentence_transformer: bool = False,
329
        is_cross_encoder: bool = False,
330
        skip_tokenizer_init: bool = False,
331
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
332
    ) -> None:
333
        self.model_name = model_name
334

335
336
337
338
339
340
341
342
343
344
        self.config = AutoConfig.from_pretrained(
            model_name,
            trust_remote_code=True,
        )
        self.device = self.get_default_device()
        self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

345
        if is_sentence_transformer:
346
347
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
348
349
350
351
352
353
354

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
                trust_remote_code=True,
            )
355
356
357
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
358
359
360
361
362
363
364

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
                trust_remote_code=True,
            )
365
        else:
366
367
368
369
370
371
372
373
374
375
376
377
            model = auto_cls.from_pretrained(
                model_name,
                trust_remote_code=True,
                **model_kwargs,
            )

            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
                model = model.to(self.device)

            self.model = model
378

379
380
381
382
383
384
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
385

386
387
388
389
390
391
392
393
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
394
395
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
396

397
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
398
        self,
399
        prompts: list[str],
400
        images: Optional[PromptImageInput] = None,
401
402
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
403
    ) -> list[Union[BatchFeature, BatchEncoding]]:
404
        if images is not None:
405
            assert len(prompts) == len(images)
406

407
408
409
410
411
412
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

413
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
414
        for i, prompt in enumerate(prompts):
415
            processor_kwargs: dict[str, Any] = {
416
417
418
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
419
420
421
422
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
423
424
425
426
427
428
429
430
431
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
432
433

            inputs = self.processor(**processor_kwargs)
434
435
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
436

437
438
439
440
            all_inputs.append(inputs)

        return all_inputs

441
    def classify(self, prompts: list[str]) -> list[str]:
442
443
444
445
446
447
448
449
450
451
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

452
453
    def generate(
        self,
454
        prompts: list[str],
455
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
456
        videos: Optional[PromptVideoInput] = None,
457
458
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
459
    ) -> list[tuple[list[list[int]], list[str]]]:
460
461
462
463
464
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

465
        outputs: list[tuple[list[list[int]], list[str]]] = []
466
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
467
            output_ids = self.model.generate(
468
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
469
470
471
                use_cache=True,
                **kwargs,
            )
472
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
473
474
475
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
476
477
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
478
479
480
481
482
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
483
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
484
        max_tokens: int,
485
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
486
        videos: Optional[PromptVideoInput] = None,
487
        audios: Optional[PromptAudioInput] = None,
488
        **kwargs: Any,
489
    ) -> list[tuple[list[int], str]]:
490
491
        outputs = self.generate(prompts,
                                do_sample=False,
492
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
493
                                images=images,
494
495
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
496
                                **kwargs)
497
498
499

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
500
501
502

    def generate_beam_search(
        self,
503
        prompts: list[str],
504
505
        beam_width: int,
        max_tokens: int,
506
507
508
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
509
    ) -> list[tuple[list[list[int]], list[str]]]:
510
511
512
513
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
514
515
516
517
518
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

519
520
521
522
523
524
525
526
527
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
528

529
530
    def generate_greedy_logprobs(
        self,
531
        prompts: list[str],
532
        max_tokens: int,
533
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
534
        videos: Optional[PromptVideoInput] = None,
535
        audios: Optional[PromptAudioInput] = None,
536
        **kwargs: Any,
537
    ) -> list[list[torch.Tensor]]:
538
539
540
541
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
542

543
        all_logprobs: list[list[torch.Tensor]] = []
544
        for inputs in all_inputs:
545
            output = self.model.generate(
546
                **self.wrap_device(inputs),
547
548
549
550
551
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
552
                **kwargs,
553
            )
554
555
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
556
557
558
            all_logprobs.append(seq_logprobs)
        return all_logprobs

559
    def _hidden_states_to_seq_logprobs(
560
        self,
561
562
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
563
564
        output_embeddings = self.model.get_output_embeddings()

565
        seq_logprobs: list[torch.Tensor] = []
566
567
568
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
569
570
571
572
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
573
                output_embeddings.weight.t(),
574
            )
575
576
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
577
578
579
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

580
581
582
583
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
584
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
585
        num_logprobs: int,
586
    ) -> tuple[list[dict[int, float]], int]:
587
588
589
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

590
        # convert to dict
591
        seq_logprobs_lst: list[dict[int, float]] = []
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

609
610
    def generate_greedy_logprobs_limit(
        self,
611
        prompts: list[str],
612
613
        max_tokens: int,
        num_logprobs: int,
614
615
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
616
        videos: Optional[PromptVideoInput] = None,
617
        **kwargs: Any,
618
    ) -> list[TokensTextLogprobs]:
619
620
621
622
623
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

624
625
626
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
627

628
        for inputs in all_inputs:
629
            output = self.model.generate(
630
                **self.wrap_device(inputs),
631
632
633
634
635
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
636
                **kwargs,
637
638
            )

639
640
641
642
643
644
645
646
647
648
649
650
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
651

652
653
654
655
656
657
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
658
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
659
660
        max_tokens: int,
        num_logprobs: int,
661
        images: Optional[PromptImageInput] = None,
662
        **kwargs: Any,
663
    ) -> list[TokensTextLogprobs]:
664
665
666
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
667

668
669
670
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
671

672
673
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
674
            processor_kwargs: dict[str, Any] = {
675
676
677
678
679
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
680

681
682
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
683
684
685
686

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
687
688
689
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
690
691
692
693
694
695
696
697

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
698
                **encoder_inputs,
699
700
701
702
703
704
705
706
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
707
708
709
710
711
712
713
714
715
716
717

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

718
719
720
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
721

722
    def predict(self, prompts: list[list[str]]) -> torch.Tensor:
723
724
        return self.model.predict(prompts, convert_to_tensor=True)

725
726
727
728
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
729
        del self.model
730
        cleanup_dist_env_and_memory()
731

Woosuk Kwon's avatar
Woosuk Kwon committed
732

Cyrus Leung's avatar
Cyrus Leung committed
733
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
734
735
736
737
738
def hf_runner():
    return HfRunner


class VllmRunner:
739
740
741
    """
    The default value of some arguments have been modified from
    :class:`~vllm.LLM` as follows:
742

743
744
745
746
747
748
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
    - `block_size`: Set to `16` instead of `None` to reduce memory usage.
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
749
    - `enforce_eager`: Set to `False` to test CUDA graph.
750
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
751
752
753
754

    def __init__(
        self,
        model_name: str,
755
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
756
        tokenizer_name: Optional[str] = None,
757
        tokenizer_mode: str = "auto",
758
759
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
760
        max_model_len: int = 1024,
761
        dtype: str = "auto",
762
        disable_log_stats: bool = True,
763
        tensor_parallel_size: int = 1,
764
        block_size: int = 16,
765
        enable_chunked_prefill: Optional[bool] = False,
766
        swap_space: int = 4,
767
        enforce_eager: Optional[bool] = False,
768
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
769
770
771
    ) -> None:
        self.model = LLM(
            model=model_name,
772
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
773
            tokenizer=tokenizer_name,
774
            tokenizer_mode=tokenizer_mode,
775
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
776
            dtype=dtype,
777
            seed=seed,
778
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
779
            enforce_eager=enforce_eager,
780
            disable_log_stats=disable_log_stats,
781
            tensor_parallel_size=tensor_parallel_size,
782
            max_model_len=max_model_len,
783
784
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
785
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
786
787
        )

788
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
789
        self,
790
        prompts: Union[list[str], list[torch.Tensor]],
791
        images: Optional[PromptImageInput] = None,
792
793
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
794
    ) -> list[TextPrompt]:
795

796
797
798
799
800
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
801

802
803
804
805
806
807
808
809
810
811
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

812
813
814
815
816
817
            text_prompt_kwargs = {
                ("prompt" if isinstance(prompt, str) else "prompt_embeds"):
                prompt,
                "multi_modal_data": multi_modal_data or None
            }
            inputs.append(TextPrompt(**text_prompt_kwargs))
818
819
820
821
822

        return inputs

    def generate(
        self,
823
        prompts: Union[list[str], list[torch.Tensor]],
824
825
826
827
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
828
        **kwargs: Any,
829
    ) -> list[tuple[list[list[int]], list[str]]]:
830
831
832
833
834
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

835
        req_outputs = self.model.generate(inputs,
836
837
                                          sampling_params=sampling_params,
                                          **kwargs)
838

839
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
840
841
842
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
843
844
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
845
846
            for sample in req_output.outputs:
                output_str = sample.text
847
                output_ids = list(sample.token_ids)
848
                req_sample_output_ids.append(prompt_ids + output_ids)
849
                req_sample_output_strs.append((prompt_str or "") + output_str)
850
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
851
852
        return outputs

853
    @staticmethod
854
    def _final_steps_generate_w_logprobs(
855
856
857
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
858
        for req_output in req_outputs:
859
            assert len(req_output.outputs) > 0
860
861
            for sample in req_output.outputs:
                output_str = sample.text
862
                output_ids = list(sample.token_ids)
863
                output_logprobs = sample.logprobs
864
865
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
866
867
        return outputs

868
869
    def generate_w_logprobs(
        self,
870
        prompts: list[str],
871
        sampling_params: SamplingParams,
872
873
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
874
        videos: Optional[PromptVideoInput] = None,
875
        **kwargs: Any,
876
877
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
878
879
880
881
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
882

883
        req_outputs = self.model.generate(inputs,
884
885
                                          sampling_params=sampling_params,
                                          **kwargs)
886
887
888
889
890
891
892

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
893
894
895

    def generate_encoder_decoder_w_logprobs(
        self,
896
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
897
        sampling_params: SamplingParams,
898
899
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
900
901
902
903
904
905
906
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
907
908
909
910
911
912
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
913

Woosuk Kwon's avatar
Woosuk Kwon committed
914
915
    def generate_greedy(
        self,
916
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
917
        max_tokens: int,
918
        images: Optional[PromptImageInput] = None,
919
920
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
921
        **kwargs: Any,
922
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
923
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
924
925
926
927
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
928
929
                                audios=audios,
                                **kwargs)
930
931
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
932

933
934
    def generate_greedy_logprobs(
        self,
935
        prompts: list[str],
936
937
        max_tokens: int,
        num_logprobs: int,
938
        num_prompt_logprobs: Optional[int] = None,
939
940
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
941
        videos: Optional[PromptVideoInput] = None,
942
943
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
944
        **kwargs: Any,
945
946
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
947
948
949
950
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
951
            prompt_logprobs=num_prompt_logprobs,
952
953
            stop_token_ids=stop_token_ids,
            stop=stop)
954
955
956
957
958

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
959
960
                                        videos=videos,
                                        **kwargs)
961

962
963
    def generate_encoder_decoder_greedy_logprobs(
        self,
964
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
965
966
        max_tokens: int,
        num_logprobs: int,
967
        num_prompt_logprobs: Optional[int] = None,
968
        skip_special_tokens: bool = True,
969
970
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
971
972
973
974
975
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
976
            skip_special_tokens=skip_special_tokens,
977
        )
978
979
980
981
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

982
        return self.generate_encoder_decoder_w_logprobs(
983
984
            encoder_decoder_prompts, greedy_logprobs_params)

985
    def generate_beam_search(
986
        self,
987
        prompts: list[str],
988
989
        beam_width: int,
        max_tokens: int,
990
991
992
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
993
    ) -> list[tuple[list[list[int]], list[str]]]:
994
995
996
997
998
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

999
        outputs = self.model.beam_search(
1000
            inputs,
1001
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
1002
1003
1004
1005
1006
1007
1008
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1009
    def classify(self, prompts: list[str]) -> list[list[float]]:
1010
1011
1012
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

1013
1014
1015
1016
1017
1018
1019
    def encode(self,
               prompts: list[str],
               images: Optional[PromptImageInput] = None,
               videos: Optional[PromptVideoInput] = None,
               audios: Optional[PromptAudioInput] = None,
               *args,
               **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1020
1021
1022
1023
1024
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1025
        req_outputs = self.model.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1026
        return [req_output.outputs.embedding for req_output in req_outputs]
1027

1028
1029
    def score(
        self,
1030
1031
1032
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
    ) -> list[float]:
1033
        req_outputs = self.model.score(text_1, text_2)
1034
        return [req_output.outputs.score for req_output in req_outputs]
1035

1036
1037
1038
1039
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

1040
1041
1042
1043
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1044
        del self.model
1045
        cleanup_dist_env_and_memory()
1046

Woosuk Kwon's avatar
Woosuk Kwon committed
1047

1048
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1049
1050
def vllm_runner():
    return VllmRunner
1051
1052


1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1067
1068
1069
1070
1071
1072
1073


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1074
    return cuda_device_count_stateless()
1075
1076
1077


temp_dir = tempfile.gettempdir()
1078
1079
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1080
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1081
1082
1083
1084


@pytest.fixture
def dummy_opt_path():
1085
1086
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1087
        snapshot_download(repo_id="facebook/opt-125m",
1088
                          local_dir=_dummy_opt_path,
1089
1090
1091
1092
1093
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1094
        with open(json_path) as f:
1095
1096
1097
1098
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1113
        with open(json_path) as f:
1114
1115
1116
1117
1118
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1132
        with open(json_path) as f:
1133
1134
1135
1136
1137
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1156
            item.add_marker(skip_optional)
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")