conftest.py 37.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
import os
5
import tempfile
6
from collections import UserList
7
from enum import Enum
8
9
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
import pytest
import torch
14
import torch.nn as nn
15
import torch.nn.functional as F
16
from huggingface_hub import snapshot_download
17
from PIL import Image
18
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
19
                          BatchFeature)
20
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
23
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm import LLM, SamplingParams
25
from vllm.assets.image import ImageAsset
26
from vllm.assets.video import VideoAsset
27
from vllm.config import LoadFormat, TaskOption, TokenizerPoolConfig
28
from vllm.connections import global_http_connection
29
from vllm.distributed import (cleanup_dist_env_and_memory,
30
31
                              init_distributed_environment,
                              initialize_model_parallel)
32
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
youkaichao's avatar
youkaichao committed
33
34
                         TokensPrompt, to_enc_dec_tuple_list,
                         zip_enc_dec_prompts)
35
from vllm.logger import init_logger
36
from vllm.outputs import RequestOutput
37
from vllm.sampling_params import BeamSearchParams
38
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
youkaichao's avatar
youkaichao committed
39
                        identity, is_list_of)
40

41
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
45
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
46
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
47

Cyrus Leung's avatar
Cyrus Leung committed
48
_M = TypeVar("_M")
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

MODELS_ON_S3 = [
    "distilbert/distilgpt2",
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Llama-3.2-1B",
    "meta-llama/Llama-3.2-1B-Instruct",
    "openai-community/gpt2",
    "ArthurZ/Ilama-3.2-1B",
    "llava-hf/llava-1.5-7b-hf",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]

MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights"

Cyrus Leung's avatar
Cyrus Leung committed
64
65
66
67
68
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]

PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
69

70

71
def _read_prompts(filename: str) -> List[str]:
72
    with open(filename) as f:
73
74
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76


77
78
79
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
80
81


82
83
class _ImageAssetsBase(UserList[ImageAsset]):
    pass
84

85
86

class _ImageAssets(_ImageAssetsBase):
87
88

    def __init__(self) -> None:
89
90
91
92
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
93
94
95
96
97
98
99
100

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
101
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
102
103


104
105
106
107
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


108
109
class _VideoAssetsBase(UserList[VideoAsset]):
    pass
110
111
112
113
114
115
116
117
118
119
120
121
122


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


123
124
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
125
126
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
127
128


Joe Runde's avatar
Joe Runde committed
129
@pytest.fixture(params=[True, False])
130
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
131
132
133
134
135
136
137
138
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
139
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
140
    else:
141
142
143
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
144
145


146
147
148
149
150
151
152
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


153
154
155
156
157
158
159
160
161
162
163
164
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
165
    cleanup_dist_env_and_memory()
166
167


168
@pytest.fixture()
169
def should_do_global_cleanup_after_test(request) -> bool:
170
171
172
173
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
174

175
    return not request.node.get_closest_marker("skip_global_cleanup")
176
177


178
@pytest.fixture(autouse=True)
179
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
180
    yield
181
    if should_do_global_cleanup_after_test:
182
        cleanup_dist_env_and_memory()
183
184


185
186
187
188
189
190
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
@pytest.fixture
def example_prompts() -> List[str]:
193
194
    prompts = []
    for filename in _TEST_PROMPTS:
195
        prompts += _read_prompts(filename)
196
197
198
    return prompts


199
200
201
202
203
204
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


205
206
207
208
209
210
211
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


212
@pytest.fixture
213
214
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
215
216
217
218
219
220
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
237
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
238
        DecoderPromptType.EMPTY_STR:
239
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
240
        DecoderPromptType.CUSTOM:
241
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
242
243
244
    }


245
246
247
248
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
249
        prompts += _read_prompts(filename)
250
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252


253
254
255
256
257
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


258
259
260
261
262
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


263
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
264
_R = TypeVar("_R")
265

Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
268

class HfRunner:

269
    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
270
        from vllm.platforms import current_platform
271
272
273
        if x is None or isinstance(x, (bool, )):
            return x

274
        if device is None:
275
            device = "cpu" if current_platform.is_cpu() else "cuda"
276

277
278
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
279

280
281
282
283
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
284

Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
287
288
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
289
        *,
290
        model_kwargs: Optional[Dict[str, Any]] = None,
291
        is_sentence_transformer: bool = False,
292
        is_cross_encoder: bool = False,
293
        skip_tokenizer_init: bool = False,
294
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
295
        postprocess_inputs: Callable[..., BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
296
    ) -> None:
297
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
298

299
        self.model_name = model_name
300

301
        if is_sentence_transformer:
302
303
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
304
305
306
307
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
308
                    trust_remote_code=True,
309
                ).to(dtype=torch_dtype))
310
311
312
313
314
315
316
317
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
            self.model = CrossEncoder(model_name,
                                      device="cpu",
                                      trust_remote_code=True)
            self.model.model = self.wrap_device(self.model.model)\
                .to(dtype=torch_dtype)
318
        else:
319
            model_kwargs = model_kwargs if model_kwargs is not None else {}
320
            self.model = self.wrap_device(
321
                auto_cls.from_pretrained(
322
323
324
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
325
                    **model_kwargs,
326
                ))
327

328
329
330
331
332
333
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
334

335
336
337
338
339
340
341
342
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
343
344
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
345

346
        self.dtype = dtype
347
348
        self.postprocess_inputs = postprocess_inputs

349
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
        self,
        prompts: List[str],
352
        images: Optional[PromptImageInput] = None,
353
354
355
356
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[BatchEncoding]:
        if images is not None:
357
            assert len(prompts) == len(images)
358

359
360
361
362
363
364
365
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

        all_inputs: List[BatchEncoding] = []
366
        for i, prompt in enumerate(prompts):
367
368
369
370
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
371
372
373
374
375
376
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
            if audios is not None and (audio_tuple := audios[i]) is not None:
                audio, sr = audio_tuple
377
378
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr
379
380

            inputs = self.processor(**processor_kwargs)
381
            inputs = self.postprocess_inputs(inputs, dtype=self.dtype)
382

383
384
385
386
            all_inputs.append(inputs)

        return all_inputs

387
388
389
390
391
392
393
394
395
396
397
    def classify(self, prompts: List[str]) -> List[str]:
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

398
399
400
401
    def generate(
        self,
        prompts: List[str],
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
402
        videos: Optional[PromptVideoInput] = None,
403
404
405
406
407
408
409
410
411
412
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

        outputs: List[Tuple[List[List[int]], List[str]]] = []
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
413
            output_ids = self.model.generate(
414
                **self.wrap_device(inputs, device=self.model.device.type),
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
                use_cache=True,
                **kwargs,
            )
418
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
419
420
421
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
422
423
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
424
425
426
427
428
429
430
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
431
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
432
        videos: Optional[PromptVideoInput] = None,
433
        audios: Optional[PromptAudioInput] = None,
434
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
435
    ) -> List[Tuple[List[int], str]]:
436
437
        outputs = self.generate(prompts,
                                do_sample=False,
438
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
439
                                images=images,
440
441
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
442
                                **kwargs)
443
444
445

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
446
447
448
449
450
451

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
452
    ) -> List[Tuple[List[List[int]], List[str]]]:
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
467

468
469
470
471
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
472
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
473
        videos: Optional[PromptVideoInput] = None,
474
        audios: Optional[PromptAudioInput] = None,
475
        **kwargs: Any,
476
    ) -> List[List[torch.Tensor]]:
477
478
479
480
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
481

482
483
        all_logprobs: List[List[torch.Tensor]] = []
        for inputs in all_inputs:
484
            output = self.model.generate(
485
                **self.wrap_device(inputs, device=self.model.device.type),
486
487
488
489
490
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
491
                **kwargs,
492
            )
493
494
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
495
496
497
            all_logprobs.append(seq_logprobs)
        return all_logprobs

498
    def _hidden_states_to_seq_logprobs(
499
        self,
500
501
502
503
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
    ) -> List[torch.Tensor]:
        output_embeddings = self.model.get_output_embeddings()

504
505
506
507
        seq_logprobs: List[torch.Tensor] = []
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
508
509
                last_hidden_states.to(output_embeddings.weight.device),
                output_embeddings.weight.t(),
510
            )
511
512
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
513
514
515
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

516
517
518
519
520
521
522
523
524
525
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
        num_logprobs: int,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

545
546
547
548
549
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
550
551
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
552
        videos: Optional[PromptVideoInput] = None,
553
        **kwargs: Any,
554
    ) -> List[TokensTextLogprobs]:
555
556
557
558
559
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

560
561
562
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
563

564
        for inputs in all_inputs:
565
            output = self.model.generate(
566
                **self.wrap_device(inputs, device=self.model.device.type),
567
568
569
570
571
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
572
                **kwargs,
573
574
            )

575
576
577
578
579
580
581
582
583
584
585
586
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
587

588
589
590
591
592
593
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
594
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
595
596
        max_tokens: int,
        num_logprobs: int,
597
        images: Optional[PromptImageInput] = None,
598
        **kwargs: Any,
599
    ) -> List[TokensTextLogprobs]:
600
601
602
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
603

604
605
606
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
607

608
609
610
611
612
613
614
615
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
            processor_kwargs: Dict[str, Any] = {
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
616

617
            encoder_input_ids = self.wrap_device(
618
                self.processor(**processor_kwargs).input_ids,
619
620
621
622
623
624
625
                device=self.model.device.type,
            )

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
                decoder_input_ids = self.wrap_device(
626
                    self.tokenizer(decoder_prompt,
627
628
629
                                   return_tensors="pt").input_ids,
                    device=self.model.device.type,
                )
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
647
648
649
650
651
652
653
654
655
656
657

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

658
659
660
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

661
662
663
    def predict(self, prompts: List[List[str]]) -> torch.Tensor:
        return self.model.predict(prompts, convert_to_tensor=True)

664
665
666
667
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
668
        del self.model
669
        cleanup_dist_env_and_memory()
670

Woosuk Kwon's avatar
Woosuk Kwon committed
671

Cyrus Leung's avatar
Cyrus Leung committed
672
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
673
674
675
676
677
678
679
680
681
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
682
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
683
        tokenizer_name: Optional[str] = None,
684
        tokenizer_mode: str = "auto",
685
686
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
687
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
688
        dtype: str = "half",
689
        disable_log_stats: bool = True,
690
        tensor_parallel_size: int = 1,
691
692
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
693
        swap_space: int = 4,
694
        enforce_eager: Optional[bool] = False,
695
        load_format: Optional[LoadFormat] = None,
696
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
697
    ) -> None:
698
699
700
701
702
703
        if model_name in MODELS_ON_S3 and not load_format:
            model_name = (f"s3://vllm-ci-model-weights/"
                          f"{model_name.split('/')[-1]}")
            load_format = LoadFormat.RUNAI_STREAMER
        if not load_format:
            load_format = LoadFormat.AUTO
Woosuk Kwon's avatar
Woosuk Kwon committed
704
705
        self.model = LLM(
            model=model_name,
706
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
707
            tokenizer=tokenizer_name,
708
            tokenizer_mode=tokenizer_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
709
710
            trust_remote_code=True,
            dtype=dtype,
711
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
712
            enforce_eager=enforce_eager,
713
            disable_log_stats=disable_log_stats,
714
            tensor_parallel_size=tensor_parallel_size,
715
            max_model_len=max_model_len,
716
717
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
718
            load_format=load_format,
719
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
720
721
        )

722
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
723
724
        self,
        prompts: List[str],
725
        images: Optional[PromptImageInput] = None,
726
727
728
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[TextPrompt]:
729
        if images is not None:
730
            assert len(prompts) == len(images)
731

732
733
734
735
736
737
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

738
739
740
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
Cyrus Leung's avatar
Cyrus Leung committed
741
742
                if image is not None:
                    inputs[i]["multi_modal_data"] = {"image": image}
743

744
745
        if videos is not None:
            for i, video in enumerate(videos):
Cyrus Leung's avatar
Cyrus Leung committed
746
747
                if video is not None:
                    inputs[i]["multi_modal_data"] = {"video": video}
748
749
750

        if audios is not None:
            for i, audio in enumerate(audios):
Cyrus Leung's avatar
Cyrus Leung committed
751
752
                if audio is not None:
                    inputs[i]["multi_modal_data"] = {"audio": audio}
753
754
755
756
757
758
759
760
761
762

        return inputs

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
763
        **kwargs: Any,
764
765
766
767
768
769
    ) -> List[Tuple[List[List[int]], List[str]]]:
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

770
        req_outputs = self.model.generate(inputs,
771
772
                                          sampling_params=sampling_params,
                                          **kwargs)
773
774

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
775
776
777
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
778
779
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
780
781
            for sample in req_output.outputs:
                output_str = sample.text
782
                output_ids = list(sample.token_ids)
783
784
785
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
786
787
        return outputs

788
    @staticmethod
789
790
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
791
792
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
793
        for req_output in req_outputs:
794
            assert len(req_output.outputs) > 0
795
796
            for sample in req_output.outputs:
                output_str = sample.text
797
                output_ids = list(sample.token_ids)
798
                output_logprobs = sample.logprobs
799
800
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
801
802
        return outputs

803
804
805
806
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
807
808
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
809
        videos: Optional[PromptVideoInput] = None,
810
        **kwargs: Any,
811
812
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
813
814
815
816
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
817

818
        req_outputs = self.model.generate(inputs,
819
820
                                          sampling_params=sampling_params,
                                          **kwargs)
821
822
823
824
825
826
827

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
828
829
830

    def generate_encoder_decoder_w_logprobs(
        self,
831
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
832
        sampling_params: SamplingParams,
833
834
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
835
836
837
838
839
840
841
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
842
843
844
845
846
847
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
848

Woosuk Kwon's avatar
Woosuk Kwon committed
849
850
851
852
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
853
        images: Optional[PromptImageInput] = None,
854
855
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
856
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
857
858
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
859
860
861
862
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
863
864
                                audios=audios,
                                **kwargs)
865
866
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
867

868
869
870
871
872
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
873
        num_prompt_logprobs: Optional[int] = None,
874
875
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
876
        videos: Optional[PromptVideoInput] = None,
877
        stop_token_ids: Optional[List[int]] = None,
878
        stop: Optional[List[str]] = None,
879
        **kwargs: Any,
880
881
882
883
884
885
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
886
            prompt_logprobs=num_prompt_logprobs,
887
888
            stop_token_ids=stop_token_ids,
            stop=stop)
889
890
891
892
893

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
894
895
                                        videos=videos,
                                        **kwargs)
896

897
898
    def generate_encoder_decoder_greedy_logprobs(
        self,
899
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
900
901
        max_tokens: int,
        num_logprobs: int,
902
903
904
905
906
907
908
909
910
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
911
912
913
914
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

915
        return self.generate_encoder_decoder_w_logprobs(
916
917
            encoder_decoder_prompts, greedy_logprobs_params)

918
    def generate_beam_search(
919
920
921
922
923
        self,
        prompts: Union[List[str], List[List[int]]],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[List[int]], List[str]]]:
youkaichao's avatar
youkaichao committed
924
925
926
927
928
929
        if is_list_of(prompts, str, check="all"):
            prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
        else:
            prompts = [
                TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
            ]
930
931
932
        outputs = self.model.beam_search(
            prompts,
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
933
934
935
936
937
938
939
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

940
941
942
943
    def classify(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

Cyrus Leung's avatar
Cyrus Leung committed
944
945
946
947
948
949
950
951
952
953
954
955
    def encode(
        self,
        prompts: List[str],
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[List[float]]:
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

956
        req_outputs = self.model.embed(inputs)
Cyrus Leung's avatar
Cyrus Leung committed
957
        return [req_output.outputs.embedding for req_output in req_outputs]
958

959
960
961
962
    def score(
        self,
        text_1: Union[str, List[str]],
        text_2: Union[str, List[str]],
963
    ) -> List[float]:
964
        req_outputs = self.model.score(text_1, text_2)
965
        return [req_output.outputs.score for req_output in req_outputs]
966

967
968
969
970
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

971
972
973
974
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
975
        del self.model
976
        cleanup_dist_env_and_memory()
977

Woosuk Kwon's avatar
Woosuk Kwon committed
978

979
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
980
981
def vllm_runner():
    return VllmRunner
982
983
984
985
986
987
988
989
990


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
991
992
993
994
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
995
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1012
1013
1014
1015
1016
1017
1018


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1019
    return cuda_device_count_stateless()
1020
1021
1022


temp_dir = tempfile.gettempdir()
1023
1024
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1025
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1026
1027
1028
1029


@pytest.fixture
def dummy_opt_path():
1030
1031
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1032
        snapshot_download(repo_id="facebook/opt-125m",
1033
                          local_dir=_dummy_opt_path,
1034
1035
1036
1037
1038
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1039
        with open(json_path) as f:
1040
1041
1042
1043
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1058
        with open(json_path) as f:
1059
1060
1061
1062
1063
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1077
        with open(json_path) as f:
1078
1079
1080
1081
1082
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
            item.add_marker(skip_optional)