conftest.py 40.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import os
5
import tempfile
6
from enum import Enum
7
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
import pytest
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from huggingface_hub import snapshot_download
15
from PIL import Image
16
17
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
18
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm import LLM, SamplingParams
23
from vllm.assets.audio import AudioAsset
24
from vllm.assets.image import ImageAsset
25
from vllm.assets.video import VideoAsset
26
from vllm.config import TaskOption, _get_and_verify_dtype
27
from vllm.connections import global_http_connection
28
from vllm.distributed import (cleanup_dist_env_and_memory,
29
30
                              init_distributed_environment,
                              initialize_model_parallel)
31
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
32
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
33
from vllm.logger import init_logger
34
from vllm.outputs import RequestOutput
35
from vllm.sampling_params import BeamSearchParams
36
from vllm.utils import cuda_device_count_stateless
37

38
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
39

40
41
42
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
43
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
44

Cyrus Leung's avatar
Cyrus Leung committed
45
_M = TypeVar("_M")
46

47
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
48
49

PromptImageInput = _PromptMultiModalInput[Image.Image]
50
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
51
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
52

53

54
def _read_prompts(filename: str) -> list[str]:
55
    with open(filename) as f:
56
57
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59


60
class ImageAssetPrompts(TypedDict):
61
62
    stop_sign: str
    cherry_blossom: str
63
64


65
class ImageTestAssets(list[ImageAsset]):
66
67

    def __init__(self) -> None:
68
69
70
71
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
72

73
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
74
75
76
77
78
79
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
80
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
81
82


83
84
class VideoAssetPrompts(TypedDict):
    baby_reading: str
85
86


87
class VideoTestAssets(list[VideoAsset]):
88
89
90

    def __init__(self) -> None:
        super().__init__([
91
            VideoAsset("baby_reading"),
92
93
        ])

94
95
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
96
97


98
class AudioAssetPrompts(TypedDict):
99
100
101
102
    mary_had_lamb: str
    winning_call: str


103
class AudioTestAssets(list[AudioAsset]):
104
105
106
107
108
109
110

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

111
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
112
113
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

114

115
IMAGE_ASSETS = ImageTestAssets()
116
"""Singleton instance of {class}`ImageTestAssets`."""
117
VIDEO_ASSETS = VideoTestAssets()
118
"""Singleton instance of {class}`VideoTestAssets`."""
119
AUDIO_ASSETS = AudioTestAssets()
120
"""Singleton instance of {class}`AudioTestAssets`."""
121
122


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
143
@pytest.fixture(params=[True, False])
144
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
145
146
147
148
149
150
151
152
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
153
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
154
    else:
155
156
157
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
158
159


160
161
162
163
164
165
166
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


167
168
169
170
171
172
173
174
175
176
177
178
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
179
    cleanup_dist_env_and_memory()
180
181


182
@pytest.fixture()
183
def should_do_global_cleanup_after_test(request) -> bool:
184
185
186
187
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
188

189
    return not request.node.get_closest_marker("skip_global_cleanup")
190
191


192
@pytest.fixture(autouse=True)
193
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
194
    yield
195
    if should_do_global_cleanup_after_test:
196
        cleanup_dist_env_and_memory()
197
198


199
200
201
202
203
204
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
205
@pytest.fixture
206
def example_prompts() -> list[str]:
207
208
    prompts = []
    for filename in _TEST_PROMPTS:
209
        prompts += _read_prompts(filename)
210
211
212
    return prompts


213
214
215
216
217
218
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


219
220
221
222
223
224
225
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


226
@pytest.fixture
227
def example_encoder_decoder_prompts(
228
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
229
230
231
232
233
234
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
235

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
251
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
252
        DecoderPromptType.EMPTY_STR:
253
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
254
        DecoderPromptType.CUSTOM:
255
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
256
257
258
    }


259
@pytest.fixture
260
def example_long_prompts() -> list[str]:
261
262
    prompts = []
    for filename in _LONG_PROMPTS:
263
        prompts += _read_prompts(filename)
264
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266


267
@pytest.fixture(scope="session")
268
def image_assets() -> ImageTestAssets:
269
270
271
    return IMAGE_ASSETS


272
@pytest.fixture(scope="session")
273
def video_assets() -> VideoTestAssets:
274
275
276
    return VIDEO_ASSETS


277
@pytest.fixture(scope="session")
278
def audio_assets() -> AudioTestAssets:
279
280
281
    return AUDIO_ASSETS


282
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
283
_R = TypeVar("_R")
284

Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
287

class HfRunner:

288
    def get_default_device(self):
289
        from vllm.platforms import current_platform
290

291
292
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
293
294

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
295
296
297
        if x is None or isinstance(x, (bool, )):
            return x

298
        if device is None:
299
            device = self.device
300

301
302
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
303

304
305
306
307
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
308

Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
    def __init__(
        self,
        model_name: str,
312
        dtype: str = "auto",
313
        *,
314
        model_kwargs: Optional[dict[str, Any]] = None,
315
        trust_remote_code: bool = True,
316
        is_sentence_transformer: bool = False,
317
        is_cross_encoder: bool = False,
318
        skip_tokenizer_init: bool = False,
319
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
320
    ) -> None:
321
        self.model_name = model_name
322

323
324
        self.config = AutoConfig.from_pretrained(
            model_name,
325
            trust_remote_code=trust_remote_code,
326
327
        )
        self.device = self.get_default_device()
328
329
330
331
332
333
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
334
335
336
337

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

338
        if is_sentence_transformer:
339
340
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
341
342
343
344
345

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
346
                trust_remote_code=trust_remote_code,
347
            )
348
349
350
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
351
352
353
354
355

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
356
                trust_remote_code=trust_remote_code,
357
            )
358
        else:
359
360
            model = auto_cls.from_pretrained(
                model_name,
361
                trust_remote_code=trust_remote_code,
362
363
364
                **model_kwargs,
            )

365
366
367
368
369
370
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

371
372
373
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
374
                model = model.to(device=self.device)
375
376

            self.model = model
377

378
379
380
381
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
382
                trust_remote_code=trust_remote_code,
383
            )
384

385
386
387
388
389
390
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
391
            trust_remote_code=trust_remote_code,
392
        )
393
394
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
395

396
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
397
        self,
398
        prompts: list[str],
399
        images: Optional[PromptImageInput] = None,
400
401
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
402
    ) -> list[Union[BatchFeature, BatchEncoding]]:
403
        if images is not None:
404
            assert len(prompts) == len(images)
405

406
407
408
409
410
411
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

412
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
413
        for i, prompt in enumerate(prompts):
414
            processor_kwargs: dict[str, Any] = {
415
416
417
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
418
419
420
421
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
422
423
424
425
426
427
428
429
430
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
431
432

            inputs = self.processor(**processor_kwargs)
433
434
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
435

436
437
438
439
            all_inputs.append(inputs)

        return all_inputs

440
441
442
443
444
445
446
447
448
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

449
    def classify(self, prompts: list[str]) -> list[str]:
450
451
452
453
454
455
456
457
458
459
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

460
461
    def generate(
        self,
462
        prompts: list[str],
463
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
464
        videos: Optional[PromptVideoInput] = None,
465
466
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
467
    ) -> list[tuple[list[list[int]], list[str]]]:
468
469
470
471
472
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

473
        outputs: list[tuple[list[list[int]], list[str]]] = []
474
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
475
            output_ids = self.model.generate(
476
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
477
478
479
                use_cache=True,
                **kwargs,
            )
480
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
481
482
483
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
484
485
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
486
487
488
489
490
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
491
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
492
        max_tokens: int,
493
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
494
        videos: Optional[PromptVideoInput] = None,
495
        audios: Optional[PromptAudioInput] = None,
496
        **kwargs: Any,
497
    ) -> list[tuple[list[int], str]]:
498
499
        outputs = self.generate(prompts,
                                do_sample=False,
500
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
501
                                images=images,
502
503
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
504
                                **kwargs)
505
506
507

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
508
509
510

    def generate_beam_search(
        self,
511
        prompts: list[str],
512
513
        beam_width: int,
        max_tokens: int,
514
515
516
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
517
    ) -> list[tuple[list[list[int]], list[str]]]:
518
519
520
521
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
522
523
524
525
526
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

527
528
529
530
531
532
533
534
535
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
536

537
538
    def generate_greedy_logprobs(
        self,
539
        prompts: list[str],
540
        max_tokens: int,
541
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
542
        videos: Optional[PromptVideoInput] = None,
543
        audios: Optional[PromptAudioInput] = None,
544
        **kwargs: Any,
545
    ) -> list[list[torch.Tensor]]:
546
547
548
549
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
550

551
        all_logprobs: list[list[torch.Tensor]] = []
552
        for inputs in all_inputs:
553
            output = self.model.generate(
554
                **self.wrap_device(inputs),
555
556
557
558
559
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
560
                **kwargs,
561
            )
562
563
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
564
565
566
            all_logprobs.append(seq_logprobs)
        return all_logprobs

567
    def _hidden_states_to_seq_logprobs(
568
        self,
569
570
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
571
572
        output_embeddings = self.model.get_output_embeddings()

573
        seq_logprobs: list[torch.Tensor] = []
574
575
576
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
577
578
579
580
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
581
                output_embeddings.weight.t(),
582
            )
583
584
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
585
586
587
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

588
589
590
591
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
592
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
593
        num_logprobs: int,
594
    ) -> tuple[list[dict[int, float]], int]:
595
596
597
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

598
        # convert to dict
599
        seq_logprobs_lst: list[dict[int, float]] = []
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

617
618
    def generate_greedy_logprobs_limit(
        self,
619
        prompts: list[str],
620
621
        max_tokens: int,
        num_logprobs: int,
622
623
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
624
        videos: Optional[PromptVideoInput] = None,
625
        **kwargs: Any,
626
    ) -> list[TokensTextLogprobs]:
627
628
629
630
631
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

632
633
634
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
635

636
        for inputs in all_inputs:
637
            output = self.model.generate(
638
                **self.wrap_device(inputs),
639
640
641
642
643
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
644
                **kwargs,
645
646
            )

647
648
649
650
651
652
653
654
655
656
657
658
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
659

660
661
662
663
664
665
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
666
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
667
668
        max_tokens: int,
        num_logprobs: int,
669
        images: Optional[PromptImageInput] = None,
670
        **kwargs: Any,
671
    ) -> list[TokensTextLogprobs]:
672
673
674
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
675

676
677
678
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
679

680
681
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
682
            processor_kwargs: dict[str, Any] = {
683
684
685
686
687
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
688

689
690
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
691
692
693
694

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
695
696
697
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
698
699
700
701
702
703
704
705

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
706
                **encoder_inputs,
707
708
709
710
711
712
713
714
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
715
716
717
718
719
720
721
722
723
724
725

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

726
727
728
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
729

730
    def predict(self, prompts: list[list[str]]) -> torch.Tensor:
731
732
        return self.model.predict(prompts, convert_to_tensor=True)

733
734
735
736
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
737
        del self.model
738
        cleanup_dist_env_and_memory()
739

Woosuk Kwon's avatar
Woosuk Kwon committed
740

Cyrus Leung's avatar
Cyrus Leung committed
741
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
742
743
744
745
746
def hf_runner():
    return HfRunner


class VllmRunner:
747
748
    """
    The default value of some arguments have been modified from
749
    {class}`~vllm.LLM` as follows:
750

751
752
753
754
755
756
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
    - `block_size`: Set to `16` instead of `None` to reduce memory usage.
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
757
    - `enforce_eager`: Set to `False` to test CUDA graph.
758
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
759
760
761
762

    def __init__(
        self,
        model_name: str,
763
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
764
        tokenizer_name: Optional[str] = None,
765
        tokenizer_mode: str = "auto",
766
767
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
768
        max_model_len: int = 1024,
769
        dtype: str = "auto",
770
        disable_log_stats: bool = True,
771
        tensor_parallel_size: int = 1,
772
        block_size: int = 16,
773
        enable_chunked_prefill: Optional[bool] = False,
774
        swap_space: int = 4,
775
        enforce_eager: Optional[bool] = False,
776
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
777
778
779
    ) -> None:
        self.model = LLM(
            model=model_name,
780
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
781
            tokenizer=tokenizer_name,
782
            tokenizer_mode=tokenizer_mode,
783
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
784
            dtype=dtype,
785
            seed=seed,
786
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
787
            enforce_eager=enforce_eager,
788
            disable_log_stats=disable_log_stats,
789
            tensor_parallel_size=tensor_parallel_size,
790
            max_model_len=max_model_len,
791
792
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
793
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
794
795
        )

796
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
797
        self,
798
        prompts: Union[list[str], list[torch.Tensor]],
799
        images: Optional[PromptImageInput] = None,
800
801
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
802
    ) -> list[TextPrompt]:
803

804
805
806
807
808
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
809

810
811
812
813
814
815
816
817
818
819
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

820
821
822
823
824
825
            text_prompt_kwargs = {
                ("prompt" if isinstance(prompt, str) else "prompt_embeds"):
                prompt,
                "multi_modal_data": multi_modal_data or None
            }
            inputs.append(TextPrompt(**text_prompt_kwargs))
826
827
828
829
830

        return inputs

    def generate(
        self,
831
        prompts: Union[list[str], list[torch.Tensor]],
832
833
834
835
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
836
        **kwargs: Any,
837
    ) -> list[tuple[list[list[int]], list[str]]]:
838
839
840
841
842
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

843
        req_outputs = self.model.generate(inputs,
844
845
                                          sampling_params=sampling_params,
                                          **kwargs)
846

847
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
848
849
850
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
851
852
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
853
854
            for sample in req_output.outputs:
                output_str = sample.text
855
                output_ids = list(sample.token_ids)
856
                req_sample_output_ids.append(prompt_ids + output_ids)
857
                req_sample_output_strs.append((prompt_str or "") + output_str)
858
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
859
860
        return outputs

861
    @staticmethod
862
    def _final_steps_generate_w_logprobs(
863
864
865
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
866
        for req_output in req_outputs:
867
            assert len(req_output.outputs) > 0
868
869
            for sample in req_output.outputs:
                output_str = sample.text
870
                output_ids = list(sample.token_ids)
871
                output_logprobs = sample.logprobs
872
873
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
874
875
        return outputs

876
877
    def generate_w_logprobs(
        self,
878
        prompts: list[str],
879
        sampling_params: SamplingParams,
880
881
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
882
        videos: Optional[PromptVideoInput] = None,
883
        **kwargs: Any,
884
885
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
886
887
888
889
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
890

891
        req_outputs = self.model.generate(inputs,
892
893
                                          sampling_params=sampling_params,
                                          **kwargs)
894
895
896
897
898
899
900

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
901
902
903

    def generate_encoder_decoder_w_logprobs(
        self,
904
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
905
        sampling_params: SamplingParams,
906
907
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
908
909
910
911
912
913
914
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
915
916
917
918
919
920
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
921

Woosuk Kwon's avatar
Woosuk Kwon committed
922
923
    def generate_greedy(
        self,
924
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
925
        max_tokens: int,
926
        images: Optional[PromptImageInput] = None,
927
928
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
929
        **kwargs: Any,
930
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
931
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
932
933
934
935
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
936
937
                                audios=audios,
                                **kwargs)
938
939
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
940

941
942
    def generate_greedy_logprobs(
        self,
943
        prompts: list[str],
944
945
        max_tokens: int,
        num_logprobs: int,
946
        num_prompt_logprobs: Optional[int] = None,
947
948
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
949
        videos: Optional[PromptVideoInput] = None,
950
951
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
952
        **kwargs: Any,
953
954
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
955
956
957
958
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
959
            prompt_logprobs=num_prompt_logprobs,
960
961
            stop_token_ids=stop_token_ids,
            stop=stop)
962
963
964
965
966

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
967
968
                                        videos=videos,
                                        **kwargs)
969

970
971
    def generate_encoder_decoder_greedy_logprobs(
        self,
972
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
973
974
        max_tokens: int,
        num_logprobs: int,
975
        num_prompt_logprobs: Optional[int] = None,
976
        skip_special_tokens: bool = True,
977
978
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
979
980
981
982
983
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
984
            skip_special_tokens=skip_special_tokens,
985
        )
986
987
988
989
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

990
        return self.generate_encoder_decoder_w_logprobs(
991
992
            encoder_decoder_prompts, greedy_logprobs_params)

993
    def generate_beam_search(
994
        self,
995
        prompts: list[str],
996
997
        beam_width: int,
        max_tokens: int,
998
999
1000
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
1001
    ) -> list[tuple[list[list[int]], list[str]]]:
1002
1003
1004
1005
1006
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1007
        outputs = self.model.beam_search(
1008
            inputs,
1009
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
1010
1011
1012
1013
1014
1015
1016
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1017
    def classify(self, prompts: list[str]) -> list[list[float]]:
1018
1019
1020
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

1021
1022
1023
1024
1025
1026
1027
    def encode(self,
               prompts: list[str],
               images: Optional[PromptImageInput] = None,
               videos: Optional[PromptVideoInput] = None,
               audios: Optional[PromptAudioInput] = None,
               *args,
               **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1028
1029
1030
1031
1032
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1033
        req_outputs = self.model.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1034
        return [req_output.outputs.embedding for req_output in req_outputs]
1035

1036
1037
    def score(
        self,
1038
1039
1040
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
    ) -> list[float]:
1041
        req_outputs = self.model.score(text_1, text_2)
1042
        return [req_output.outputs.score for req_output in req_outputs]
1043

1044
1045
1046
1047
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

1048
1049
1050
1051
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1052
        del self.model
1053
        cleanup_dist_env_and_memory()
1054

Woosuk Kwon's avatar
Woosuk Kwon committed
1055

1056
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1057
1058
def vllm_runner():
    return VllmRunner
1059
1060


1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1075
1076
1077
1078
1079
1080
1081


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1082
    return cuda_device_count_stateless()
1083
1084
1085


temp_dir = tempfile.gettempdir()
1086
1087
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1088
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1089
1090
1091
1092


@pytest.fixture
def dummy_opt_path():
1093
1094
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1095
        snapshot_download(repo_id="facebook/opt-125m",
1096
                          local_dir=_dummy_opt_path,
1097
1098
1099
1100
1101
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1102
        with open(json_path) as f:
1103
1104
1105
1106
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1121
        with open(json_path) as f:
1122
1123
1124
1125
1126
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1140
        with open(json_path) as f:
1141
1142
1143
1144
1145
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1164
            item.add_marker(skip_optional)
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")