conftest.py 39.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import json
3
import os
4
import tempfile
5
from enum import Enum
6
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
import pytest
import torch
11
import torch.nn as nn
12
import torch.nn.functional as F
13
from huggingface_hub import snapshot_download
14
from PIL import Image
15
16
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
17
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
20
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm import LLM, SamplingParams
22
from vllm.assets.audio import AudioAsset
23
from vllm.assets.image import ImageAsset
24
from vllm.assets.video import VideoAsset
25
from vllm.config import TaskOption, _get_and_verify_dtype
26
from vllm.connections import global_http_connection
27
from vllm.distributed import (cleanup_dist_env_and_memory,
28
29
                              init_distributed_environment,
                              initialize_model_parallel)
30
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
31
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
32
from vllm.logger import init_logger
33
from vllm.outputs import RequestOutput
34
from vllm.sampling_params import BeamSearchParams
35
from vllm.utils import cuda_device_count_stateless
36

37
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
38

39
40
41
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
42
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
43

Cyrus Leung's avatar
Cyrus Leung committed
44
_M = TypeVar("_M")
45

46
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
47
48

PromptImageInput = _PromptMultiModalInput[Image.Image]
49
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
50
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
51

52

53
def _read_prompts(filename: str) -> list[str]:
54
    with open(filename) as f:
55
56
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58


59
class ImageAssetPrompts(TypedDict):
60
61
    stop_sign: str
    cherry_blossom: str
62
63


64
class ImageTestAssets(list[ImageAsset]):
65
66

    def __init__(self) -> None:
67
68
69
70
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
71

72
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
73
74
75
76
77
78
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
79
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
80
81


82
83
class VideoAssetPrompts(TypedDict):
    baby_reading: str
84
85


86
class VideoTestAssets(list[VideoAsset]):
87
88
89

    def __init__(self) -> None:
        super().__init__([
90
            VideoAsset("baby_reading"),
91
92
        ])

93
94
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
95
96


97
class AudioAssetPrompts(TypedDict):
98
99
100
101
    mary_had_lamb: str
    winning_call: str


102
class AudioTestAssets(list[AudioAsset]):
103
104
105
106
107
108
109

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

110
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
111
112
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

113

114
IMAGE_ASSETS = ImageTestAssets()
115
"""Singleton instance of {class}`ImageTestAssets`."""
116
VIDEO_ASSETS = VideoTestAssets()
117
"""Singleton instance of {class}`VideoTestAssets`."""
118
AUDIO_ASSETS = AudioTestAssets()
119
"""Singleton instance of {class}`AudioTestAssets`."""
120
121


122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
142
@pytest.fixture(params=[True, False])
143
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
144
145
146
147
148
149
150
151
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
152
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
153
    else:
154
155
156
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
157
158


159
160
161
162
163
164
165
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


166
167
168
169
170
171
172
173
174
175
176
177
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
178
    cleanup_dist_env_and_memory()
179
180


181
@pytest.fixture()
182
def should_do_global_cleanup_after_test(request) -> bool:
183
184
185
186
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
187

188
    return not request.node.get_closest_marker("skip_global_cleanup")
189
190


191
@pytest.fixture(autouse=True)
192
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
193
    yield
194
    if should_do_global_cleanup_after_test:
195
        cleanup_dist_env_and_memory()
196
197


198
199
200
201
202
203
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
204
@pytest.fixture
205
def example_prompts() -> list[str]:
206
207
    prompts = []
    for filename in _TEST_PROMPTS:
208
        prompts += _read_prompts(filename)
209
210
211
    return prompts


212
213
214
215
216
217
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


218
219
220
221
222
223
224
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


225
@pytest.fixture
226
def example_encoder_decoder_prompts(
227
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
228
229
230
231
232
233
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
250
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
251
        DecoderPromptType.EMPTY_STR:
252
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
253
        DecoderPromptType.CUSTOM:
254
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
255
256
257
    }


258
@pytest.fixture
259
def example_long_prompts() -> list[str]:
260
261
    prompts = []
    for filename in _LONG_PROMPTS:
262
        prompts += _read_prompts(filename)
263
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265


266
@pytest.fixture(scope="session")
267
def image_assets() -> ImageTestAssets:
268
269
270
    return IMAGE_ASSETS


271
@pytest.fixture(scope="session")
272
def video_assets() -> VideoTestAssets:
273
274
275
    return VIDEO_ASSETS


276
@pytest.fixture(scope="session")
277
def audio_assets() -> AudioTestAssets:
278
279
280
    return AUDIO_ASSETS


281
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
282
_R = TypeVar("_R")
283

Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286

class HfRunner:

287
    def get_default_device(self):
288
        from vllm.platforms import current_platform
289

290
291
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
292
293

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
294
295
296
        if x is None or isinstance(x, (bool, )):
            return x

297
        if device is None:
298
            device = self.device
299

300
301
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
302

303
304
305
306
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
307

Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
    def __init__(
        self,
        model_name: str,
311
        dtype: str = "auto",
312
        *,
313
        model_kwargs: Optional[dict[str, Any]] = None,
314
        is_sentence_transformer: bool = False,
315
        is_cross_encoder: bool = False,
316
        skip_tokenizer_init: bool = False,
317
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
318
    ) -> None:
319
        self.model_name = model_name
320

321
322
323
324
325
326
327
328
329
330
        self.config = AutoConfig.from_pretrained(
            model_name,
            trust_remote_code=True,
        )
        self.device = self.get_default_device()
        self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

331
        if is_sentence_transformer:
332
333
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
334
335
336
337
338
339
340

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
                trust_remote_code=True,
            )
341
342
343
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
344
345
346
347
348
349
350

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
                trust_remote_code=True,
            )
351
        else:
352
353
354
355
356
357
358
359
360
361
362
363
            model = auto_cls.from_pretrained(
                model_name,
                trust_remote_code=True,
                **model_kwargs,
            )

            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
                model = model.to(self.device)

            self.model = model
364

365
366
367
368
369
370
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
371

372
373
374
375
376
377
378
379
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
380
381
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
382

383
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
384
        self,
385
        prompts: list[str],
386
        images: Optional[PromptImageInput] = None,
387
388
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
389
    ) -> list[Union[BatchFeature, BatchEncoding]]:
390
        if images is not None:
391
            assert len(prompts) == len(images)
392

393
394
395
396
397
398
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

399
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
400
        for i, prompt in enumerate(prompts):
401
            processor_kwargs: dict[str, Any] = {
402
403
404
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
405
406
407
408
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
409
410
411
412
413
414
415
416
417
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
418
419

            inputs = self.processor(**processor_kwargs)
420
421
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
422

423
424
425
426
            all_inputs.append(inputs)

        return all_inputs

427
    def classify(self, prompts: list[str]) -> list[str]:
428
429
430
431
432
433
434
435
436
437
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

438
439
    def generate(
        self,
440
        prompts: list[str],
441
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
442
        videos: Optional[PromptVideoInput] = None,
443
444
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
445
    ) -> list[tuple[list[list[int]], list[str]]]:
446
447
448
449
450
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

451
        outputs: list[tuple[list[list[int]], list[str]]] = []
452
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
453
            output_ids = self.model.generate(
454
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
455
456
457
                use_cache=True,
                **kwargs,
            )
458
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
459
460
461
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
462
463
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
464
465
466
467
468
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
469
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
470
        max_tokens: int,
471
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
472
        videos: Optional[PromptVideoInput] = None,
473
        audios: Optional[PromptAudioInput] = None,
474
        **kwargs: Any,
475
    ) -> list[tuple[list[int], str]]:
476
477
        outputs = self.generate(prompts,
                                do_sample=False,
478
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
479
                                images=images,
480
481
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
482
                                **kwargs)
483
484
485

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
486
487
488

    def generate_beam_search(
        self,
489
        prompts: list[str],
490
491
        beam_width: int,
        max_tokens: int,
492
493
494
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
495
    ) -> list[tuple[list[list[int]], list[str]]]:
496
497
498
499
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
500
501
502
503
504
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

505
506
507
508
509
510
511
512
513
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
514

515
516
    def generate_greedy_logprobs(
        self,
517
        prompts: list[str],
518
        max_tokens: int,
519
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
520
        videos: Optional[PromptVideoInput] = None,
521
        audios: Optional[PromptAudioInput] = None,
522
        **kwargs: Any,
523
    ) -> list[list[torch.Tensor]]:
524
525
526
527
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
528

529
        all_logprobs: list[list[torch.Tensor]] = []
530
        for inputs in all_inputs:
531
            output = self.model.generate(
532
                **self.wrap_device(inputs),
533
534
535
536
537
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
538
                **kwargs,
539
            )
540
541
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
542
543
544
            all_logprobs.append(seq_logprobs)
        return all_logprobs

545
    def _hidden_states_to_seq_logprobs(
546
        self,
547
548
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
549
550
        output_embeddings = self.model.get_output_embeddings()

551
        seq_logprobs: list[torch.Tensor] = []
552
553
554
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
555
556
557
558
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
559
                output_embeddings.weight.t(),
560
            )
561
562
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
563
564
565
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

566
567
568
569
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
570
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
571
        num_logprobs: int,
572
    ) -> tuple[list[dict[int, float]], int]:
573
574
575
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

576
        # convert to dict
577
        seq_logprobs_lst: list[dict[int, float]] = []
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

595
596
    def generate_greedy_logprobs_limit(
        self,
597
        prompts: list[str],
598
599
        max_tokens: int,
        num_logprobs: int,
600
601
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
602
        videos: Optional[PromptVideoInput] = None,
603
        **kwargs: Any,
604
    ) -> list[TokensTextLogprobs]:
605
606
607
608
609
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

610
611
612
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
613

614
        for inputs in all_inputs:
615
            output = self.model.generate(
616
                **self.wrap_device(inputs),
617
618
619
620
621
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
622
                **kwargs,
623
624
            )

625
626
627
628
629
630
631
632
633
634
635
636
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
637

638
639
640
641
642
643
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
644
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
645
646
        max_tokens: int,
        num_logprobs: int,
647
        images: Optional[PromptImageInput] = None,
648
        **kwargs: Any,
649
    ) -> list[TokensTextLogprobs]:
650
651
652
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
653

654
655
656
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
657

658
659
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
660
            processor_kwargs: dict[str, Any] = {
661
662
663
664
665
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
666

667
668
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
669
670
671
672

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
673
674
675
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
676
677
678
679
680
681
682
683

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
684
                **encoder_inputs,
685
686
687
688
689
690
691
692
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
693
694
695
696
697
698
699
700
701
702
703

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

704
705
706
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
707

708
    def predict(self, prompts: list[list[str]]) -> torch.Tensor:
709
710
        return self.model.predict(prompts, convert_to_tensor=True)

711
712
713
714
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
715
        del self.model
716
        cleanup_dist_env_and_memory()
717

Woosuk Kwon's avatar
Woosuk Kwon committed
718

Cyrus Leung's avatar
Cyrus Leung committed
719
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
720
721
722
723
724
def hf_runner():
    return HfRunner


class VllmRunner:
725
726
    """
    The default value of some arguments have been modified from
727
    {class}`~vllm.LLM` as follows:
728

729
730
731
732
733
734
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
    - `block_size`: Set to `16` instead of `None` to reduce memory usage.
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
735
    - `enforce_eager`: Set to `False` to test CUDA graph.
736
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
737
738
739
740

    def __init__(
        self,
        model_name: str,
741
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
742
        tokenizer_name: Optional[str] = None,
743
        tokenizer_mode: str = "auto",
744
745
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
746
        max_model_len: int = 1024,
747
        dtype: str = "auto",
748
        disable_log_stats: bool = True,
749
        tensor_parallel_size: int = 1,
750
        block_size: int = 16,
751
        enable_chunked_prefill: Optional[bool] = False,
752
        swap_space: int = 4,
753
        enforce_eager: Optional[bool] = False,
754
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
755
756
757
    ) -> None:
        self.model = LLM(
            model=model_name,
758
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
759
            tokenizer=tokenizer_name,
760
            tokenizer_mode=tokenizer_mode,
761
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
762
            dtype=dtype,
763
            seed=seed,
764
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
765
            enforce_eager=enforce_eager,
766
            disable_log_stats=disable_log_stats,
767
            tensor_parallel_size=tensor_parallel_size,
768
            max_model_len=max_model_len,
769
770
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
771
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
772
773
        )

774
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
775
        self,
776
        prompts: Union[list[str], list[torch.Tensor]],
777
        images: Optional[PromptImageInput] = None,
778
779
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
780
    ) -> list[TextPrompt]:
781

782
783
784
785
786
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
787

788
789
790
791
792
793
794
795
796
797
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

798
799
800
801
802
803
            text_prompt_kwargs = {
                ("prompt" if isinstance(prompt, str) else "prompt_embeds"):
                prompt,
                "multi_modal_data": multi_modal_data or None
            }
            inputs.append(TextPrompt(**text_prompt_kwargs))
804
805
806
807
808

        return inputs

    def generate(
        self,
809
        prompts: Union[list[str], list[torch.Tensor]],
810
811
812
813
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
814
        **kwargs: Any,
815
    ) -> list[tuple[list[list[int]], list[str]]]:
816
817
818
819
820
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

821
        req_outputs = self.model.generate(inputs,
822
823
                                          sampling_params=sampling_params,
                                          **kwargs)
824

825
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
826
827
828
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
829
830
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
831
832
            for sample in req_output.outputs:
                output_str = sample.text
833
                output_ids = list(sample.token_ids)
834
                req_sample_output_ids.append(prompt_ids + output_ids)
835
                req_sample_output_strs.append((prompt_str or "") + output_str)
836
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
837
838
        return outputs

839
    @staticmethod
840
    def _final_steps_generate_w_logprobs(
841
842
843
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
844
        for req_output in req_outputs:
845
            assert len(req_output.outputs) > 0
846
847
            for sample in req_output.outputs:
                output_str = sample.text
848
                output_ids = list(sample.token_ids)
849
                output_logprobs = sample.logprobs
850
851
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
852
853
        return outputs

854
855
    def generate_w_logprobs(
        self,
856
        prompts: list[str],
857
        sampling_params: SamplingParams,
858
859
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
860
        videos: Optional[PromptVideoInput] = None,
861
        **kwargs: Any,
862
863
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
864
865
866
867
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
868

869
        req_outputs = self.model.generate(inputs,
870
871
                                          sampling_params=sampling_params,
                                          **kwargs)
872
873
874
875
876
877
878

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
879
880
881

    def generate_encoder_decoder_w_logprobs(
        self,
882
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
883
        sampling_params: SamplingParams,
884
885
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
886
887
888
889
890
891
892
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
893
894
895
896
897
898
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
899

Woosuk Kwon's avatar
Woosuk Kwon committed
900
901
    def generate_greedy(
        self,
902
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
903
        max_tokens: int,
904
        images: Optional[PromptImageInput] = None,
905
906
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
907
        **kwargs: Any,
908
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
909
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
910
911
912
913
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
914
915
                                audios=audios,
                                **kwargs)
916
917
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
918

919
920
    def generate_greedy_logprobs(
        self,
921
        prompts: list[str],
922
923
        max_tokens: int,
        num_logprobs: int,
924
        num_prompt_logprobs: Optional[int] = None,
925
926
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
927
        videos: Optional[PromptVideoInput] = None,
928
929
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
930
        **kwargs: Any,
931
932
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
933
934
935
936
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
937
            prompt_logprobs=num_prompt_logprobs,
938
939
            stop_token_ids=stop_token_ids,
            stop=stop)
940
941
942
943
944

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
945
946
                                        videos=videos,
                                        **kwargs)
947

948
949
    def generate_encoder_decoder_greedy_logprobs(
        self,
950
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
951
952
        max_tokens: int,
        num_logprobs: int,
953
        num_prompt_logprobs: Optional[int] = None,
954
        skip_special_tokens: bool = True,
955
956
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
957
958
959
960
961
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
962
            skip_special_tokens=skip_special_tokens,
963
        )
964
965
966
967
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

968
        return self.generate_encoder_decoder_w_logprobs(
969
970
            encoder_decoder_prompts, greedy_logprobs_params)

971
    def generate_beam_search(
972
        self,
973
        prompts: list[str],
974
975
        beam_width: int,
        max_tokens: int,
976
977
978
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
979
    ) -> list[tuple[list[list[int]], list[str]]]:
980
981
982
983
984
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

985
        outputs = self.model.beam_search(
986
            inputs,
987
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
988
989
990
991
992
993
994
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

995
    def classify(self, prompts: list[str]) -> list[list[float]]:
996
997
998
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

999
1000
1001
1002
1003
1004
1005
    def encode(self,
               prompts: list[str],
               images: Optional[PromptImageInput] = None,
               videos: Optional[PromptVideoInput] = None,
               audios: Optional[PromptAudioInput] = None,
               *args,
               **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1006
1007
1008
1009
1010
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1011
        req_outputs = self.model.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1012
        return [req_output.outputs.embedding for req_output in req_outputs]
1013

1014
1015
    def score(
        self,
1016
1017
1018
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
    ) -> list[float]:
1019
        req_outputs = self.model.score(text_1, text_2)
1020
        return [req_output.outputs.score for req_output in req_outputs]
1021

1022
1023
1024
1025
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

1026
1027
1028
1029
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1030
        del self.model
1031
        cleanup_dist_env_and_memory()
1032

Woosuk Kwon's avatar
Woosuk Kwon committed
1033

1034
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1035
1036
def vllm_runner():
    return VllmRunner
1037
1038


1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1053
1054
1055
1056
1057
1058
1059


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1060
    return cuda_device_count_stateless()
1061
1062
1063


temp_dir = tempfile.gettempdir()
1064
1065
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1066
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1067
1068
1069
1070


@pytest.fixture
def dummy_opt_path():
1071
1072
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1073
        snapshot_download(repo_id="facebook/opt-125m",
1074
                          local_dir=_dummy_opt_path,
1075
1076
1077
1078
1079
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1080
        with open(json_path) as f:
1081
1082
1083
1084
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1099
        with open(json_path) as f:
1100
1101
1102
1103
1104
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1118
        with open(json_path) as f:
1119
1120
1121
1122
1123
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1142
            item.add_marker(skip_optional)
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")