conftest.py 29 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
import pytest
import torch
15
import torch.nn as nn
16
import torch.nn.functional as F
17
from huggingface_hub import snapshot_download
18
from PIL import Image
19
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
20
                          BatchFeature)
21
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23

from vllm import LLM, SamplingParams
24
from vllm.assets.image import ImageAsset
25
from vllm.assets.video import VideoAsset
26
from vllm.config import TokenizerPoolConfig
27
from vllm.connections import global_http_connection
28
from vllm.distributed import (destroy_distributed_environment,
29
30
31
                              destroy_model_parallel,
                              init_distributed_environment,
                              initialize_model_parallel)
32
33
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
34
from vllm.logger import init_logger
35
from vllm.outputs import RequestOutput
36
from vllm.sequence import SampleLogprobs
37
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
38
                        identity, is_cpu)
39

40
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
41

42
43
44
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
45

46
47
48
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
                         List[List[Tuple[np.ndarray, int]]]]
49
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
50

51

52
def _read_prompts(filename: str) -> List[str]:
53
    with open(filename, "r") as f:
54
55
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57


58
59
60
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
61
62
63
64
65
66
67


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
68

69
70
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
71

72
73

class _ImageAssets(_ImageAssetsBase):
74
75

    def __init__(self) -> None:
76
77
78
79
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
80
81
82
83
84
85
86
87

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
88
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
89
90


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _VideoAssetsBase(UserList):
        pass
else:

    class _VideoAssetsBase(UserList[VideoAsset]):
        pass


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


116
117
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
118
119
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
120
121


122
123
124
125
126
127
128
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
    cleanup()


144
145
def cleanup():
    destroy_model_parallel()
146
    destroy_distributed_environment()
147
148
149
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
150
151
    if not is_cpu():
        torch.cuda.empty_cache()
152
153


154
@pytest.fixture()
155
def should_do_global_cleanup_after_test(request) -> bool:
156
157
158
159
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
160

161
    return not request.node.get_closest_marker("skip_global_cleanup")
162
163


164
@pytest.fixture(autouse=True)
165
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
166
    yield
167
168
    if should_do_global_cleanup_after_test:
        cleanup()
169
170


Woosuk Kwon's avatar
Woosuk Kwon committed
171
172
@pytest.fixture
def example_prompts() -> List[str]:
173
174
    prompts = []
    for filename in _TEST_PROMPTS:
175
        prompts += _read_prompts(filename)
176
177
178
    return prompts


179
180
181
182
183
184
185
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


186
@pytest.fixture
187
188
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
189
190
191
192
193
194
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
211
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
212
        DecoderPromptType.EMPTY_STR:
213
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
214
        DecoderPromptType.CUSTOM:
215
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
216
217
218
    }


219
220
221
222
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
223
        prompts += _read_prompts(filename)
224
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226


227
228
229
230
231
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


232
233
234
235
236
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


237
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
238

Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241

class HfRunner:

242
    def wrap_device(self, input: _T) -> _T:
243
        if not is_cpu():
244
245
246
            # Check if the input is already on the GPU
            if hasattr(input, 'device') and input.device.type == "cuda":
                return input  # Already on GPU, no need to move
247
248
            return input.to("cuda")
        else:
249
250
251
            # Check if the input is already on the CPU
            if hasattr(input, 'device') and input.device.type == "cpu":
                return input  # Already on CPU, no need to move
252
253
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
257
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
258
        *,
259
        model_kwargs: Optional[Dict[str, Any]] = None,
260
        is_embedding_model: bool = False,
261
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
262
263
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
264
    ) -> None:
265
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
266

267
        self.model_name = model_name
268

269
        if is_embedding_model:
270
271
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
272
273
274
275
276
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
277
        else:
278
            model_kwargs = model_kwargs if model_kwargs is not None else {}
279
            self.model = self.wrap_device(
280
                auto_cls.from_pretrained(
281
282
283
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
284
                    **model_kwargs,
285
                ))
286
287
288
289
290
291
292

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

293
294
295
296
297
298
299
300
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
301

302
303
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
    def generate(
        self,
        prompts: List[str],
307
        images: Optional[PromptImageInput] = None,
308
        videos: Optional[List[np.ndarray]] = None,
309
        **kwargs: Any,
310
    ) -> List[Tuple[List[List[int]], List[str]]]:
311
312
        if images:
            assert len(prompts) == len(images)
313
314

        outputs: List[Tuple[List[List[int]], List[str]]] = []
315
        for i, prompt in enumerate(prompts):
316
317
318
319
320
321
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
322
323
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
324
325

            inputs = self.processor(**processor_kwargs)
326
            inputs = self.postprocess_inputs(inputs)
327

Woosuk Kwon's avatar
Woosuk Kwon committed
328
            output_ids = self.model.generate(
329
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
332
                use_cache=True,
                **kwargs,
            )
333
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
337
338
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
341
342
343
344
345
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
346
        images: Optional[PromptImageInput] = None,
347
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
348
    ) -> List[Tuple[List[int], str]]:
349
350
        outputs = self.generate(prompts,
                                do_sample=False,
351
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
352
353
                                images=images,
                                **kwargs)
354
355
356

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
357
358
359
360
361
362

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
363
    ) -> List[Tuple[List[List[int]], List[str]]]:
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
378

379
380
381
382
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
383
        images: Optional[PromptImageInput] = None,
384
        videos: Optional[List[np.ndarray]] = None,
385
        **kwargs: Any,
386
    ) -> List[List[torch.Tensor]]:
387
388
389
390
391
392
393
394
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
395
396
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
397
398

            inputs = self.processor(**processor_kwargs)
399
            inputs = self.postprocess_inputs(inputs)
400

401
            output = self.model.generate(
402
                **self.wrap_device(inputs),
403
404
405
406
407
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
408
                **kwargs,
409
            )
410
            seq_logprobs: List[torch.Tensor] = []
411
412
413
414
415
416
417
418
419
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
420
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
421
422
423
424
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    def _hidden_states_to_logprobs(
        self,
        hidden_states,
        num_logprobs,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs: List[torch.Tensor] = []
        output_len = len(hidden_states)
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
                last_hidden_states,
                self.model.get_output_embeddings().weight.t(),
            )
            if getattr(self.model.get_output_embeddings(), "bias",
                       None) is not None:
                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

463
464
465
466
467
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
468
469
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
470
        videos: Optional[List[np.ndarray]] = None,
471
        **kwargs: Any,
472
473
474
475
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
476

477
478
479
480
481
482
483
484
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

485
486
487
488
489
            if audios is not None:
                audio, sr = audios[i]
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr

490
491
            if videos is not None:
                processor_kwargs["videos"] = videos[i]
492
            inputs = self.processor(**processor_kwargs)
493
            inputs = self.postprocess_inputs(inputs)
494

495
            output = self.model.generate(
496
                **self.wrap_device(inputs),
497
498
499
500
501
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
502
                **kwargs,
503
504
            )

505
506
507
508
509
510
511
512
513
514
515
516
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
517

518
519
520
521
522
523
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
524
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
525
526
527
528
529
530
531
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
532

533
534
535
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
536

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
            encoder_input_ids = self.wrap_device(
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
            decoder_input_ids = (
                None if decoder_prompt is None else self.wrap_device(
                    self.tokenizer(decoder_prompt,
                                   return_tensors="pt").input_ids))

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
562
563
564
565
566
567
568
569
570
571
572

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

573
574
575
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

576
577
578
579
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
580
581
582
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
583

Cyrus Leung's avatar
Cyrus Leung committed
584
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
585
586
587
588
589
590
591
592
593
594
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
595
596
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
597
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
598
        dtype: str = "half",
599
        disable_log_stats: bool = True,
600
        tensor_parallel_size: int = 1,
601
602
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
603
        swap_space: int = 4,
604
        enforce_eager: Optional[bool] = False,
605
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
606
607
608
609
610
611
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
612
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
613
            enforce_eager=enforce_eager,
614
            disable_log_stats=disable_log_stats,
615
            tensor_parallel_size=tensor_parallel_size,
616
            max_model_len=max_model_len,
617
618
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
619
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
620
621
622
623
624
625
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
626
        images: Optional[PromptImageInput] = None,
627
    ) -> List[Tuple[List[List[int]], List[str]]]:
628
        if images is not None:
629
            assert len(prompts) == len(images)
630

631
632
633
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
634
                inputs[i]["multi_modal_data"] = {"image": image}
635

636
        req_outputs = self.model.generate(inputs,
637
                                          sampling_params=sampling_params)
638
639

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
640
641
642
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
643
644
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
645
646
            for sample in req_output.outputs:
                output_str = sample.text
647
                output_ids = list(sample.token_ids)
648
649
650
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
651
652
        return outputs

653
    @staticmethod
654
655
656
657
658
659
660
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
661
                output_ids = list(sample.token_ids)
662
663
664
665
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

666
667
668
669
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
670
671
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
672
        videos: Optional[PromptVideoInput] = None,
673
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
674
675
        assert sampling_params.logprobs is not None

676
677
678
        if images is not None:
            assert len(prompts) == len(images)

679
680
681
        if videos is not None:
            assert len(prompts) == len(videos)

682
683
684
685
686
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

687
688
689
690
        if audios is not None:
            for i, audio in enumerate(audios):
                inputs[i]["multi_modal_data"] = {"audio": audio}

691
692
693
694
695
        if videos is not None:
            for i, video in enumerate(videos):
                inputs[i]["multi_modal_data"] = {"video": video}
        print(f"[INPUTS!!!!]: {inputs}, {sampling_params}")

696
        req_outputs = self.model.generate(inputs,
697
                                          sampling_params=sampling_params)
698
699
700
701
        return self._final_steps_generate_w_logprobs(req_outputs)

    def generate_encoder_decoder_w_logprobs(
        self,
702
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
703
704
705
706
707
708
709
710
711
712
        sampling_params: SamplingParams,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
        return self._final_steps_generate_w_logprobs(req_outputs)
713

Woosuk Kwon's avatar
Woosuk Kwon committed
714
715
716
717
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
718
        images: Optional[PromptImageInput] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
719
720
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
721
        outputs = self.generate(prompts, greedy_params, images=images)
722
723
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
724

725
726
727
728
729
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
730
731
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
732
        videos: Optional[PromptVideoInput] = None,
733
        stop_token_ids: Optional[List[int]] = None,
734
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
735
736
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
737
738
                                                logprobs=num_logprobs,
                                                stop_token_ids=stop_token_ids)
739
740
        outputs = self.generate_w_logprobs(prompts,
                                           greedy_logprobs_params,
741
                                           images=images,
742
743
                                           audios=audios,
                                           videos=videos)
744
745
746
747

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

748
749
    def generate_encoder_decoder_greedy_logprobs(
        self,
750
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                use_beam_search=False,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

        outputs = self.generate_encoder_decoder_w_logprobs(
            encoder_decoder_prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

768
769
770
771
772
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
773
    ) -> List[Tuple[List[List[int]], List[str]]]:
774
775
776
777
778
779
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
780

781
782
783
784
785
786
787
788
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

789
790
791
792
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
793
794
795
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
796

797
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
798
799
def vllm_runner():
    return VllmRunner
800
801
802
803
804
805
806
807
808


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
809
810
811
812
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
813
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
830
831
832
833
834
835
836


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

837
    return cuda_device_count_stateless()
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860


temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")


@pytest.fixture
def dummy_opt_path():
    json_path = os.path.join(_dummy_path, "config.json")
    if not os.path.exists(_dummy_path):
        snapshot_download(repo_id="facebook/opt-125m",
                          local_dir=_dummy_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_path