conftest.py 30.7 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
import pytest
import torch
15
import torch.nn as nn
16
import torch.nn.functional as F
17
from huggingface_hub import snapshot_download
18
from PIL import Image
19
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
20
                          BatchFeature)
21
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm import LLM, SamplingParams
26
from vllm.assets.image import ImageAsset
27
from vllm.assets.video import VideoAsset
28
from vllm.config import TokenizerPoolConfig
29
from vllm.connections import global_http_connection
30
from vllm.distributed import (destroy_distributed_environment,
31
32
33
                              destroy_model_parallel,
                              init_distributed_environment,
                              initialize_model_parallel)
34
35
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
36
from vllm.logger import init_logger
37
from vllm.outputs import RequestOutput
38
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
39
                        identity, is_cpu)
40

41
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
45
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
46

47
48
49
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
                         List[List[Tuple[np.ndarray, int]]]]
50
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
51

52

53
def _read_prompts(filename: str) -> List[str]:
54
    with open(filename, "r") as f:
55
56
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58


59
60
61
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
62
63
64
65
66
67
68


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
69

70
71
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
72

73
74

class _ImageAssets(_ImageAssetsBase):
75
76

    def __init__(self) -> None:
77
78
79
80
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
81
82
83
84
85
86
87
88

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
89
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _VideoAssetsBase(UserList):
        pass
else:

    class _VideoAssetsBase(UserList[VideoAsset]):
        pass


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


117
118
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
119
120
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
121
122


123
124
125
126
127
128
129
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
    cleanup()


145
146
def cleanup():
    destroy_model_parallel()
147
    destroy_distributed_environment()
148
149
150
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
151
152
    if not is_cpu():
        torch.cuda.empty_cache()
153
154


155
@pytest.fixture()
156
def should_do_global_cleanup_after_test(request) -> bool:
157
158
159
160
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
161

162
    return not request.node.get_closest_marker("skip_global_cleanup")
163
164


165
@pytest.fixture(autouse=True)
166
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
167
    yield
168
169
    if should_do_global_cleanup_after_test:
        cleanup()
170
171


172
173
174
175
176
177
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
178
179
@pytest.fixture
def example_prompts() -> List[str]:
180
181
    prompts = []
    for filename in _TEST_PROMPTS:
182
        prompts += _read_prompts(filename)
183
184
185
    return prompts


186
187
188
189
190
191
192
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


193
@pytest.fixture
194
195
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
196
197
198
199
200
201
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
218
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
219
        DecoderPromptType.EMPTY_STR:
220
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
221
        DecoderPromptType.CUSTOM:
222
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
223
224
225
    }


226
227
228
229
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
230
        prompts += _read_prompts(filename)
231
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233


234
235
236
237
238
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


239
240
241
242
243
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


244
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
245

Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248

class HfRunner:

249
250
251
252
253
254
255
256
    def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
        if device is None:
            return self.wrap_device(input, "cpu" if is_cpu() else "cuda")

        if hasattr(input, "device") and input.device.type == device:
            return input

        return input.to(device)
257

Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
261
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
262
        *,
263
        model_kwargs: Optional[Dict[str, Any]] = None,
264
        is_embedding_model: bool = False,
265
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
266
267
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
268
    ) -> None:
269
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
270

271
        self.model_name = model_name
272

273
        if is_embedding_model:
274
275
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
276
277
278
279
280
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
281
        else:
282
            model_kwargs = model_kwargs if model_kwargs is not None else {}
283
            self.model = self.wrap_device(
284
                auto_cls.from_pretrained(
285
286
287
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
288
                    **model_kwargs,
289
                ))
290
291
292
293
294
295
296

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

297
298
299
300
301
302
303
304
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
305

306
307
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
    def generate(
        self,
        prompts: List[str],
311
        images: Optional[PromptImageInput] = None,
312
        videos: Optional[List[np.ndarray]] = None,
313
        **kwargs: Any,
314
    ) -> List[Tuple[List[List[int]], List[str]]]:
315
316
        if images:
            assert len(prompts) == len(images)
317
318

        outputs: List[Tuple[List[List[int]], List[str]]] = []
319
        for i, prompt in enumerate(prompts):
320
321
322
323
324
325
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
326
327
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
328
329

            inputs = self.processor(**processor_kwargs)
330
            inputs = self.postprocess_inputs(inputs)
331

Woosuk Kwon's avatar
Woosuk Kwon committed
332
            output_ids = self.model.generate(
333
                **self.wrap_device(inputs, device=self.model.device.type),
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
                use_cache=True,
                **kwargs,
            )
337
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
338
339
340
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
341
342
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
345
346
347
348
349
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
350
        images: Optional[PromptImageInput] = None,
351
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
352
    ) -> List[Tuple[List[int], str]]:
353
354
        outputs = self.generate(prompts,
                                do_sample=False,
355
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
356
357
                                images=images,
                                **kwargs)
358
359
360

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
361
362
363
364
365
366

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
367
    ) -> List[Tuple[List[List[int]], List[str]]]:
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
382

383
384
385
386
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
387
        images: Optional[PromptImageInput] = None,
388
        videos: Optional[List[np.ndarray]] = None,
389
        **kwargs: Any,
390
    ) -> List[List[torch.Tensor]]:
391
392
393
394
395
396
397
398
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
399
400
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
401
402

            inputs = self.processor(**processor_kwargs)
403
            inputs = self.postprocess_inputs(inputs)
404

405
            output = self.model.generate(
406
                **self.wrap_device(inputs, device=self.model.device.type),
407
408
409
410
411
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
412
                **kwargs,
413
            )
414
415
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
416
417
418
            all_logprobs.append(seq_logprobs)
        return all_logprobs

419
    def _hidden_states_to_seq_logprobs(
420
        self,
421
422
423
424
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
    ) -> List[torch.Tensor]:
        output_embeddings = self.model.get_output_embeddings()

425
426
427
428
        seq_logprobs: List[torch.Tensor] = []
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
429
430
                last_hidden_states.to(output_embeddings.weight.device),
                output_embeddings.weight.t(),
431
            )
432
433
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
434
435
436
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

437
438
439
440
441
442
443
444
445
446
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
        num_logprobs: int,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

466
467
468
469
470
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
471
472
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
473
        videos: Optional[List[np.ndarray]] = None,
474
        **kwargs: Any,
475
    ) -> List[TokensTextLogprobs]:
476
477
478
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
479

480
481
482
483
484
485
486
487
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

488
489
490
491
492
            if audios is not None:
                audio, sr = audios[i]
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr

493
494
            if videos is not None:
                processor_kwargs["videos"] = videos[i]
495
            inputs = self.processor(**processor_kwargs)
496
            inputs = self.postprocess_inputs(inputs)
497

498
            output = self.model.generate(
499
                **self.wrap_device(inputs, device=self.model.device.type),
500
501
502
503
504
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
505
                **kwargs,
506
507
            )

508
509
510
511
512
513
514
515
516
517
518
519
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
520

521
522
523
524
525
526
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
527
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
528
529
530
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
531
    ) -> List[TokensTextLogprobs]:
532
533
534
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
535

536
537
538
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
539

540
541
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
542

543
            encoder_input_ids = self.wrap_device(
544
545
546
547
548
549
550
551
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
                device=self.model.device.type,
            )

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
                decoder_input_ids = self.wrap_device(
552
                    self.tokenizer(decoder_prompt,
553
554
555
                                   return_tensors="pt").input_ids,
                    device=self.model.device.type,
                )
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
573
574
575
576
577
578
579
580
581
582
583

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

584
585
586
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

587
588
589
590
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
591
592
593
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
594

Cyrus Leung's avatar
Cyrus Leung committed
595
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
596
597
598
599
600
601
602
603
604
605
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
606
607
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
608
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
609
        dtype: str = "half",
610
        disable_log_stats: bool = True,
611
        tensor_parallel_size: int = 1,
612
613
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
614
        swap_space: int = 4,
615
        enforce_eager: Optional[bool] = False,
616
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
617
618
619
620
621
622
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
623
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
624
            enforce_eager=enforce_eager,
625
            disable_log_stats=disable_log_stats,
626
            tensor_parallel_size=tensor_parallel_size,
627
            max_model_len=max_model_len,
628
629
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
630
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
631
632
633
634
635
636
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
637
        images: Optional[PromptImageInput] = None,
638
    ) -> List[Tuple[List[List[int]], List[str]]]:
639
        if images is not None:
640
            assert len(prompts) == len(images)
641

642
643
644
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
645
                inputs[i]["multi_modal_data"] = {"image": image}
646

647
        req_outputs = self.model.generate(inputs,
648
                                          sampling_params=sampling_params)
649
650

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
651
652
653
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
654
655
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
656
657
            for sample in req_output.outputs:
                output_str = sample.text
658
                output_ids = list(sample.token_ids)
659
660
661
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
662
663
        return outputs

664
    @staticmethod
665
666
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
667
668
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
669
        for req_output in req_outputs:
670
            assert len(req_output.outputs) > 0
671
672
            for sample in req_output.outputs:
                output_str = sample.text
673
                output_ids = list(sample.token_ids)
674
                output_logprobs = sample.logprobs
675
676
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
677
678
        return outputs

679
680
681
682
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
683
684
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
685
        videos: Optional[PromptVideoInput] = None,
686
687
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
688
689
690
        if images is not None:
            assert len(prompts) == len(images)

691
692
693
        if videos is not None:
            assert len(prompts) == len(videos)

694
695
696
697
698
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

699
700
701
702
        if audios is not None:
            for i, audio in enumerate(audios):
                inputs[i]["multi_modal_data"] = {"audio": audio}

703
704
705
706
        if videos is not None:
            for i, video in enumerate(videos):
                inputs[i]["multi_modal_data"] = {"video": video}

707
        req_outputs = self.model.generate(inputs,
708
                                          sampling_params=sampling_params)
709
710
711
712
713
714
715

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
716
717
718

    def generate_encoder_decoder_w_logprobs(
        self,
719
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
720
        sampling_params: SamplingParams,
721
722
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
723
724
725
726
727
728
729
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
730
731
732
733
734
735
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
736

Woosuk Kwon's avatar
Woosuk Kwon committed
737
738
739
740
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
741
        images: Optional[PromptImageInput] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
742
743
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
744
        outputs = self.generate(prompts, greedy_params, images=images)
745
746
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
747

748
749
750
751
752
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
753
        num_prompt_logprobs: Optional[int] = None,
754
755
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
756
        videos: Optional[PromptVideoInput] = None,
757
        stop_token_ids: Optional[List[int]] = None,
758
759
760
761
762
763
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
764
            prompt_logprobs=num_prompt_logprobs,
765
766
767
768
769
770
771
            stop_token_ids=stop_token_ids)

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
                                        videos=videos)
772

773
774
    def generate_encoder_decoder_greedy_logprobs(
        self,
775
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
776
777
        max_tokens: int,
        num_logprobs: int,
778
779
780
781
782
783
784
785
786
787
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            use_beam_search=False,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
788
789
790
791
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

792
        return self.generate_encoder_decoder_w_logprobs(
793
794
            encoder_decoder_prompts, greedy_logprobs_params)

795
796
797
798
799
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
800
    ) -> List[Tuple[List[List[int]], List[str]]]:
801
802
803
804
805
806
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
807

808
809
810
811
812
813
814
815
816
817
818
819
820
821
    def generate_beam_search_new(
        self,
        prompts: Union[List[str], List[List[int]]],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        outputs = self.model.beam_search(prompts, beam_width, max_tokens)
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

822
823
824
825
826
827
828
829
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

830
831
832
833
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
834
835
836
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
837

838
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
839
840
def vllm_runner():
    return VllmRunner
841
842
843
844
845
846
847
848
849


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
850
851
852
853
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
854
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
871
872
873
874
875
876
877


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

878
    return cuda_device_count_stateless()
879
880
881


temp_dir = tempfile.gettempdir()
882
883
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
884
885
886
887


@pytest.fixture
def dummy_opt_path():
888
889
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
890
        snapshot_download(repo_id="facebook/opt-125m",
891
                          local_dir=_dummy_opt_path,
892
893
894
895
896
897
898
899
900
901
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path