conftest.py 35.9 KB
Newer Older
1
import json
2
import os
3
import tempfile
4
from collections import UserList
5
from enum import Enum
6
7
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
import pytest
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from huggingface_hub import snapshot_download
15
from PIL import Image
16
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
17
                          BatchFeature)
18
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
21
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm import LLM, SamplingParams
23
from vllm.assets.image import ImageAsset
24
from vllm.assets.video import VideoAsset
25
from vllm.config import TaskOption, TokenizerPoolConfig
26
from vllm.connections import global_http_connection
27
from vllm.distributed import (cleanup_dist_env_and_memory,
28
29
                              init_distributed_environment,
                              initialize_model_parallel)
30
31
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
32
from vllm.logger import init_logger
33
from vllm.outputs import RequestOutput
34
from vllm.sampling_params import BeamSearchParams
35
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
36
                        identity)
37

38
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
39

40
41
42
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
43
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
44

Cyrus Leung's avatar
Cyrus Leung committed
45
46
47
48
49
50
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]

PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
51

52

53
def _read_prompts(filename: str) -> List[str]:
54
    with open(filename) as f:
55
56
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58


59
60
61
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
62
63


64
65
class _ImageAssetsBase(UserList[ImageAsset]):
    pass
66

67
68

class _ImageAssets(_ImageAssetsBase):
69
70

    def __init__(self) -> None:
71
72
73
74
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
75
76
77
78
79
80
81
82

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
83
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
84
85


86
87
88
89
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


90
91
class _VideoAssetsBase(UserList[VideoAsset]):
    pass
92
93
94
95
96
97
98
99
100
101
102
103
104


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


105
106
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
107
108
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
109
110


Joe Runde's avatar
Joe Runde committed
111
@pytest.fixture(params=[True, False])
112
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
113
114
115
116
117
118
119
120
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
121
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
122
    else:
123
124
125
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
126
127


128
129
130
131
132
133
134
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


135
136
137
138
139
140
141
142
143
144
145
146
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
147
    cleanup_dist_env_and_memory()
148
149


150
@pytest.fixture()
151
def should_do_global_cleanup_after_test(request) -> bool:
152
153
154
155
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
156

157
    return not request.node.get_closest_marker("skip_global_cleanup")
158
159


160
@pytest.fixture(autouse=True)
161
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
162
    yield
163
    if should_do_global_cleanup_after_test:
164
        cleanup_dist_env_and_memory()
165
166


167
168
169
170
171
172
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
@pytest.fixture
def example_prompts() -> List[str]:
175
176
    prompts = []
    for filename in _TEST_PROMPTS:
177
        prompts += _read_prompts(filename)
178
179
180
    return prompts


181
182
183
184
185
186
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


187
188
189
190
191
192
193
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


194
@pytest.fixture
195
196
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
197
198
199
200
201
202
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
219
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
220
        DecoderPromptType.EMPTY_STR:
221
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
222
        DecoderPromptType.CUSTOM:
223
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
224
225
226
    }


227
228
229
230
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
231
        prompts += _read_prompts(filename)
232
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234


235
236
237
238
239
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


240
241
242
243
244
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


245
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
246

Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249

class HfRunner:

250
    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
251
        from vllm.platforms import current_platform
252
253
254
        if x is None or isinstance(x, (bool, )):
            return x

255
        if device is None:
256
            device = "cpu" if current_platform.is_cpu() else "cuda"
257

258
259
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
260

261
262
263
264
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
265

Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
268
269
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
270
        *,
271
        model_kwargs: Optional[Dict[str, Any]] = None,
272
        is_sentence_transformer: bool = False,
273
        is_cross_encoder: bool = False,
274
        skip_tokenizer_init: bool = False,
275
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
276
        postprocess_inputs: Callable[..., BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
277
    ) -> None:
278
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
279

280
        self.model_name = model_name
281

282
        if is_sentence_transformer:
283
284
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
285
286
287
288
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
289
                    trust_remote_code=True,
290
                ).to(dtype=torch_dtype))
291
292
293
294
295
296
297
298
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
            self.model = CrossEncoder(model_name,
                                      device="cpu",
                                      trust_remote_code=True)
            self.model.model = self.wrap_device(self.model.model)\
                .to(dtype=torch_dtype)
299
        else:
300
            model_kwargs = model_kwargs if model_kwargs is not None else {}
301
            self.model = self.wrap_device(
302
                auto_cls.from_pretrained(
303
304
305
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
306
                    **model_kwargs,
307
                ))
308

309
310
311
312
313
314
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
315

316
317
318
319
320
321
322
323
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
324
325
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
326

327
        self.dtype = dtype
328
329
        self.postprocess_inputs = postprocess_inputs

330
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
        self,
        prompts: List[str],
333
        images: Optional[PromptImageInput] = None,
334
335
336
337
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[BatchEncoding]:
        if images is not None:
338
            assert len(prompts) == len(images)
339

340
341
342
343
344
345
346
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

        all_inputs: List[BatchEncoding] = []
347
        for i, prompt in enumerate(prompts):
348
349
350
351
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
352
353
354
355
356
357
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
            if audios is not None and (audio_tuple := audios[i]) is not None:
                audio, sr = audio_tuple
358
359
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr
360
361

            inputs = self.processor(**processor_kwargs)
362
            inputs = self.postprocess_inputs(inputs, dtype=self.dtype)
363

364
365
366
367
            all_inputs.append(inputs)

        return all_inputs

368
369
370
371
372
373
374
375
376
377
378
    def classify(self, prompts: List[str]) -> List[str]:
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

379
380
381
382
    def generate(
        self,
        prompts: List[str],
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
383
        videos: Optional[PromptVideoInput] = None,
384
385
386
387
388
389
390
391
392
393
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

        outputs: List[Tuple[List[List[int]], List[str]]] = []
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
394
            output_ids = self.model.generate(
395
                **self.wrap_device(inputs, device=self.model.device.type),
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
                use_cache=True,
                **kwargs,
            )
399
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
400
401
402
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
403
404
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407
408
409
410
411
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
412
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
413
        videos: Optional[PromptVideoInput] = None,
414
        audios: Optional[PromptAudioInput] = None,
415
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
416
    ) -> List[Tuple[List[int], str]]:
417
418
        outputs = self.generate(prompts,
                                do_sample=False,
419
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
420
                                images=images,
421
422
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
423
                                **kwargs)
424
425
426

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
427
428
429
430
431
432

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
433
    ) -> List[Tuple[List[List[int]], List[str]]]:
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
448

449
450
451
452
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
453
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
454
        videos: Optional[PromptVideoInput] = None,
455
        audios: Optional[PromptAudioInput] = None,
456
        **kwargs: Any,
457
    ) -> List[List[torch.Tensor]]:
458
459
460
461
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
462

463
464
        all_logprobs: List[List[torch.Tensor]] = []
        for inputs in all_inputs:
465
            output = self.model.generate(
466
                **self.wrap_device(inputs, device=self.model.device.type),
467
468
469
470
471
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
472
                **kwargs,
473
            )
474
475
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
476
477
478
            all_logprobs.append(seq_logprobs)
        return all_logprobs

479
    def _hidden_states_to_seq_logprobs(
480
        self,
481
482
483
484
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
    ) -> List[torch.Tensor]:
        output_embeddings = self.model.get_output_embeddings()

485
486
487
488
        seq_logprobs: List[torch.Tensor] = []
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
489
490
                last_hidden_states.to(output_embeddings.weight.device),
                output_embeddings.weight.t(),
491
            )
492
493
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
494
495
496
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

497
498
499
500
501
502
503
504
505
506
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
        num_logprobs: int,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

526
527
528
529
530
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
531
532
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
533
        videos: Optional[PromptVideoInput] = None,
534
        **kwargs: Any,
535
    ) -> List[TokensTextLogprobs]:
536
537
538
539
540
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

541
542
543
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
544

545
        for inputs in all_inputs:
546
            output = self.model.generate(
547
                **self.wrap_device(inputs, device=self.model.device.type),
548
549
550
551
552
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
553
                **kwargs,
554
555
            )

556
557
558
559
560
561
562
563
564
565
566
567
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
568

569
570
571
572
573
574
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
575
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
576
577
        max_tokens: int,
        num_logprobs: int,
578
        images: Optional[PromptImageInput] = None,
579
        **kwargs: Any,
580
    ) -> List[TokensTextLogprobs]:
581
582
583
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
584

585
586
587
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
588

589
590
591
592
593
594
595
596
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
            processor_kwargs: Dict[str, Any] = {
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
597

598
            encoder_input_ids = self.wrap_device(
599
                self.processor(**processor_kwargs).input_ids,
600
601
602
603
604
605
606
                device=self.model.device.type,
            )

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
                decoder_input_ids = self.wrap_device(
607
                    self.tokenizer(decoder_prompt,
608
609
610
                                   return_tensors="pt").input_ids,
                    device=self.model.device.type,
                )
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
628
629
630
631
632
633
634
635
636
637
638

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

639
640
641
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

642
643
644
    def predict(self, prompts: List[List[str]]) -> torch.Tensor:
        return self.model.predict(prompts, convert_to_tensor=True)

645
646
647
648
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
649
        del self.model
650
        cleanup_dist_env_and_memory()
651

Woosuk Kwon's avatar
Woosuk Kwon committed
652

Cyrus Leung's avatar
Cyrus Leung committed
653
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
654
655
656
657
658
659
660
661
662
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
663
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
664
        tokenizer_name: Optional[str] = None,
665
        tokenizer_mode: str = "auto",
666
667
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
668
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
669
        dtype: str = "half",
670
        disable_log_stats: bool = True,
671
        tensor_parallel_size: int = 1,
672
673
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
674
        swap_space: int = 4,
675
        enforce_eager: Optional[bool] = False,
676
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
677
678
679
    ) -> None:
        self.model = LLM(
            model=model_name,
680
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
681
            tokenizer=tokenizer_name,
682
            tokenizer_mode=tokenizer_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
683
684
            trust_remote_code=True,
            dtype=dtype,
685
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
686
            enforce_eager=enforce_eager,
687
            disable_log_stats=disable_log_stats,
688
            tensor_parallel_size=tensor_parallel_size,
689
            max_model_len=max_model_len,
690
691
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
692
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
693
694
        )

695
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
696
697
        self,
        prompts: List[str],
698
        images: Optional[PromptImageInput] = None,
699
700
701
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[TextPrompt]:
702
        if images is not None:
703
            assert len(prompts) == len(images)
704

705
706
707
708
709
710
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

711
712
713
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
Cyrus Leung's avatar
Cyrus Leung committed
714
715
                if image is not None:
                    inputs[i]["multi_modal_data"] = {"image": image}
716

717
718
        if videos is not None:
            for i, video in enumerate(videos):
Cyrus Leung's avatar
Cyrus Leung committed
719
720
                if video is not None:
                    inputs[i]["multi_modal_data"] = {"video": video}
721
722
723

        if audios is not None:
            for i, audio in enumerate(audios):
Cyrus Leung's avatar
Cyrus Leung committed
724
725
                if audio is not None:
                    inputs[i]["multi_modal_data"] = {"audio": audio}
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741

        return inputs

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

742
        req_outputs = self.model.generate(inputs,
743
                                          sampling_params=sampling_params)
744
745

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
746
747
748
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
749
750
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
751
752
            for sample in req_output.outputs:
                output_str = sample.text
753
                output_ids = list(sample.token_ids)
754
755
756
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
757
758
        return outputs

759
    @staticmethod
760
761
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
762
763
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
764
        for req_output in req_outputs:
765
            assert len(req_output.outputs) > 0
766
767
            for sample in req_output.outputs:
                output_str = sample.text
768
                output_ids = list(sample.token_ids)
769
                output_logprobs = sample.logprobs
770
771
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
772
773
        return outputs

774
775
776
777
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
778
779
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
780
        videos: Optional[PromptVideoInput] = None,
781
782
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
783
784
785
786
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
787

788
        req_outputs = self.model.generate(inputs,
789
                                          sampling_params=sampling_params)
790
791
792
793
794
795
796

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
797
798
799

    def generate_encoder_decoder_w_logprobs(
        self,
800
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
801
        sampling_params: SamplingParams,
802
803
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
804
805
806
807
808
809
810
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
811
812
813
814
815
816
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
817

Woosuk Kwon's avatar
Woosuk Kwon committed
818
819
820
821
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
822
        images: Optional[PromptImageInput] = None,
823
824
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
825
826
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
827
828
829
830
831
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
                                audios=audios)
832
833
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
834

835
836
837
838
839
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
840
        num_prompt_logprobs: Optional[int] = None,
841
842
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
843
        videos: Optional[PromptVideoInput] = None,
844
        stop_token_ids: Optional[List[int]] = None,
845
        stop: Optional[List[str]] = None,
846
847
848
849
850
851
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
852
            prompt_logprobs=num_prompt_logprobs,
853
854
            stop_token_ids=stop_token_ids,
            stop=stop)
855
856
857
858
859
860

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
                                        videos=videos)
861

862
863
    def generate_encoder_decoder_greedy_logprobs(
        self,
864
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
865
866
        max_tokens: int,
        num_logprobs: int,
867
868
869
870
871
872
873
874
875
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
876
877
878
879
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

880
        return self.generate_encoder_decoder_w_logprobs(
881
882
            encoder_decoder_prompts, greedy_logprobs_params)

883
    def generate_beam_search(
884
885
886
887
888
        self,
        prompts: Union[List[str], List[List[int]]],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[List[int]], List[str]]]:
889
890
891
        outputs = self.model.beam_search(
            prompts,
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
892
893
894
895
896
897
898
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

899
900
901
902
    def classify(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

Cyrus Leung's avatar
Cyrus Leung committed
903
904
905
906
907
908
909
910
911
912
913
914
    def encode(
        self,
        prompts: List[str],
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[List[float]]:
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

915
        req_outputs = self.model.embed(inputs)
Cyrus Leung's avatar
Cyrus Leung committed
916
        return [req_output.outputs.embedding for req_output in req_outputs]
917

918
919
920
921
    def score(
        self,
        text_1: Union[str, List[str]],
        text_2: Union[str, List[str]],
922
    ) -> List[float]:
923
        req_outputs = self.model.score(text_1, text_2)
924
        return [req_output.outputs.score for req_output in req_outputs]
925

926
927
928
929
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
930
        del self.model
931
        cleanup_dist_env_and_memory()
932

Woosuk Kwon's avatar
Woosuk Kwon committed
933

934
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
935
936
def vllm_runner():
    return VllmRunner
937
938
939
940
941
942
943
944
945


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
946
947
948
949
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
950
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
967
968
969
970
971
972
973


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

974
    return cuda_device_count_stateless()
975
976
977


temp_dir = tempfile.gettempdir()
978
979
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
980
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
981
982
983
984


@pytest.fixture
def dummy_opt_path():
985
986
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
987
        snapshot_download(repo_id="facebook/opt-125m",
988
                          local_dir=_dummy_opt_path,
989
990
991
992
993
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
994
        with open(json_path) as f:
995
996
997
998
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1013
        with open(json_path) as f:
1014
1015
1016
1017
1018
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1032
        with open(json_path) as f:
1033
1034
1035
1036
1037
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
            item.add_marker(skip_optional)