conftest.py 40.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import os
5
import tempfile
6
from enum import Enum
7
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
import pytest
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from huggingface_hub import snapshot_download
15
from PIL import Image
16
17
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
18
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm import LLM, SamplingParams
23
from vllm.assets.audio import AudioAsset
24
from vllm.assets.image import ImageAsset
25
from vllm.assets.video import VideoAsset
26
from vllm.config import TaskOption, _get_and_verify_dtype
27
from vllm.connections import global_http_connection
28
from vllm.distributed import (cleanup_dist_env_and_memory,
29
30
                              init_distributed_environment,
                              initialize_model_parallel)
31
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
32
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
33
from vllm.logger import init_logger
34
from vllm.outputs import RequestOutput
35
from vllm.sampling_params import BeamSearchParams
36
from vllm.utils import cuda_device_count_stateless
37

38
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
39

40
41
42
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
43
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
44

Cyrus Leung's avatar
Cyrus Leung committed
45
_M = TypeVar("_M")
46

47
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
48
49

PromptImageInput = _PromptMultiModalInput[Image.Image]
50
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
51
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
52

53

54
def _read_prompts(filename: str) -> list[str]:
55
    with open(filename) as f:
56
57
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59


60
class ImageAssetPrompts(TypedDict):
61
62
    stop_sign: str
    cherry_blossom: str
63
64


65
class ImageTestAssets(list[ImageAsset]):
66
67

    def __init__(self) -> None:
68
69
70
71
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
72

73
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
74
75
76
77
78
79
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
80
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
81
82


83
84
class VideoAssetPrompts(TypedDict):
    baby_reading: str
85
86


87
class VideoTestAssets(list[VideoAsset]):
88
89
90

    def __init__(self) -> None:
        super().__init__([
91
            VideoAsset("baby_reading"),
92
93
        ])

94
95
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
96
97


98
class AudioAssetPrompts(TypedDict):
99
100
101
102
    mary_had_lamb: str
    winning_call: str


103
class AudioTestAssets(list[AudioAsset]):
104
105
106
107
108
109
110

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

111
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
112
113
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

114

115
IMAGE_ASSETS = ImageTestAssets()
116
"""Singleton instance of {class}`ImageTestAssets`."""
117
VIDEO_ASSETS = VideoTestAssets()
118
"""Singleton instance of {class}`VideoTestAssets`."""
119
AUDIO_ASSETS = AudioTestAssets()
120
"""Singleton instance of {class}`AudioTestAssets`."""
121
122


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
143
@pytest.fixture(params=[True, False])
144
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
145
146
147
148
149
150
151
152
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
153
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
154
    else:
155
156
157
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
158
159


160
161
162
163
164
165
166
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


167
168
169
170
171
172
173
174
175
176
177
178
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
179
    cleanup_dist_env_and_memory()
180
181


182
@pytest.fixture()
183
def should_do_global_cleanup_after_test(request) -> bool:
184
185
186
187
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
188

189
    return not request.node.get_closest_marker("skip_global_cleanup")
190
191


192
@pytest.fixture(autouse=True)
193
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
194
    yield
195
    if should_do_global_cleanup_after_test:
196
        cleanup_dist_env_and_memory()
197
198


199
200
201
202
203
204
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
205
@pytest.fixture
206
def example_prompts() -> list[str]:
207
208
    prompts = []
    for filename in _TEST_PROMPTS:
209
        prompts += _read_prompts(filename)
210
211
212
    return prompts


213
214
215
216
217
218
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


219
220
221
222
223
224
225
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


226
@pytest.fixture
227
def example_encoder_decoder_prompts(
228
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
229
230
231
232
233
234
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
235

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
251
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
252
        DecoderPromptType.EMPTY_STR:
253
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
254
        DecoderPromptType.CUSTOM:
255
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
256
257
258
    }


259
@pytest.fixture
260
def example_long_prompts() -> list[str]:
261
262
    prompts = []
    for filename in _LONG_PROMPTS:
263
        prompts += _read_prompts(filename)
264
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266


267
@pytest.fixture(scope="session")
268
def image_assets() -> ImageTestAssets:
269
270
271
    return IMAGE_ASSETS


272
@pytest.fixture(scope="session")
273
def video_assets() -> VideoTestAssets:
274
275
276
    return VIDEO_ASSETS


277
@pytest.fixture(scope="session")
278
def audio_assets() -> AudioTestAssets:
279
280
281
    return AUDIO_ASSETS


282
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
283
_R = TypeVar("_R")
284

Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
287

class HfRunner:

288
    def get_default_device(self):
289
        from vllm.platforms import current_platform
290

291
292
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
293
294

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
295
296
297
        if x is None or isinstance(x, (bool, )):
            return x

298
        if device is None:
299
            device = self.device
300

301
302
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
303

304
305
306
307
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
308

Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
    def __init__(
        self,
        model_name: str,
312
        dtype: str = "auto",
313
        *,
314
        model_kwargs: Optional[dict[str, Any]] = None,
315
        trust_remote_code: bool = True,
316
        is_sentence_transformer: bool = False,
317
        is_cross_encoder: bool = False,
318
        skip_tokenizer_init: bool = False,
319
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
320
    ) -> None:
321
        self.model_name = model_name
322

323
324
        self.config = AutoConfig.from_pretrained(
            model_name,
325
            trust_remote_code=trust_remote_code,
326
327
        )
        self.device = self.get_default_device()
328
329
330
331
332
333
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
334
335
336
337

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

338
        if is_sentence_transformer:
339
340
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
341
342
343
344
345

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
346
                trust_remote_code=trust_remote_code,
347
            )
348
349
350
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
351
352
353
354
355

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
356
                trust_remote_code=trust_remote_code,
357
            )
358
        else:
359
360
            model = auto_cls.from_pretrained(
                model_name,
361
                trust_remote_code=trust_remote_code,
362
363
364
                **model_kwargs,
            )

365
366
367
368
369
370
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

371
372
373
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
374
                model = model.to(device=self.device)
375
376

            self.model = model
377

378
379
380
381
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
382
                trust_remote_code=trust_remote_code,
383
            )
384

385
386
387
388
389
390
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
391
            trust_remote_code=trust_remote_code,
392
        )
393
394
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
395

396
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
397
        self,
398
        prompts: list[str],
399
        images: Optional[PromptImageInput] = None,
400
401
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
402
    ) -> list[Union[BatchFeature, BatchEncoding]]:
403
        if images is not None:
404
            assert len(prompts) == len(images)
405

406
407
408
409
410
411
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

412
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
413
        for i, prompt in enumerate(prompts):
414
            processor_kwargs: dict[str, Any] = {
415
416
417
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
418
419
420
421
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
422
423
424
425
426
427
428
429
430
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
431
432

            inputs = self.processor(**processor_kwargs)
433
434
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
435

436
437
438
439
            all_inputs.append(inputs)

        return all_inputs

440
441
442
443
444
445
446
447
448
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

449
    def classify(self, prompts: list[str]) -> list[str]:
450
451
452
453
454
455
456
457
458
459
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

460
461
    def generate(
        self,
462
        prompts: list[str],
463
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
464
        videos: Optional[PromptVideoInput] = None,
465
466
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
467
    ) -> list[tuple[list[list[int]], list[str]]]:
468
469
470
471
472
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

473
        outputs: list[tuple[list[list[int]], list[str]]] = []
474
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
475
            output_ids = self.model.generate(
476
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
477
478
479
                use_cache=True,
                **kwargs,
            )
480
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
481
482
483
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
484
485
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
486
487
488
489
490
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
491
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
492
        max_tokens: int,
493
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
494
        videos: Optional[PromptVideoInput] = None,
495
        audios: Optional[PromptAudioInput] = None,
496
        **kwargs: Any,
497
    ) -> list[tuple[list[int], str]]:
498
499
        outputs = self.generate(prompts,
                                do_sample=False,
500
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
501
                                images=images,
502
503
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
504
                                **kwargs)
505
506
507

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
508
509
510

    def generate_beam_search(
        self,
511
        prompts: list[str],
512
513
        beam_width: int,
        max_tokens: int,
514
515
516
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
517
    ) -> list[tuple[list[list[int]], list[str]]]:
518
519
520
521
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
522
523
524
525
526
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

527
528
529
530
531
532
533
534
535
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
536

537
538
    def generate_greedy_logprobs(
        self,
539
        prompts: list[str],
540
        max_tokens: int,
541
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
542
        videos: Optional[PromptVideoInput] = None,
543
        audios: Optional[PromptAudioInput] = None,
544
        **kwargs: Any,
545
    ) -> list[list[torch.Tensor]]:
546
547
548
549
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
550

551
        all_logprobs: list[list[torch.Tensor]] = []
552
        for inputs in all_inputs:
553
            output = self.model.generate(
554
                **self.wrap_device(inputs),
555
556
557
558
559
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
560
                **kwargs,
561
            )
562
563
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
564
565
566
            all_logprobs.append(seq_logprobs)
        return all_logprobs

567
    def _hidden_states_to_seq_logprobs(
568
        self,
569
570
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
571
572
        output_embeddings = self.model.get_output_embeddings()

573
        seq_logprobs: list[torch.Tensor] = []
574
575
576
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
577
578
579
580
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
581
                output_embeddings.weight.t(),
582
            )
583
584
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
585
586
587
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

588
589
590
591
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
592
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
593
        num_logprobs: int,
594
    ) -> tuple[list[dict[int, float]], int]:
595
596
597
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

598
        # convert to dict
599
        seq_logprobs_lst: list[dict[int, float]] = []
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

617
618
    def generate_greedy_logprobs_limit(
        self,
619
        prompts: list[str],
620
621
        max_tokens: int,
        num_logprobs: int,
622
623
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
624
        videos: Optional[PromptVideoInput] = None,
625
        **kwargs: Any,
626
    ) -> list[TokensTextLogprobs]:
627
628
629
630
631
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

632
633
634
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
635

636
        for inputs in all_inputs:
637
            output = self.model.generate(
638
                **self.wrap_device(inputs),
639
640
641
642
643
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
644
                **kwargs,
645
646
            )

647
648
649
650
651
652
653
654
655
656
657
658
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
659

660
661
662
663
664
665
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
666
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
667
668
        max_tokens: int,
        num_logprobs: int,
669
        images: Optional[PromptImageInput] = None,
670
        **kwargs: Any,
671
    ) -> list[TokensTextLogprobs]:
672
673
674
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
675

676
677
678
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
679

680
681
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
682
            processor_kwargs: dict[str, Any] = {
683
684
685
686
687
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
688

689
690
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
691
692
693
694

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
695
696
697
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
698
699
700
701
702
703
704
705

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
706
                **encoder_inputs,
707
708
709
710
711
712
713
714
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
715
716
717
718
719
720
721
722
723
724
725

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

726
727
728
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
729

730
731
732
733
734
735
    def predict(self, prompts: list[list[str]], *args,
                **kwargs) -> torch.Tensor:
        return self.model.predict(prompts,
                                  *args,
                                  convert_to_tensor=True,
                                  **kwargs)
736

737
738
739
740
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
741
        del self.model
742
        cleanup_dist_env_and_memory()
743

Woosuk Kwon's avatar
Woosuk Kwon committed
744

Cyrus Leung's avatar
Cyrus Leung committed
745
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
746
747
748
749
750
def hf_runner():
    return HfRunner


class VllmRunner:
751
752
    """
    The default value of some arguments have been modified from
753
    {class}`~vllm.LLM` as follows:
754

755
756
757
758
759
760
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
    - `block_size`: Set to `16` instead of `None` to reduce memory usage.
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
761
    - `enforce_eager`: Set to `False` to test CUDA graph.
762
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
763
764
765
766

    def __init__(
        self,
        model_name: str,
767
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
768
        tokenizer_name: Optional[str] = None,
769
        tokenizer_mode: str = "auto",
770
771
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
772
        max_model_len: int = 1024,
773
        dtype: str = "auto",
774
        disable_log_stats: bool = True,
775
        tensor_parallel_size: int = 1,
776
        block_size: int = 16,
777
        enable_chunked_prefill: Optional[bool] = False,
778
        swap_space: int = 4,
779
        enforce_eager: Optional[bool] = False,
780
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
781
782
783
    ) -> None:
        self.model = LLM(
            model=model_name,
784
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
785
            tokenizer=tokenizer_name,
786
            tokenizer_mode=tokenizer_mode,
787
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
788
            dtype=dtype,
789
            seed=seed,
790
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
791
            enforce_eager=enforce_eager,
792
            disable_log_stats=disable_log_stats,
793
            tensor_parallel_size=tensor_parallel_size,
794
            max_model_len=max_model_len,
795
796
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
797
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
798
799
        )

800
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
801
        self,
802
        prompts: Union[list[str], list[torch.Tensor]],
803
        images: Optional[PromptImageInput] = None,
804
805
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
806
    ) -> list[TextPrompt]:
807

808
809
810
811
812
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
813

814
815
816
817
818
819
820
821
822
823
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

824
825
826
827
828
829
            text_prompt_kwargs = {
                ("prompt" if isinstance(prompt, str) else "prompt_embeds"):
                prompt,
                "multi_modal_data": multi_modal_data or None
            }
            inputs.append(TextPrompt(**text_prompt_kwargs))
830
831
832
833
834

        return inputs

    def generate(
        self,
835
        prompts: Union[list[str], list[torch.Tensor]],
836
837
838
839
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
840
        **kwargs: Any,
841
    ) -> list[tuple[list[list[int]], list[str]]]:
842
843
844
845
846
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

847
        req_outputs = self.model.generate(inputs,
848
849
                                          sampling_params=sampling_params,
                                          **kwargs)
850

851
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
852
853
854
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
855
856
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
857
858
            for sample in req_output.outputs:
                output_str = sample.text
859
                output_ids = list(sample.token_ids)
860
                req_sample_output_ids.append(prompt_ids + output_ids)
861
                req_sample_output_strs.append((prompt_str or "") + output_str)
862
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
863
864
        return outputs

865
    @staticmethod
866
    def _final_steps_generate_w_logprobs(
867
868
869
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
870
        for req_output in req_outputs:
871
            assert len(req_output.outputs) > 0
872
873
            for sample in req_output.outputs:
                output_str = sample.text
874
                output_ids = list(sample.token_ids)
875
                output_logprobs = sample.logprobs
876
877
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
878
879
        return outputs

880
881
    def generate_w_logprobs(
        self,
882
        prompts: list[str],
883
        sampling_params: SamplingParams,
884
885
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
886
        videos: Optional[PromptVideoInput] = None,
887
        **kwargs: Any,
888
889
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
890
891
892
893
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
894

895
        req_outputs = self.model.generate(inputs,
896
897
                                          sampling_params=sampling_params,
                                          **kwargs)
898
899
900
901
902
903
904

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
905
906
907

    def generate_encoder_decoder_w_logprobs(
        self,
908
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
909
        sampling_params: SamplingParams,
910
911
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
912
913
914
915
916
917
918
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
919
920
921
922
923
924
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
925

Woosuk Kwon's avatar
Woosuk Kwon committed
926
927
    def generate_greedy(
        self,
928
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
929
        max_tokens: int,
930
        images: Optional[PromptImageInput] = None,
931
932
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
933
        **kwargs: Any,
934
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
935
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
936
937
938
939
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
940
941
                                audios=audios,
                                **kwargs)
942
943
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
944

945
946
    def generate_greedy_logprobs(
        self,
947
        prompts: list[str],
948
949
        max_tokens: int,
        num_logprobs: int,
950
        num_prompt_logprobs: Optional[int] = None,
951
952
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
953
        videos: Optional[PromptVideoInput] = None,
954
955
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
956
        **kwargs: Any,
957
958
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
959
960
961
962
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
963
            prompt_logprobs=num_prompt_logprobs,
964
965
            stop_token_ids=stop_token_ids,
            stop=stop)
966
967
968
969
970

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
971
972
                                        videos=videos,
                                        **kwargs)
973

974
975
    def generate_encoder_decoder_greedy_logprobs(
        self,
976
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
977
978
        max_tokens: int,
        num_logprobs: int,
979
        num_prompt_logprobs: Optional[int] = None,
980
        skip_special_tokens: bool = True,
981
982
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
983
984
985
986
987
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
988
            skip_special_tokens=skip_special_tokens,
989
        )
990
991
992
993
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

994
        return self.generate_encoder_decoder_w_logprobs(
995
996
            encoder_decoder_prompts, greedy_logprobs_params)

997
    def generate_beam_search(
998
        self,
999
        prompts: list[str],
1000
1001
        beam_width: int,
        max_tokens: int,
1002
1003
1004
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
1005
    ) -> list[tuple[list[list[int]], list[str]]]:
1006
1007
1008
1009
1010
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1011
        outputs = self.model.beam_search(
1012
            inputs,
1013
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
1014
1015
1016
1017
1018
1019
1020
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1021
    def classify(self, prompts: list[str]) -> list[list[float]]:
1022
1023
1024
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

1025
1026
1027
1028
1029
1030
1031
    def encode(self,
               prompts: list[str],
               images: Optional[PromptImageInput] = None,
               videos: Optional[PromptVideoInput] = None,
               audios: Optional[PromptAudioInput] = None,
               *args,
               **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1032
1033
1034
1035
1036
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1037
        req_outputs = self.model.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1038
        return [req_output.outputs.embedding for req_output in req_outputs]
1039

1040
1041
    def score(
        self,
1042
1043
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
1044
1045
        *args,
        **kwargs,
1046
    ) -> list[float]:
1047
        req_outputs = self.model.score(text_1, text_2, *args, **kwargs)
1048
        return [req_output.outputs.score for req_output in req_outputs]
1049

1050
1051
1052
1053
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

1054
1055
1056
1057
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1058
        del self.model
1059
        cleanup_dist_env_and_memory()
1060

Woosuk Kwon's avatar
Woosuk Kwon committed
1061

1062
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1063
1064
def vllm_runner():
    return VllmRunner
1065
1066


1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1081
1082
1083
1084
1085
1086
1087


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1088
    return cuda_device_count_stateless()
1089
1090
1091


temp_dir = tempfile.gettempdir()
1092
1093
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1094
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1095
1096
1097
1098


@pytest.fixture
def dummy_opt_path():
1099
1100
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1101
        snapshot_download(repo_id="facebook/opt-125m",
1102
                          local_dir=_dummy_opt_path,
1103
1104
1105
1106
1107
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1108
        with open(json_path) as f:
1109
1110
1111
1112
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1127
        with open(json_path) as f:
1128
1129
1130
1131
1132
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1146
        with open(json_path) as f:
1147
1148
1149
1150
1151
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1170
            item.add_marker(skip_optional)
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")