"vscode:/vscode.git/clone" did not exist on "e837b624f25efe5d05412e95e18ed07ced880272"
conftest.py 40.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import json
3
import os
4
import tempfile
5
from enum import Enum
6
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
import pytest
import torch
11
import torch.nn as nn
12
import torch.nn.functional as F
13
from huggingface_hub import snapshot_download
14
from PIL import Image
15
16
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
17
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
20
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm import LLM, SamplingParams
22
from vllm.assets.audio import AudioAsset
23
from vllm.assets.image import ImageAsset
24
from vllm.assets.video import VideoAsset
25
from vllm.config import TaskOption, _get_and_verify_dtype
26
from vllm.connections import global_http_connection
27
from vllm.distributed import (cleanup_dist_env_and_memory,
28
29
                              init_distributed_environment,
                              initialize_model_parallel)
30
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
31
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
32
from vllm.logger import init_logger
33
from vllm.outputs import RequestOutput
34
from vllm.sampling_params import BeamSearchParams
35
from vllm.utils import cuda_device_count_stateless
36

37
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
38

39
40
41
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
42
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
43

Cyrus Leung's avatar
Cyrus Leung committed
44
_M = TypeVar("_M")
45

46
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
47
48

PromptImageInput = _PromptMultiModalInput[Image.Image]
49
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
50
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
51

52

53
def _read_prompts(filename: str) -> list[str]:
54
    with open(filename) as f:
55
56
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58


59
class ImageAssetPrompts(TypedDict):
60
61
    stop_sign: str
    cherry_blossom: str
62
63


64
class ImageTestAssets(list[ImageAsset]):
65
66

    def __init__(self) -> None:
67
68
69
70
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
71

72
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
73
74
75
76
77
78
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
79
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
80
81


82
83
class VideoAssetPrompts(TypedDict):
    baby_reading: str
84
85


86
class VideoTestAssets(list[VideoAsset]):
87
88
89

    def __init__(self) -> None:
        super().__init__([
90
            VideoAsset("baby_reading"),
91
92
        ])

93
94
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
95
96


97
class AudioAssetPrompts(TypedDict):
98
99
100
101
    mary_had_lamb: str
    winning_call: str


102
class AudioTestAssets(list[AudioAsset]):
103
104
105
106
107
108
109

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

110
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
111
112
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

113

114
IMAGE_ASSETS = ImageTestAssets()
115
"""Singleton instance of {class}`ImageTestAssets`."""
116
VIDEO_ASSETS = VideoTestAssets()
117
"""Singleton instance of {class}`VideoTestAssets`."""
118
AUDIO_ASSETS = AudioTestAssets()
119
"""Singleton instance of {class}`AudioTestAssets`."""
120
121


122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
142
@pytest.fixture(params=[True, False])
143
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
144
145
146
147
148
149
150
151
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
152
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
153
    else:
154
155
156
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
157
158


159
160
161
162
163
164
165
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


166
167
168
169
170
171
172
173
174
175
176
177
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
178
    cleanup_dist_env_and_memory()
179
180


181
@pytest.fixture()
182
def should_do_global_cleanup_after_test(request) -> bool:
183
184
185
186
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
187

188
    return not request.node.get_closest_marker("skip_global_cleanup")
189
190


191
@pytest.fixture(autouse=True)
192
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
193
    yield
194
    if should_do_global_cleanup_after_test:
195
        cleanup_dist_env_and_memory()
196
197


198
199
200
201
202
203
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
204
@pytest.fixture
205
def example_prompts() -> list[str]:
206
207
    prompts = []
    for filename in _TEST_PROMPTS:
208
        prompts += _read_prompts(filename)
209
210
211
    return prompts


212
213
214
215
216
217
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


218
219
220
221
222
223
224
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


225
@pytest.fixture
226
def example_encoder_decoder_prompts(
227
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
228
229
230
231
232
233
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
250
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
251
        DecoderPromptType.EMPTY_STR:
252
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
253
        DecoderPromptType.CUSTOM:
254
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
255
256
257
    }


258
@pytest.fixture
259
def example_long_prompts() -> list[str]:
260
261
    prompts = []
    for filename in _LONG_PROMPTS:
262
        prompts += _read_prompts(filename)
263
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265


266
@pytest.fixture(scope="session")
267
def image_assets() -> ImageTestAssets:
268
269
270
    return IMAGE_ASSETS


271
@pytest.fixture(scope="session")
272
def video_assets() -> VideoTestAssets:
273
274
275
    return VIDEO_ASSETS


276
@pytest.fixture(scope="session")
277
def audio_assets() -> AudioTestAssets:
278
279
280
    return AUDIO_ASSETS


281
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
282
_R = TypeVar("_R")
283

Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286

class HfRunner:

287
    def get_default_device(self):
288
        from vllm.platforms import current_platform
289

290
291
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
292
293

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
294
295
296
        if x is None or isinstance(x, (bool, )):
            return x

297
        if device is None:
298
            device = self.device
299

300
301
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
302

303
304
305
306
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
307

Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
    def __init__(
        self,
        model_name: str,
311
        dtype: str = "auto",
312
        *,
313
        model_kwargs: Optional[dict[str, Any]] = None,
314
        trust_remote_code: bool = True,
315
        is_sentence_transformer: bool = False,
316
        is_cross_encoder: bool = False,
317
        skip_tokenizer_init: bool = False,
318
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
319
    ) -> None:
320
        self.model_name = model_name
321

322
323
        self.config = AutoConfig.from_pretrained(
            model_name,
324
            trust_remote_code=trust_remote_code,
325
326
        )
        self.device = self.get_default_device()
327
328
329
330
331
332
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
333
334
335
336

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

337
        if is_sentence_transformer:
338
339
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
340
341
342
343
344

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
345
                trust_remote_code=trust_remote_code,
346
            )
347
348
349
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
350
351
352
353
354

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
355
                trust_remote_code=trust_remote_code,
356
            )
357
        else:
358
359
            model = auto_cls.from_pretrained(
                model_name,
360
                trust_remote_code=trust_remote_code,
361
362
363
                **model_kwargs,
            )

364
365
366
367
368
369
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

370
371
372
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
373
                model = model.to(device=self.device)
374
375

            self.model = model
376

377
378
379
380
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
381
                trust_remote_code=trust_remote_code,
382
            )
383

384
385
386
387
388
389
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
390
            trust_remote_code=trust_remote_code,
391
        )
392
393
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
394

395
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
396
        self,
397
        prompts: list[str],
398
        images: Optional[PromptImageInput] = None,
399
400
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
401
    ) -> list[Union[BatchFeature, BatchEncoding]]:
402
        if images is not None:
403
            assert len(prompts) == len(images)
404

405
406
407
408
409
410
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

411
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
412
        for i, prompt in enumerate(prompts):
413
            processor_kwargs: dict[str, Any] = {
414
415
416
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
417
418
419
420
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
421
422
423
424
425
426
427
428
429
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
430
431

            inputs = self.processor(**processor_kwargs)
432
433
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
434

435
436
437
438
            all_inputs.append(inputs)

        return all_inputs

439
440
441
442
443
444
445
446
447
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

448
    def classify(self, prompts: list[str]) -> list[str]:
449
450
451
452
453
454
455
456
457
458
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

459
460
    def generate(
        self,
461
        prompts: list[str],
462
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
463
        videos: Optional[PromptVideoInput] = None,
464
465
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
466
    ) -> list[tuple[list[list[int]], list[str]]]:
467
468
469
470
471
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

472
        outputs: list[tuple[list[list[int]], list[str]]] = []
473
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
474
            output_ids = self.model.generate(
475
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
476
477
478
                use_cache=True,
                **kwargs,
            )
479
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
480
481
482
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
483
484
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
485
486
487
488
489
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
490
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
491
        max_tokens: int,
492
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
493
        videos: Optional[PromptVideoInput] = None,
494
        audios: Optional[PromptAudioInput] = None,
495
        **kwargs: Any,
496
    ) -> list[tuple[list[int], str]]:
497
498
        outputs = self.generate(prompts,
                                do_sample=False,
499
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
500
                                images=images,
501
502
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
503
                                **kwargs)
504
505
506

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
507
508
509

    def generate_beam_search(
        self,
510
        prompts: list[str],
511
512
        beam_width: int,
        max_tokens: int,
513
514
515
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
516
    ) -> list[tuple[list[list[int]], list[str]]]:
517
518
519
520
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
521
522
523
524
525
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

526
527
528
529
530
531
532
533
534
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
535

536
537
    def generate_greedy_logprobs(
        self,
538
        prompts: list[str],
539
        max_tokens: int,
540
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
541
        videos: Optional[PromptVideoInput] = None,
542
        audios: Optional[PromptAudioInput] = None,
543
        **kwargs: Any,
544
    ) -> list[list[torch.Tensor]]:
545
546
547
548
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
549

550
        all_logprobs: list[list[torch.Tensor]] = []
551
        for inputs in all_inputs:
552
            output = self.model.generate(
553
                **self.wrap_device(inputs),
554
555
556
557
558
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
559
                **kwargs,
560
            )
561
562
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
563
564
565
            all_logprobs.append(seq_logprobs)
        return all_logprobs

566
    def _hidden_states_to_seq_logprobs(
567
        self,
568
569
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
570
571
        output_embeddings = self.model.get_output_embeddings()

572
        seq_logprobs: list[torch.Tensor] = []
573
574
575
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
576
577
578
579
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
580
                output_embeddings.weight.t(),
581
            )
582
583
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
584
585
586
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

587
588
589
590
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
591
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
592
        num_logprobs: int,
593
    ) -> tuple[list[dict[int, float]], int]:
594
595
596
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

597
        # convert to dict
598
        seq_logprobs_lst: list[dict[int, float]] = []
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

616
617
    def generate_greedy_logprobs_limit(
        self,
618
        prompts: list[str],
619
620
        max_tokens: int,
        num_logprobs: int,
621
622
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
623
        videos: Optional[PromptVideoInput] = None,
624
        **kwargs: Any,
625
    ) -> list[TokensTextLogprobs]:
626
627
628
629
630
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

631
632
633
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
634

635
        for inputs in all_inputs:
636
            output = self.model.generate(
637
                **self.wrap_device(inputs),
638
639
640
641
642
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
643
                **kwargs,
644
645
            )

646
647
648
649
650
651
652
653
654
655
656
657
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
658

659
660
661
662
663
664
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
665
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
666
667
        max_tokens: int,
        num_logprobs: int,
668
        images: Optional[PromptImageInput] = None,
669
        **kwargs: Any,
670
    ) -> list[TokensTextLogprobs]:
671
672
673
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
674

675
676
677
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
678

679
680
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
681
            processor_kwargs: dict[str, Any] = {
682
683
684
685
686
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
687

688
689
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
690
691
692
693

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
694
695
696
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
697
698
699
700
701
702
703
704

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
705
                **encoder_inputs,
706
707
708
709
710
711
712
713
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
714
715
716
717
718
719
720
721
722
723
724

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

725
726
727
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
728

729
    def predict(self, prompts: list[list[str]]) -> torch.Tensor:
730
731
        return self.model.predict(prompts, convert_to_tensor=True)

732
733
734
735
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
736
        del self.model
737
        cleanup_dist_env_and_memory()
738

Woosuk Kwon's avatar
Woosuk Kwon committed
739

Cyrus Leung's avatar
Cyrus Leung committed
740
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
741
742
743
744
745
def hf_runner():
    return HfRunner


class VllmRunner:
746
747
    """
    The default value of some arguments have been modified from
748
    {class}`~vllm.LLM` as follows:
749

750
751
752
753
754
755
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
    - `block_size`: Set to `16` instead of `None` to reduce memory usage.
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
756
    - `enforce_eager`: Set to `False` to test CUDA graph.
757
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
758
759
760
761

    def __init__(
        self,
        model_name: str,
762
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
763
        tokenizer_name: Optional[str] = None,
764
        tokenizer_mode: str = "auto",
765
766
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
767
        max_model_len: int = 1024,
768
        dtype: str = "auto",
769
        disable_log_stats: bool = True,
770
        tensor_parallel_size: int = 1,
771
        block_size: int = 16,
772
        enable_chunked_prefill: Optional[bool] = False,
773
        swap_space: int = 4,
774
        enforce_eager: Optional[bool] = False,
775
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
776
777
778
    ) -> None:
        self.model = LLM(
            model=model_name,
779
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
780
            tokenizer=tokenizer_name,
781
            tokenizer_mode=tokenizer_mode,
782
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
783
            dtype=dtype,
784
            seed=seed,
785
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
786
            enforce_eager=enforce_eager,
787
            disable_log_stats=disable_log_stats,
788
            tensor_parallel_size=tensor_parallel_size,
789
            max_model_len=max_model_len,
790
791
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
792
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
793
794
        )

795
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
796
        self,
797
        prompts: Union[list[str], list[torch.Tensor]],
798
        images: Optional[PromptImageInput] = None,
799
800
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
801
    ) -> list[TextPrompt]:
802

803
804
805
806
807
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
808

809
810
811
812
813
814
815
816
817
818
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

819
820
821
822
823
824
            text_prompt_kwargs = {
                ("prompt" if isinstance(prompt, str) else "prompt_embeds"):
                prompt,
                "multi_modal_data": multi_modal_data or None
            }
            inputs.append(TextPrompt(**text_prompt_kwargs))
825
826
827
828
829

        return inputs

    def generate(
        self,
830
        prompts: Union[list[str], list[torch.Tensor]],
831
832
833
834
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
835
        **kwargs: Any,
836
    ) -> list[tuple[list[list[int]], list[str]]]:
837
838
839
840
841
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

842
        req_outputs = self.model.generate(inputs,
843
844
                                          sampling_params=sampling_params,
                                          **kwargs)
845

846
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
847
848
849
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
850
851
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
852
853
            for sample in req_output.outputs:
                output_str = sample.text
854
                output_ids = list(sample.token_ids)
855
                req_sample_output_ids.append(prompt_ids + output_ids)
856
                req_sample_output_strs.append((prompt_str or "") + output_str)
857
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
858
859
        return outputs

860
    @staticmethod
861
    def _final_steps_generate_w_logprobs(
862
863
864
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
865
        for req_output in req_outputs:
866
            assert len(req_output.outputs) > 0
867
868
            for sample in req_output.outputs:
                output_str = sample.text
869
                output_ids = list(sample.token_ids)
870
                output_logprobs = sample.logprobs
871
872
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
873
874
        return outputs

875
876
    def generate_w_logprobs(
        self,
877
        prompts: list[str],
878
        sampling_params: SamplingParams,
879
880
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
881
        videos: Optional[PromptVideoInput] = None,
882
        **kwargs: Any,
883
884
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
885
886
887
888
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
889

890
        req_outputs = self.model.generate(inputs,
891
892
                                          sampling_params=sampling_params,
                                          **kwargs)
893
894
895
896
897
898
899

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
900
901
902

    def generate_encoder_decoder_w_logprobs(
        self,
903
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
904
        sampling_params: SamplingParams,
905
906
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
907
908
909
910
911
912
913
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
914
915
916
917
918
919
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
920

Woosuk Kwon's avatar
Woosuk Kwon committed
921
922
    def generate_greedy(
        self,
923
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
924
        max_tokens: int,
925
        images: Optional[PromptImageInput] = None,
926
927
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
928
        **kwargs: Any,
929
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
930
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
931
932
933
934
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
935
936
                                audios=audios,
                                **kwargs)
937
938
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
939

940
941
    def generate_greedy_logprobs(
        self,
942
        prompts: list[str],
943
944
        max_tokens: int,
        num_logprobs: int,
945
        num_prompt_logprobs: Optional[int] = None,
946
947
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
948
        videos: Optional[PromptVideoInput] = None,
949
950
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
951
        **kwargs: Any,
952
953
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
954
955
956
957
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
958
            prompt_logprobs=num_prompt_logprobs,
959
960
            stop_token_ids=stop_token_ids,
            stop=stop)
961
962
963
964
965

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
966
967
                                        videos=videos,
                                        **kwargs)
968

969
970
    def generate_encoder_decoder_greedy_logprobs(
        self,
971
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
972
973
        max_tokens: int,
        num_logprobs: int,
974
        num_prompt_logprobs: Optional[int] = None,
975
        skip_special_tokens: bool = True,
976
977
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
978
979
980
981
982
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
983
            skip_special_tokens=skip_special_tokens,
984
        )
985
986
987
988
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

989
        return self.generate_encoder_decoder_w_logprobs(
990
991
            encoder_decoder_prompts, greedy_logprobs_params)

992
    def generate_beam_search(
993
        self,
994
        prompts: list[str],
995
996
        beam_width: int,
        max_tokens: int,
997
998
999
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
1000
    ) -> list[tuple[list[list[int]], list[str]]]:
1001
1002
1003
1004
1005
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1006
        outputs = self.model.beam_search(
1007
            inputs,
1008
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
1009
1010
1011
1012
1013
1014
1015
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1016
    def classify(self, prompts: list[str]) -> list[list[float]]:
1017
1018
1019
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

1020
1021
1022
1023
1024
1025
1026
    def encode(self,
               prompts: list[str],
               images: Optional[PromptImageInput] = None,
               videos: Optional[PromptVideoInput] = None,
               audios: Optional[PromptAudioInput] = None,
               *args,
               **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1027
1028
1029
1030
1031
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1032
        req_outputs = self.model.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1033
        return [req_output.outputs.embedding for req_output in req_outputs]
1034

1035
1036
    def score(
        self,
1037
1038
1039
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
    ) -> list[float]:
1040
        req_outputs = self.model.score(text_1, text_2)
1041
        return [req_output.outputs.score for req_output in req_outputs]
1042

1043
1044
1045
1046
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

1047
1048
1049
1050
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1051
        del self.model
1052
        cleanup_dist_env_and_memory()
1053

Woosuk Kwon's avatar
Woosuk Kwon committed
1054

1055
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1056
1057
def vllm_runner():
    return VllmRunner
1058
1059


1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1074
1075
1076
1077
1078
1079
1080


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1081
    return cuda_device_count_stateless()
1082
1083
1084


temp_dir = tempfile.gettempdir()
1085
1086
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1087
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1088
1089
1090
1091


@pytest.fixture
def dummy_opt_path():
1092
1093
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1094
        snapshot_download(repo_id="facebook/opt-125m",
1095
                          local_dir=_dummy_opt_path,
1096
1097
1098
1099
1100
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1101
        with open(json_path) as f:
1102
1103
1104
1105
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1120
        with open(json_path) as f:
1121
1122
1123
1124
1125
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1139
        with open(json_path) as f:
1140
1141
1142
1143
1144
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1163
            item.add_marker(skip_optional)
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")