conftest.py 30.3 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
import pytest
import torch
15
import torch.nn as nn
16
import torch.nn.functional as F
17
from huggingface_hub import snapshot_download
18
from PIL import Image
19
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
20
                          BatchFeature)
21
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm import LLM, SamplingParams
26
from vllm.assets.image import ImageAsset
27
from vllm.assets.video import VideoAsset
28
from vllm.config import TokenizerPoolConfig
29
from vllm.connections import global_http_connection
30
from vllm.distributed import (destroy_distributed_environment,
31
32
33
                              destroy_model_parallel,
                              init_distributed_environment,
                              initialize_model_parallel)
34
35
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
36
from vllm.logger import init_logger
37
from vllm.outputs import RequestOutput
38
from vllm.sampling_params import BeamSearchParams
39
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
40
                        identity, is_cpu)
41

42
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
46
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
47

48
49
50
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
                         List[List[Tuple[np.ndarray, int]]]]
51
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
52

53

54
def _read_prompts(filename: str) -> List[str]:
55
    with open(filename, "r") as f:
56
57
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59


60
61
62
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
63
64
65
66
67
68
69


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
70

71
72
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
73

74
75

class _ImageAssets(_ImageAssetsBase):
76
77

    def __init__(self) -> None:
78
79
80
81
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
82
83
84
85
86
87
88
89

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
90
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
91
92


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _VideoAssetsBase(UserList):
        pass
else:

    class _VideoAssetsBase(UserList[VideoAsset]):
        pass


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


118
119
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
120
121
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
122
123


124
125
126
127
128
129
130
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
    cleanup()


146
147
def cleanup():
    destroy_model_parallel()
148
    destroy_distributed_environment()
149
150
151
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
152
153
    if not is_cpu():
        torch.cuda.empty_cache()
154
155


156
@pytest.fixture()
157
def should_do_global_cleanup_after_test(request) -> bool:
158
159
160
161
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
162

163
    return not request.node.get_closest_marker("skip_global_cleanup")
164
165


166
@pytest.fixture(autouse=True)
167
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
168
    yield
169
170
    if should_do_global_cleanup_after_test:
        cleanup()
171
172


173
174
175
176
177
178
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
@pytest.fixture
def example_prompts() -> List[str]:
181
182
    prompts = []
    for filename in _TEST_PROMPTS:
183
        prompts += _read_prompts(filename)
184
185
186
    return prompts


187
188
189
190
191
192
193
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


194
@pytest.fixture
195
196
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
197
198
199
200
201
202
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
219
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
220
        DecoderPromptType.EMPTY_STR:
221
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
222
        DecoderPromptType.CUSTOM:
223
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
224
225
226
    }


227
228
229
230
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
231
        prompts += _read_prompts(filename)
232
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234


235
236
237
238
239
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


240
241
242
243
244
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


245
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
246

Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249

class HfRunner:

250
251
252
253
254
255
256
257
    def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
        if device is None:
            return self.wrap_device(input, "cpu" if is_cpu() else "cuda")

        if hasattr(input, "device") and input.device.type == device:
            return input

        return input.to(device)
258

Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
262
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
263
        *,
264
        model_kwargs: Optional[Dict[str, Any]] = None,
265
        is_embedding_model: bool = False,
266
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
267
268
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
269
    ) -> None:
270
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
271

272
        self.model_name = model_name
273

274
        if is_embedding_model:
275
276
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
277
278
279
280
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
281
                    trust_remote_code=True,
282
                ).to(dtype=torch_dtype))
283
        else:
284
            model_kwargs = model_kwargs if model_kwargs is not None else {}
285
            self.model = self.wrap_device(
286
                auto_cls.from_pretrained(
287
288
289
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
290
                    **model_kwargs,
291
                ))
292
293
294
295
296
297
298

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

299
300
301
302
303
304
305
306
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
307

308
309
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
310
311
312
    def generate(
        self,
        prompts: List[str],
313
        images: Optional[PromptImageInput] = None,
314
        videos: Optional[List[np.ndarray]] = None,
315
        **kwargs: Any,
316
    ) -> List[Tuple[List[List[int]], List[str]]]:
317
318
        if images:
            assert len(prompts) == len(images)
319
320

        outputs: List[Tuple[List[List[int]], List[str]]] = []
321
        for i, prompt in enumerate(prompts):
322
323
324
325
326
327
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
328
329
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
330
331

            inputs = self.processor(**processor_kwargs)
332
            inputs = self.postprocess_inputs(inputs)
333

Woosuk Kwon's avatar
Woosuk Kwon committed
334
            output_ids = self.model.generate(
335
                **self.wrap_device(inputs, device=self.model.device.type),
Woosuk Kwon's avatar
Woosuk Kwon committed
336
337
338
                use_cache=True,
                **kwargs,
            )
339
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341
342
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
343
344
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
347
348
349
350
351
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
352
        images: Optional[PromptImageInput] = None,
353
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
354
    ) -> List[Tuple[List[int], str]]:
355
356
        outputs = self.generate(prompts,
                                do_sample=False,
357
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
358
359
                                images=images,
                                **kwargs)
360
361
362

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
363
364
365
366
367
368

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
369
    ) -> List[Tuple[List[List[int]], List[str]]]:
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
384

385
386
387
388
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
389
        images: Optional[PromptImageInput] = None,
390
        videos: Optional[List[np.ndarray]] = None,
391
        **kwargs: Any,
392
    ) -> List[List[torch.Tensor]]:
393
394
395
396
397
398
399
400
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
401
402
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
403
404

            inputs = self.processor(**processor_kwargs)
405
            inputs = self.postprocess_inputs(inputs)
406

407
            output = self.model.generate(
408
                **self.wrap_device(inputs, device=self.model.device.type),
409
410
411
412
413
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
414
                **kwargs,
415
            )
416
417
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
418
419
420
            all_logprobs.append(seq_logprobs)
        return all_logprobs

421
    def _hidden_states_to_seq_logprobs(
422
        self,
423
424
425
426
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
    ) -> List[torch.Tensor]:
        output_embeddings = self.model.get_output_embeddings()

427
428
429
430
        seq_logprobs: List[torch.Tensor] = []
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
431
432
                last_hidden_states.to(output_embeddings.weight.device),
                output_embeddings.weight.t(),
433
            )
434
435
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
436
437
438
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

439
440
441
442
443
444
445
446
447
448
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
        num_logprobs: int,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

468
469
470
471
472
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
473
474
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
475
        videos: Optional[List[np.ndarray]] = None,
476
        **kwargs: Any,
477
    ) -> List[TokensTextLogprobs]:
478
479
480
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
481

482
483
484
485
486
487
488
489
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

490
491
492
493
494
            if audios is not None:
                audio, sr = audios[i]
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr

495
496
            if videos is not None:
                processor_kwargs["videos"] = videos[i]
497
            inputs = self.processor(**processor_kwargs)
498
            inputs = self.postprocess_inputs(inputs)
499

500
            output = self.model.generate(
501
                **self.wrap_device(inputs, device=self.model.device.type),
502
503
504
505
506
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
507
                **kwargs,
508
509
            )

510
511
512
513
514
515
516
517
518
519
520
521
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
522

523
524
525
526
527
528
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
529
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
530
531
532
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
533
    ) -> List[TokensTextLogprobs]:
534
535
536
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
537

538
539
540
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
541

542
543
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
544

545
            encoder_input_ids = self.wrap_device(
546
547
548
549
550
551
552
553
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
                device=self.model.device.type,
            )

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
                decoder_input_ids = self.wrap_device(
554
                    self.tokenizer(decoder_prompt,
555
556
557
                                   return_tensors="pt").input_ids,
                    device=self.model.device.type,
                )
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
575
576
577
578
579
580
581
582
583
584
585

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

586
587
588
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

589
590
591
592
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
593
594
595
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
596

Cyrus Leung's avatar
Cyrus Leung committed
597
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
598
599
600
601
602
603
604
605
606
607
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
608
609
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
610
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
611
        dtype: str = "half",
612
        disable_log_stats: bool = True,
613
        tensor_parallel_size: int = 1,
614
615
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
616
        swap_space: int = 4,
617
        enforce_eager: Optional[bool] = False,
618
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
619
620
621
622
623
624
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
625
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
626
            enforce_eager=enforce_eager,
627
            disable_log_stats=disable_log_stats,
628
            tensor_parallel_size=tensor_parallel_size,
629
            max_model_len=max_model_len,
630
631
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
632
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
633
634
635
636
637
638
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
639
        images: Optional[PromptImageInput] = None,
640
    ) -> List[Tuple[List[List[int]], List[str]]]:
641
        if images is not None:
642
            assert len(prompts) == len(images)
643

644
645
646
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
647
                inputs[i]["multi_modal_data"] = {"image": image}
648

649
        req_outputs = self.model.generate(inputs,
650
                                          sampling_params=sampling_params)
651
652

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
653
654
655
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
656
657
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
658
659
            for sample in req_output.outputs:
                output_str = sample.text
660
                output_ids = list(sample.token_ids)
661
662
663
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
664
665
        return outputs

666
    @staticmethod
667
668
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
669
670
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
671
        for req_output in req_outputs:
672
            assert len(req_output.outputs) > 0
673
674
            for sample in req_output.outputs:
                output_str = sample.text
675
                output_ids = list(sample.token_ids)
676
                output_logprobs = sample.logprobs
677
678
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
679
680
        return outputs

681
682
683
684
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
685
686
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
687
        videos: Optional[PromptVideoInput] = None,
688
689
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
690
691
692
        if images is not None:
            assert len(prompts) == len(images)

693
694
695
        if videos is not None:
            assert len(prompts) == len(videos)

696
697
698
699
700
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

701
702
703
704
        if audios is not None:
            for i, audio in enumerate(audios):
                inputs[i]["multi_modal_data"] = {"audio": audio}

705
706
707
708
        if videos is not None:
            for i, video in enumerate(videos):
                inputs[i]["multi_modal_data"] = {"video": video}

709
        req_outputs = self.model.generate(inputs,
710
                                          sampling_params=sampling_params)
711
712
713
714
715
716
717

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
718
719
720

    def generate_encoder_decoder_w_logprobs(
        self,
721
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
722
        sampling_params: SamplingParams,
723
724
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
725
726
727
728
729
730
731
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
732
733
734
735
736
737
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
738

Woosuk Kwon's avatar
Woosuk Kwon committed
739
740
741
742
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
743
        images: Optional[PromptImageInput] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
744
745
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
746
        outputs = self.generate(prompts, greedy_params, images=images)
747
748
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
749

750
751
752
753
754
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
755
        num_prompt_logprobs: Optional[int] = None,
756
757
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
758
        videos: Optional[PromptVideoInput] = None,
759
        stop_token_ids: Optional[List[int]] = None,
760
761
762
763
764
765
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
766
            prompt_logprobs=num_prompt_logprobs,
767
768
769
770
771
772
773
            stop_token_ids=stop_token_ids)

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
                                        videos=videos)
774

775
776
    def generate_encoder_decoder_greedy_logprobs(
        self,
777
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
778
779
        max_tokens: int,
        num_logprobs: int,
780
781
782
783
784
785
786
787
788
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
789
790
791
792
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

793
        return self.generate_encoder_decoder_w_logprobs(
794
795
            encoder_decoder_prompts, greedy_logprobs_params)

796
    def generate_beam_search(
797
798
799
800
801
        self,
        prompts: Union[List[str], List[List[int]]],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[List[int]], List[str]]]:
802
803
804
        outputs = self.model.beam_search(
            prompts,
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
805
806
807
808
809
810
811
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

812
813
814
815
816
817
818
819
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

820
821
822
823
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
824
825
826
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
827

828
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
829
830
def vllm_runner():
    return VllmRunner
831
832
833
834
835
836
837
838
839


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
840
841
842
843
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
844
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
861
862
863
864
865
866
867


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

868
    return cuda_device_count_stateless()
869
870
871


temp_dir = tempfile.gettempdir()
872
873
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
874
875
876
877


@pytest.fixture
def dummy_opt_path():
878
879
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
880
        snapshot_download(repo_id="facebook/opt-125m",
881
                          local_dir=_dummy_opt_path,
882
883
884
885
886
887
888
889
890
891
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path