conftest.py 41 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import os
5
import tempfile
6
from enum import Enum
7
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
import pytest
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from huggingface_hub import snapshot_download
15
from PIL import Image
16
17
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
18
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm import LLM, SamplingParams
23
from vllm.assets.audio import AudioAsset
24
from vllm.assets.image import ImageAsset
25
from vllm.assets.video import VideoAsset
26
from vllm.config import TaskOption, _get_and_verify_dtype
27
from vllm.connections import global_http_connection
28
from vllm.distributed import (cleanup_dist_env_and_memory,
29
30
                              init_distributed_environment,
                              initialize_model_parallel)
31
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
32
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
33
from vllm.logger import init_logger
34
from vllm.outputs import RequestOutput
35
from vllm.sampling_params import BeamSearchParams
36
from vllm.transformers_utils.utils import maybe_model_redirect
37
from vllm.utils import cuda_device_count_stateless
38

39
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
40

41
42
43
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
44
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
45

Cyrus Leung's avatar
Cyrus Leung committed
46
_M = TypeVar("_M")
47

48
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
49
50

PromptImageInput = _PromptMultiModalInput[Image.Image]
51
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
52
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
53

54

55
def _read_prompts(filename: str) -> list[str]:
56
    with open(filename) as f:
57
58
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60


61
class ImageAssetPrompts(TypedDict):
62
63
    stop_sign: str
    cherry_blossom: str
64
65


66
class ImageTestAssets(list[ImageAsset]):
67
68

    def __init__(self) -> None:
69
70
71
72
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
73

74
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
75
76
77
78
79
80
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
81
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
82
83


84
85
class VideoAssetPrompts(TypedDict):
    baby_reading: str
86
87


88
class VideoTestAssets(list[VideoAsset]):
89
90
91

    def __init__(self) -> None:
        super().__init__([
92
            VideoAsset("baby_reading"),
93
94
        ])

95
96
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
97
98


99
class AudioAssetPrompts(TypedDict):
100
101
102
103
    mary_had_lamb: str
    winning_call: str


104
class AudioTestAssets(list[AudioAsset]):
105
106
107
108
109
110
111

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

112
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
113
114
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

115

116
IMAGE_ASSETS = ImageTestAssets()
117
"""Singleton instance of {class}`ImageTestAssets`."""
118
VIDEO_ASSETS = VideoTestAssets()
119
"""Singleton instance of {class}`VideoTestAssets`."""
120
AUDIO_ASSETS = AudioTestAssets()
121
"""Singleton instance of {class}`AudioTestAssets`."""
122
123


124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
144
@pytest.fixture(params=[True, False])
145
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
146
147
148
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
149
    skip_v0 = request.node.get_closest_marker("skip_v0")
Joe Runde's avatar
Joe Runde committed
150
151
152
153
154
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
155
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
156
    else:
157
158
        if skip_v0:
            pytest.skip("Skipping test on vllm V0")
159
160
161
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
162
163


164
165
166
167
168
169
170
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


171
172
173
174
175
176
177
178
179
180
181
182
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
183
    cleanup_dist_env_and_memory()
184
185


186
@pytest.fixture()
187
def should_do_global_cleanup_after_test(request) -> bool:
188
189
190
191
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
192

193
    return not request.node.get_closest_marker("skip_global_cleanup")
194
195


196
@pytest.fixture(autouse=True)
197
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
198
    yield
199
    if should_do_global_cleanup_after_test:
200
        cleanup_dist_env_and_memory()
201
202


203
204
205
206
207
208
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
209
@pytest.fixture
210
def example_prompts() -> list[str]:
211
212
    prompts = []
    for filename in _TEST_PROMPTS:
213
        prompts += _read_prompts(filename)
214
215
216
    return prompts


217
218
219
220
221
222
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


223
224
225
226
227
228
229
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


230
@pytest.fixture
231
def example_encoder_decoder_prompts(
232
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
233
234
235
236
237
238
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
255
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
256
        DecoderPromptType.EMPTY_STR:
257
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
258
        DecoderPromptType.CUSTOM:
259
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
260
261
262
    }


263
@pytest.fixture
264
def example_long_prompts() -> list[str]:
265
266
    prompts = []
    for filename in _LONG_PROMPTS:
267
        prompts += _read_prompts(filename)
268
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270


271
@pytest.fixture(scope="session")
272
def image_assets() -> ImageTestAssets:
273
274
275
    return IMAGE_ASSETS


276
@pytest.fixture(scope="session")
277
def video_assets() -> VideoTestAssets:
278
279
280
    return VIDEO_ASSETS


281
@pytest.fixture(scope="session")
282
def audio_assets() -> AudioTestAssets:
283
284
285
    return AUDIO_ASSETS


286
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
287
_R = TypeVar("_R")
288

Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
291

class HfRunner:

292
    def get_default_device(self):
293
        from vllm.platforms import current_platform
294

295
296
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
297
298

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
299
300
301
        if x is None or isinstance(x, (bool, )):
            return x

302
        if device is None:
303
            device = self.device
304

305
306
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
307

308
309
310
311
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
312

Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
315
    def __init__(
        self,
        model_name: str,
316
        dtype: str = "auto",
317
        *,
318
        model_kwargs: Optional[dict[str, Any]] = None,
319
        trust_remote_code: bool = True,
320
        is_sentence_transformer: bool = False,
321
        is_cross_encoder: bool = False,
322
        skip_tokenizer_init: bool = False,
323
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
324
    ) -> None:
325
        model_name = maybe_model_redirect(model_name)
326
        self.model_name = model_name
327

328
329
        self.config = AutoConfig.from_pretrained(
            model_name,
330
            trust_remote_code=trust_remote_code,
331
332
        )
        self.device = self.get_default_device()
333
334
335
336
337
338
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
339
340
341
342

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

343
        if is_sentence_transformer:
344
345
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
346
347
348
349
350

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
351
                trust_remote_code=trust_remote_code,
352
            )
353
354
355
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
356
357
358
359
360

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
361
                trust_remote_code=trust_remote_code,
362
            )
363
        else:
364
365
            model = auto_cls.from_pretrained(
                model_name,
366
                trust_remote_code=trust_remote_code,
367
368
369
                **model_kwargs,
            )

370
371
372
373
374
375
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

376
377
378
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
379
                model = model.to(device=self.device)
380
381

            self.model = model
382

383
384
385
386
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
387
                trust_remote_code=trust_remote_code,
388
            )
389

390
391
392
393
394
395
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
396
            trust_remote_code=trust_remote_code,
397
        )
398
399
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
400

401
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
402
        self,
403
        prompts: list[str],
404
        images: Optional[PromptImageInput] = None,
405
406
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
407
    ) -> list[Union[BatchFeature, BatchEncoding]]:
408
        if images is not None:
409
            assert len(prompts) == len(images)
410

411
412
413
414
415
416
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

417
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
418
        for i, prompt in enumerate(prompts):
419
            processor_kwargs: dict[str, Any] = {
420
421
422
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
423
424
425
426
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
427
428
429
430
431
432
433
434
435
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
436
437

            inputs = self.processor(**processor_kwargs)
438
439
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
440

441
442
443
444
            all_inputs.append(inputs)

        return all_inputs

445
446
447
448
449
450
451
452
453
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

454
    def classify(self, prompts: list[str]) -> list[str]:
455
456
457
458
459
460
461
462
463
464
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

465
466
    def generate(
        self,
467
        prompts: list[str],
468
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
469
        videos: Optional[PromptVideoInput] = None,
470
471
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
472
    ) -> list[tuple[list[list[int]], list[str]]]:
473
474
475
476
477
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

478
        outputs: list[tuple[list[list[int]], list[str]]] = []
479
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
480
            output_ids = self.model.generate(
481
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
482
483
484
                use_cache=True,
                **kwargs,
            )
485
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
486
487
488
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
489
490
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
491
492
493
494
495
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
496
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
497
        max_tokens: int,
498
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
499
        videos: Optional[PromptVideoInput] = None,
500
        audios: Optional[PromptAudioInput] = None,
501
        **kwargs: Any,
502
    ) -> list[tuple[list[int], str]]:
503
504
        outputs = self.generate(prompts,
                                do_sample=False,
505
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
506
                                images=images,
507
508
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
509
                                **kwargs)
510
511
512

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
513
514
515

    def generate_beam_search(
        self,
516
        prompts: list[str],
517
518
        beam_width: int,
        max_tokens: int,
519
520
521
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
522
    ) -> list[tuple[list[list[int]], list[str]]]:
523
524
525
526
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
527
528
529
530
531
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

532
533
534
535
536
537
538
539
540
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
541

542
543
    def generate_greedy_logprobs(
        self,
544
        prompts: list[str],
545
        max_tokens: int,
546
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
547
        videos: Optional[PromptVideoInput] = None,
548
        audios: Optional[PromptAudioInput] = None,
549
        **kwargs: Any,
550
    ) -> list[list[torch.Tensor]]:
551
552
553
554
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
555

556
        all_logprobs: list[list[torch.Tensor]] = []
557
        for inputs in all_inputs:
558
            output = self.model.generate(
559
                **self.wrap_device(inputs),
560
561
562
563
564
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
565
                **kwargs,
566
            )
567
568
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
569
570
571
            all_logprobs.append(seq_logprobs)
        return all_logprobs

572
    def _hidden_states_to_seq_logprobs(
573
        self,
574
575
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
576
577
        output_embeddings = self.model.get_output_embeddings()

578
        seq_logprobs: list[torch.Tensor] = []
579
580
581
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
582
583
584
585
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
586
                output_embeddings.weight.t(),
587
            )
588
589
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
590
591
592
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

593
594
595
596
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
597
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
598
        num_logprobs: int,
599
    ) -> tuple[list[dict[int, float]], int]:
600
601
602
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

603
        # convert to dict
604
        seq_logprobs_lst: list[dict[int, float]] = []
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

622
623
    def generate_greedy_logprobs_limit(
        self,
624
        prompts: list[str],
625
626
        max_tokens: int,
        num_logprobs: int,
627
628
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
629
        videos: Optional[PromptVideoInput] = None,
630
        **kwargs: Any,
631
    ) -> list[TokensTextLogprobs]:
632
633
634
635
636
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

637
638
639
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
640

641
        for inputs in all_inputs:
642
            output = self.model.generate(
643
                **self.wrap_device(inputs),
644
645
646
647
648
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
649
                **kwargs,
650
651
            )

652
653
654
655
656
657
658
659
660
661
662
663
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
664

665
666
667
668
669
670
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
671
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
672
673
        max_tokens: int,
        num_logprobs: int,
674
        images: Optional[PromptImageInput] = None,
675
        **kwargs: Any,
676
    ) -> list[TokensTextLogprobs]:
677
678
679
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
680

681
682
683
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
684

685
686
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
687
            processor_kwargs: dict[str, Any] = {
688
689
690
691
692
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
693

694
695
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
696
697
698
699

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
700
701
702
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
703
704
705
706
707
708
709
710

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
711
                **encoder_inputs,
712
713
714
715
716
717
718
719
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
720
721
722
723
724
725
726
727
728
729
730

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

731
732
733
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
734

735
736
737
738
739
740
    def predict(self, prompts: list[list[str]], *args,
                **kwargs) -> torch.Tensor:
        return self.model.predict(prompts,
                                  *args,
                                  convert_to_tensor=True,
                                  **kwargs)
741

742
743
744
745
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
746
        del self.model
747
        cleanup_dist_env_and_memory()
748

Woosuk Kwon's avatar
Woosuk Kwon committed
749

Cyrus Leung's avatar
Cyrus Leung committed
750
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
751
752
753
754
755
def hf_runner():
    return HfRunner


class VllmRunner:
756
757
    """
    The default value of some arguments have been modified from
758
    {class}`~vllm.LLM` as follows:
759

760
761
762
763
764
765
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
    - `block_size`: Set to `16` instead of `None` to reduce memory usage.
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
766
    - `enforce_eager`: Set to `False` to test CUDA graph.
767
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
768
769
770
771

    def __init__(
        self,
        model_name: str,
772
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
773
        tokenizer_name: Optional[str] = None,
774
        tokenizer_mode: str = "auto",
775
776
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
777
        max_model_len: int = 1024,
778
        dtype: str = "auto",
779
        disable_log_stats: bool = True,
780
        tensor_parallel_size: int = 1,
781
        block_size: int = 16,
782
        enable_chunked_prefill: Optional[bool] = False,
783
        swap_space: int = 4,
784
        enforce_eager: Optional[bool] = False,
785
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
786
787
788
    ) -> None:
        self.model = LLM(
            model=model_name,
789
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
790
            tokenizer=tokenizer_name,
791
            tokenizer_mode=tokenizer_mode,
792
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
793
            dtype=dtype,
794
            seed=seed,
795
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
796
            enforce_eager=enforce_eager,
797
            disable_log_stats=disable_log_stats,
798
            tensor_parallel_size=tensor_parallel_size,
799
            max_model_len=max_model_len,
800
801
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
802
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
803
804
        )

805
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
806
        self,
807
        prompts: Union[list[str], list[torch.Tensor]],
808
        images: Optional[PromptImageInput] = None,
809
810
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
811
    ) -> list[TextPrompt]:
812

813
814
815
816
817
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
818

819
820
821
822
823
824
825
826
827
828
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

829
830
831
832
833
834
            text_prompt_kwargs = {
                ("prompt" if isinstance(prompt, str) else "prompt_embeds"):
                prompt,
                "multi_modal_data": multi_modal_data or None
            }
            inputs.append(TextPrompt(**text_prompt_kwargs))
835
836
837
838
839

        return inputs

    def generate(
        self,
840
        prompts: Union[list[str], list[torch.Tensor]],
841
842
843
844
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
845
        **kwargs: Any,
846
    ) -> list[tuple[list[list[int]], list[str]]]:
847
848
849
850
851
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

852
        req_outputs = self.model.generate(inputs,
853
854
                                          sampling_params=sampling_params,
                                          **kwargs)
855

856
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
857
858
859
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
860
861
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
862
863
            for sample in req_output.outputs:
                output_str = sample.text
864
                output_ids = list(sample.token_ids)
865
                req_sample_output_ids.append(prompt_ids + output_ids)
866
                req_sample_output_strs.append((prompt_str or "") + output_str)
867
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
868
869
        return outputs

870
    @staticmethod
871
    def _final_steps_generate_w_logprobs(
872
873
874
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
875
        for req_output in req_outputs:
876
            assert len(req_output.outputs) > 0
877
878
            for sample in req_output.outputs:
                output_str = sample.text
879
                output_ids = list(sample.token_ids)
880
                output_logprobs = sample.logprobs
881
882
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
883
884
        return outputs

885
886
    def generate_w_logprobs(
        self,
887
        prompts: list[str],
888
        sampling_params: SamplingParams,
889
890
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
891
        videos: Optional[PromptVideoInput] = None,
892
        **kwargs: Any,
893
894
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
895
896
897
898
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
899

900
        req_outputs = self.model.generate(inputs,
901
902
                                          sampling_params=sampling_params,
                                          **kwargs)
903
904
905
906
907
908
909

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
910
911
912

    def generate_encoder_decoder_w_logprobs(
        self,
913
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
914
        sampling_params: SamplingParams,
915
916
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
917
918
919
920
921
922
923
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
924
925
926
927
928
929
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
930

Woosuk Kwon's avatar
Woosuk Kwon committed
931
932
    def generate_greedy(
        self,
933
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
934
        max_tokens: int,
935
        images: Optional[PromptImageInput] = None,
936
937
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
938
        **kwargs: Any,
939
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
940
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
941
942
943
944
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
945
946
                                audios=audios,
                                **kwargs)
947
948
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
949

950
951
    def generate_greedy_logprobs(
        self,
952
        prompts: list[str],
953
954
        max_tokens: int,
        num_logprobs: int,
955
        num_prompt_logprobs: Optional[int] = None,
956
957
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
958
        videos: Optional[PromptVideoInput] = None,
959
960
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
961
        **kwargs: Any,
962
963
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
964
965
966
967
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
968
            prompt_logprobs=num_prompt_logprobs,
969
970
            stop_token_ids=stop_token_ids,
            stop=stop)
971
972
973
974
975

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
976
977
                                        videos=videos,
                                        **kwargs)
978

979
980
    def generate_encoder_decoder_greedy_logprobs(
        self,
981
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
982
983
        max_tokens: int,
        num_logprobs: int,
984
        num_prompt_logprobs: Optional[int] = None,
985
        skip_special_tokens: bool = True,
986
987
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
988
989
990
991
992
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
993
            skip_special_tokens=skip_special_tokens,
994
        )
995
996
997
998
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

999
        return self.generate_encoder_decoder_w_logprobs(
1000
1001
            encoder_decoder_prompts, greedy_logprobs_params)

1002
    def generate_beam_search(
1003
        self,
1004
        prompts: list[str],
1005
1006
        beam_width: int,
        max_tokens: int,
1007
1008
1009
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
1010
    ) -> list[tuple[list[list[int]], list[str]]]:
1011
1012
1013
1014
1015
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1016
        outputs = self.model.beam_search(
1017
            inputs,
1018
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
1019
1020
1021
1022
1023
1024
1025
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1026
    def classify(self, prompts: list[str]) -> list[list[float]]:
1027
1028
1029
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

1030
1031
1032
1033
1034
1035
1036
    def embed(self,
              prompts: list[str],
              images: Optional[PromptImageInput] = None,
              videos: Optional[PromptVideoInput] = None,
              audios: Optional[PromptAudioInput] = None,
              *args,
              **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1037
1038
1039
1040
1041
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1042
        req_outputs = self.model.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1043
        return [req_output.outputs.embedding for req_output in req_outputs]
1044

1045
1046
1047
1048
    def encode(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.model.encode(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

1049
1050
    def score(
        self,
1051
1052
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
1053
1054
        *args,
        **kwargs,
1055
    ) -> list[float]:
1056
        req_outputs = self.model.score(text_1, text_2, *args, **kwargs)
1057
        return [req_output.outputs.score for req_output in req_outputs]
1058

1059
1060
1061
1062
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

1063
1064
1065
1066
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1067
        del self.model
1068
        cleanup_dist_env_and_memory()
1069

Woosuk Kwon's avatar
Woosuk Kwon committed
1070

1071
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1072
1073
def vllm_runner():
    return VllmRunner
1074
1075


1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1090
1091
1092
1093
1094
1095
1096


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1097
    return cuda_device_count_stateless()
1098
1099
1100


temp_dir = tempfile.gettempdir()
1101
1102
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1103
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1104
1105
1106
1107


@pytest.fixture
def dummy_opt_path():
1108
1109
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1110
        snapshot_download(repo_id="facebook/opt-125m",
1111
                          local_dir=_dummy_opt_path,
1112
1113
1114
1115
1116
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1117
        with open(json_path) as f:
1118
1119
1120
1121
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1136
        with open(json_path) as f:
1137
1138
1139
1140
1141
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1155
        with open(json_path) as f:
1156
1157
1158
1159
1160
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1179
            item.add_marker(skip_optional)
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")