conftest.py 32.3 KB
Newer Older
1
import json
2
import os
3
import sys
4
import tempfile
5
from collections import UserList
6
from enum import Enum
7
8
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
import pytest
import torch
13
import torch.nn as nn
14
import torch.nn.functional as F
15
from huggingface_hub import snapshot_download
16
from PIL import Image
17
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
18
                          BatchFeature)
19
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
20

21
22
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm import LLM, SamplingParams
24
from vllm.assets.image import ImageAsset
25
from vllm.assets.video import VideoAsset
26
from vllm.config import TaskOption, TokenizerPoolConfig
27
from vllm.connections import global_http_connection
28
from vllm.distributed import (cleanup_dist_env_and_memory,
29
30
                              init_distributed_environment,
                              initialize_model_parallel)
31
32
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
33
from vllm.logger import init_logger
34
from vllm.outputs import RequestOutput
35
from vllm.sampling_params import BeamSearchParams
36
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
37
                        identity, is_cpu)
38

39
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
40

41
42
43
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
44

45
46
47
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
                         List[List[Tuple[np.ndarray, int]]]]
48
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
49

50

51
def _read_prompts(filename: str) -> List[str]:
52
    with open(filename, "r") as f:
53
54
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56


57
58
59
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
60
61
62
63
64
65
66


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
67

68
69
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
70

71
72

class _ImageAssets(_ImageAssetsBase):
73
74

    def __init__(self) -> None:
75
76
77
78
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
79
80
81
82
83
84
85
86

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
87
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
88
89


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _VideoAssetsBase(UserList):
        pass
else:

    class _VideoAssetsBase(UserList[VideoAsset]):
        pass


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


115
116
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
117
118
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
119
120


121
122
123
124
125
126
127
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


128
129
130
131
132
133
134
135
136
137
138
139
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
140
    cleanup_dist_env_and_memory()
141
142


143
@pytest.fixture()
144
def should_do_global_cleanup_after_test(request) -> bool:
145
146
147
148
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
149

150
    return not request.node.get_closest_marker("skip_global_cleanup")
151
152


153
@pytest.fixture(autouse=True)
154
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
155
    yield
156
    if should_do_global_cleanup_after_test:
157
        cleanup_dist_env_and_memory()
158
159


160
161
162
163
164
165
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
@pytest.fixture
def example_prompts() -> List[str]:
168
169
    prompts = []
    for filename in _TEST_PROMPTS:
170
        prompts += _read_prompts(filename)
171
172
173
    return prompts


174
175
176
177
178
179
180
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


181
@pytest.fixture
182
183
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
184
185
186
187
188
189
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
206
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
207
        DecoderPromptType.EMPTY_STR:
208
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
209
        DecoderPromptType.CUSTOM:
210
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
211
212
213
    }


214
215
216
217
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
218
        prompts += _read_prompts(filename)
219
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221


222
223
224
225
226
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


227
228
229
230
231
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


232
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
233

Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
236

class HfRunner:

237
238
239
240
241
242
243
244
    def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
        if device is None:
            return self.wrap_device(input, "cpu" if is_cpu() else "cuda")

        if hasattr(input, "device") and input.device.type == device:
            return input

        return input.to(device)
245

Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
249
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
250
        *,
251
        model_kwargs: Optional[Dict[str, Any]] = None,
252
        is_sentence_transformer: bool = False,
253
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
254
255
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
256
    ) -> None:
257
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
258

259
        self.model_name = model_name
260

261
        if is_sentence_transformer:
262
263
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
264
265
266
267
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
268
                    trust_remote_code=True,
269
                ).to(dtype=torch_dtype))
270
        else:
271
            model_kwargs = model_kwargs if model_kwargs is not None else {}
272
            self.model = self.wrap_device(
273
                auto_cls.from_pretrained(
274
275
276
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
277
                    **model_kwargs,
278
                ))
279
280
281
282
283
284
285

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

286
287
288
289
290
291
292
293
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
294

295
296
        self.postprocess_inputs = postprocess_inputs

297
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
        self,
        prompts: List[str],
300
        images: Optional[PromptImageInput] = None,
301
302
303
304
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[BatchEncoding]:
        if images is not None:
305
            assert len(prompts) == len(images)
306

307
308
309
310
311
312
313
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

        all_inputs: List[BatchEncoding] = []
314
        for i, prompt in enumerate(prompts):
315
316
317
318
319
320
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
321
322
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
323
324
325
326
            if audios is not None and audios[i] is not None:
                audio, sr = audios[i]
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr
327
328

            inputs = self.processor(**processor_kwargs)
329
            inputs = self.postprocess_inputs(inputs)
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            all_inputs.append(inputs)

        return all_inputs

    def generate(
        self,
        prompts: List[str],
        images: Optional[PromptImageInput] = None,
        videos: Optional[List[np.ndarray]] = None,
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

        outputs: List[Tuple[List[List[int]], List[str]]] = []
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
350
            output_ids = self.model.generate(
351
                **self.wrap_device(inputs, device=self.model.device.type),
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353
354
                use_cache=True,
                **kwargs,
            )
355
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
359
360
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
361
362
363
364
365
366
367
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
368
        images: Optional[PromptImageInput] = None,
369
370
        videos: Optional[List[np.ndarray]] = None,
        audios: Optional[PromptAudioInput] = None,
371
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
372
    ) -> List[Tuple[List[int], str]]:
373
374
        outputs = self.generate(prompts,
                                do_sample=False,
375
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
376
                                images=images,
377
378
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
379
                                **kwargs)
380
381
382

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
383
384
385
386
387
388

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
389
    ) -> List[Tuple[List[List[int]], List[str]]]:
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
404

405
406
407
408
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
409
        images: Optional[PromptImageInput] = None,
410
        videos: Optional[List[np.ndarray]] = None,
411
        audios: Optional[PromptAudioInput] = None,
412
        **kwargs: Any,
413
    ) -> List[List[torch.Tensor]]:
414
415
416
417
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
418

419
420
        all_logprobs: List[List[torch.Tensor]] = []
        for inputs in all_inputs:
421
            output = self.model.generate(
422
                **self.wrap_device(inputs, device=self.model.device.type),
423
424
425
426
427
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
428
                **kwargs,
429
            )
430
431
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
432
433
434
            all_logprobs.append(seq_logprobs)
        return all_logprobs

435
    def _hidden_states_to_seq_logprobs(
436
        self,
437
438
439
440
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
    ) -> List[torch.Tensor]:
        output_embeddings = self.model.get_output_embeddings()

441
442
443
444
        seq_logprobs: List[torch.Tensor] = []
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
445
446
                last_hidden_states.to(output_embeddings.weight.device),
                output_embeddings.weight.t(),
447
            )
448
449
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
450
451
452
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

453
454
455
456
457
458
459
460
461
462
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
        num_logprobs: int,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

482
483
484
485
486
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
487
488
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
489
        videos: Optional[List[np.ndarray]] = None,
490
        **kwargs: Any,
491
    ) -> List[TokensTextLogprobs]:
492
493
494
495
496
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

497
498
499
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
500

501
        for inputs in all_inputs:
502
            output = self.model.generate(
503
                **self.wrap_device(inputs, device=self.model.device.type),
504
505
506
507
508
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
509
                **kwargs,
510
511
            )

512
513
514
515
516
517
518
519
520
521
522
523
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
524

525
526
527
528
529
530
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
531
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
532
533
534
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
535
    ) -> List[TokensTextLogprobs]:
536
537
538
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
539

540
541
542
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
543

544
545
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
546

547
            encoder_input_ids = self.wrap_device(
548
549
550
551
552
553
554
555
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
                device=self.model.device.type,
            )

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
                decoder_input_ids = self.wrap_device(
556
                    self.tokenizer(decoder_prompt,
557
558
559
                                   return_tensors="pt").input_ids,
                    device=self.model.device.type,
                )
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
577
578
579
580
581
582
583
584
585
586
587

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

588
589
590
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

591
592
593
594
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
595
        del self.model
596
        cleanup_dist_env_and_memory()
597

Woosuk Kwon's avatar
Woosuk Kwon committed
598

Cyrus Leung's avatar
Cyrus Leung committed
599
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
600
601
602
603
604
605
606
607
608
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
609
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
610
        tokenizer_name: Optional[str] = None,
611
612
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
613
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
614
        dtype: str = "half",
615
        disable_log_stats: bool = True,
616
        tensor_parallel_size: int = 1,
617
618
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
619
        swap_space: int = 4,
620
        enforce_eager: Optional[bool] = False,
621
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
622
623
624
    ) -> None:
        self.model = LLM(
            model=model_name,
625
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
626
627
628
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
629
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
630
            enforce_eager=enforce_eager,
631
            disable_log_stats=disable_log_stats,
632
            tensor_parallel_size=tensor_parallel_size,
633
            max_model_len=max_model_len,
634
635
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
636
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
637
638
        )

639
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
640
641
        self,
        prompts: List[str],
642
        images: Optional[PromptImageInput] = None,
643
644
645
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[TextPrompt]:
646
        if images is not None:
647
            assert len(prompts) == len(images)
648

649
650
651
652
653
654
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

655
656
657
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
658
                inputs[i]["multi_modal_data"] = {"image": image}
659

660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        if videos is not None:
            for i, video in enumerate(videos):
                inputs[i]["multi_modal_data"] = {"video": video}

        if audios is not None:
            for i, audio in enumerate(audios):
                inputs[i]["multi_modal_data"] = {"audio": audio}

        return inputs

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

683
        req_outputs = self.model.generate(inputs,
684
                                          sampling_params=sampling_params)
685
686

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
687
688
689
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
690
691
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
692
693
            for sample in req_output.outputs:
                output_str = sample.text
694
                output_ids = list(sample.token_ids)
695
696
697
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
698
699
        return outputs

700
    @staticmethod
701
702
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
703
704
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
705
        for req_output in req_outputs:
706
            assert len(req_output.outputs) > 0
707
708
            for sample in req_output.outputs:
                output_str = sample.text
709
                output_ids = list(sample.token_ids)
710
                output_logprobs = sample.logprobs
711
712
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
713
714
        return outputs

715
716
717
718
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
719
720
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
721
        videos: Optional[PromptVideoInput] = None,
722
723
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
724
725
726
727
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
728

729
        req_outputs = self.model.generate(inputs,
730
                                          sampling_params=sampling_params)
731
732
733
734
735
736
737

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
738
739
740

    def generate_encoder_decoder_w_logprobs(
        self,
741
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
742
        sampling_params: SamplingParams,
743
744
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
745
746
747
748
749
750
751
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
752
753
754
755
756
757
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
758

Woosuk Kwon's avatar
Woosuk Kwon committed
759
760
761
762
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
763
        images: Optional[PromptImageInput] = None,
764
765
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
766
767
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
768
769
770
771
772
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
                                audios=audios)
773
774
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
775

776
777
778
779
780
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
781
        num_prompt_logprobs: Optional[int] = None,
782
783
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
784
        videos: Optional[PromptVideoInput] = None,
785
        stop_token_ids: Optional[List[int]] = None,
786
787
788
789
790
791
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
792
            prompt_logprobs=num_prompt_logprobs,
793
794
795
796
797
798
799
            stop_token_ids=stop_token_ids)

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
                                        videos=videos)
800

801
802
    def generate_encoder_decoder_greedy_logprobs(
        self,
803
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
804
805
        max_tokens: int,
        num_logprobs: int,
806
807
808
809
810
811
812
813
814
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
815
816
817
818
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

819
        return self.generate_encoder_decoder_w_logprobs(
820
821
            encoder_decoder_prompts, greedy_logprobs_params)

822
    def generate_beam_search(
823
824
825
826
827
        self,
        prompts: Union[List[str], List[List[int]]],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[List[int]], List[str]]]:
828
829
830
        outputs = self.model.beam_search(
            prompts,
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
831
832
833
834
835
836
837
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

838
839
840
841
842
843
844
845
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

846
847
848
849
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
850
        del self.model
851
        cleanup_dist_env_and_memory()
852

Woosuk Kwon's avatar
Woosuk Kwon committed
853

854
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
855
856
def vllm_runner():
    return VllmRunner
857
858
859
860
861
862
863
864
865


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
866
867
868
869
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
870
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
887
888
889
890
891
892
893


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

894
    return cuda_device_count_stateless()
895
896
897


temp_dir = tempfile.gettempdir()
898
899
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
900
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
901
902
903
904


@pytest.fixture
def dummy_opt_path():
905
906
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
907
        snapshot_download(repo_id="facebook/opt-125m",
908
                          local_dir=_dummy_opt_path,
909
910
911
912
913
914
915
916
917
918
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path