conftest.py 30.7 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
import pytest
import torch
15
import torch.nn as nn
16
import torch.nn.functional as F
17
from huggingface_hub import snapshot_download
18
from PIL import Image
19
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
20
                          BatchFeature)
21
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm import LLM, SamplingParams
26
from vllm.assets.image import ImageAsset
27
from vllm.assets.video import VideoAsset
28
from vllm.config import TokenizerPoolConfig
29
from vllm.connections import global_http_connection
30
from vllm.distributed import (destroy_distributed_environment,
31
32
33
                              destroy_model_parallel,
                              init_distributed_environment,
                              initialize_model_parallel)
34
35
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
36
from vllm.logger import init_logger
37
from vllm.outputs import RequestOutput
38
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
39
                        identity, is_cpu)
40

41
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
45
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
46

47
48
49
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
                         List[List[Tuple[np.ndarray, int]]]]
50
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
51

52

53
def _read_prompts(filename: str) -> List[str]:
54
    with open(filename, "r") as f:
55
56
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58


59
60
61
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
62
63
64
65
66
67
68


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
69

70
71
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
72

73
74

class _ImageAssets(_ImageAssetsBase):
75
76

    def __init__(self) -> None:
77
78
79
80
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
81
82
83
84
85
86
87
88

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
89
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _VideoAssetsBase(UserList):
        pass
else:

    class _VideoAssetsBase(UserList[VideoAsset]):
        pass


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


117
118
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
119
120
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
121
122


123
124
125
126
127
128
129
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
    cleanup()


145
146
def cleanup():
    destroy_model_parallel()
147
    destroy_distributed_environment()
148
149
150
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
151
152
    if not is_cpu():
        torch.cuda.empty_cache()
153
154


155
@pytest.fixture()
156
def should_do_global_cleanup_after_test(request) -> bool:
157
158
159
160
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
161

162
    return not request.node.get_closest_marker("skip_global_cleanup")
163
164


165
@pytest.fixture(autouse=True)
166
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
167
    yield
168
169
    if should_do_global_cleanup_after_test:
        cleanup()
170
171


172
173
174
175
176
177
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
178
179
@pytest.fixture
def example_prompts() -> List[str]:
180
181
    prompts = []
    for filename in _TEST_PROMPTS:
182
        prompts += _read_prompts(filename)
183
184
185
    return prompts


186
187
188
189
190
191
192
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


193
@pytest.fixture
194
195
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
196
197
198
199
200
201
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
218
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
219
        DecoderPromptType.EMPTY_STR:
220
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
221
        DecoderPromptType.CUSTOM:
222
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
223
224
225
    }


226
227
228
229
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
230
        prompts += _read_prompts(filename)
231
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233


234
235
236
237
238
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


239
240
241
242
243
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


244
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
245

Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248

class HfRunner:

249
250
251
252
253
254
255
256
    def wrap_device(self, input: _T, device: Optional[str] = None) -> _T:
        if device is None:
            return self.wrap_device(input, "cpu" if is_cpu() else "cuda")

        if hasattr(input, "device") and input.device.type == device:
            return input

        return input.to(device)
257

Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
261
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
262
        *,
263
        model_kwargs: Optional[Dict[str, Any]] = None,
264
        is_embedding_model: bool = False,
265
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
266
267
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
268
    ) -> None:
269
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
270

271
        self.model_name = model_name
272

273
        if is_embedding_model:
274
275
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
276
277
278
279
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
280
                    trust_remote_code=True,
281
                ).to(dtype=torch_dtype))
282
        else:
283
            model_kwargs = model_kwargs if model_kwargs is not None else {}
284
            self.model = self.wrap_device(
285
                auto_cls.from_pretrained(
286
287
288
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
289
                    **model_kwargs,
290
                ))
291
292
293
294
295
296
297

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

298
299
300
301
302
303
304
305
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
306

307
308
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
    def generate(
        self,
        prompts: List[str],
312
        images: Optional[PromptImageInput] = None,
313
        videos: Optional[List[np.ndarray]] = None,
314
        **kwargs: Any,
315
    ) -> List[Tuple[List[List[int]], List[str]]]:
316
317
        if images:
            assert len(prompts) == len(images)
318
319

        outputs: List[Tuple[List[List[int]], List[str]]] = []
320
        for i, prompt in enumerate(prompts):
321
322
323
324
325
326
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
327
328
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
329
330

            inputs = self.processor(**processor_kwargs)
331
            inputs = self.postprocess_inputs(inputs)
332

Woosuk Kwon's avatar
Woosuk Kwon committed
333
            output_ids = self.model.generate(
334
                **self.wrap_device(inputs, device=self.model.device.type),
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
                use_cache=True,
                **kwargs,
            )
338
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
341
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
342
343
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
344
345
346
347
348
349
350
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
351
        images: Optional[PromptImageInput] = None,
352
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
353
    ) -> List[Tuple[List[int], str]]:
354
355
        outputs = self.generate(prompts,
                                do_sample=False,
356
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
357
358
                                images=images,
                                **kwargs)
359
360
361

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
362
363
364
365
366
367

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
368
    ) -> List[Tuple[List[List[int]], List[str]]]:
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
383

384
385
386
387
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
388
        images: Optional[PromptImageInput] = None,
389
        videos: Optional[List[np.ndarray]] = None,
390
        **kwargs: Any,
391
    ) -> List[List[torch.Tensor]]:
392
393
394
395
396
397
398
399
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
400
401
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
402
403

            inputs = self.processor(**processor_kwargs)
404
            inputs = self.postprocess_inputs(inputs)
405

406
            output = self.model.generate(
407
                **self.wrap_device(inputs, device=self.model.device.type),
408
409
410
411
412
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
413
                **kwargs,
414
            )
415
416
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
417
418
419
            all_logprobs.append(seq_logprobs)
        return all_logprobs

420
    def _hidden_states_to_seq_logprobs(
421
        self,
422
423
424
425
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
    ) -> List[torch.Tensor]:
        output_embeddings = self.model.get_output_embeddings()

426
427
428
429
        seq_logprobs: List[torch.Tensor] = []
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
430
431
                last_hidden_states.to(output_embeddings.weight.device),
                output_embeddings.weight.t(),
432
            )
433
434
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
435
436
437
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

438
439
440
441
442
443
444
445
446
447
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
        num_logprobs: int,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

467
468
469
470
471
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
472
473
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
474
        videos: Optional[List[np.ndarray]] = None,
475
        **kwargs: Any,
476
    ) -> List[TokensTextLogprobs]:
477
478
479
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
480

481
482
483
484
485
486
487
488
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

489
490
491
492
493
            if audios is not None:
                audio, sr = audios[i]
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr

494
495
            if videos is not None:
                processor_kwargs["videos"] = videos[i]
496
            inputs = self.processor(**processor_kwargs)
497
            inputs = self.postprocess_inputs(inputs)
498

499
            output = self.model.generate(
500
                **self.wrap_device(inputs, device=self.model.device.type),
501
502
503
504
505
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
506
                **kwargs,
507
508
            )

509
510
511
512
513
514
515
516
517
518
519
520
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
521

522
523
524
525
526
527
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
528
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
529
530
531
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
532
    ) -> List[TokensTextLogprobs]:
533
534
535
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
536

537
538
539
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
540

541
542
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
543

544
            encoder_input_ids = self.wrap_device(
545
546
547
548
549
550
551
552
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids,
                device=self.model.device.type,
            )

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
                decoder_input_ids = self.wrap_device(
553
                    self.tokenizer(decoder_prompt,
554
555
556
                                   return_tensors="pt").input_ids,
                    device=self.model.device.type,
                )
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
574
575
576
577
578
579
580
581
582
583
584

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

585
586
587
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

588
589
590
591
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
592
593
594
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
595

Cyrus Leung's avatar
Cyrus Leung committed
596
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
597
598
599
600
601
602
603
604
605
606
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
607
608
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
609
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
610
        dtype: str = "half",
611
        disable_log_stats: bool = True,
612
        tensor_parallel_size: int = 1,
613
614
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
615
        swap_space: int = 4,
616
        enforce_eager: Optional[bool] = False,
617
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
618
619
620
621
622
623
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
624
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
625
            enforce_eager=enforce_eager,
626
            disable_log_stats=disable_log_stats,
627
            tensor_parallel_size=tensor_parallel_size,
628
            max_model_len=max_model_len,
629
630
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
631
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
632
633
634
635
636
637
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
638
        images: Optional[PromptImageInput] = None,
639
    ) -> List[Tuple[List[List[int]], List[str]]]:
640
        if images is not None:
641
            assert len(prompts) == len(images)
642

643
644
645
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
646
                inputs[i]["multi_modal_data"] = {"image": image}
647

648
        req_outputs = self.model.generate(inputs,
649
                                          sampling_params=sampling_params)
650
651

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
652
653
654
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
655
656
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
657
658
            for sample in req_output.outputs:
                output_str = sample.text
659
                output_ids = list(sample.token_ids)
660
661
662
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
663
664
        return outputs

665
    @staticmethod
666
667
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
668
669
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
670
        for req_output in req_outputs:
671
            assert len(req_output.outputs) > 0
672
673
            for sample in req_output.outputs:
                output_str = sample.text
674
                output_ids = list(sample.token_ids)
675
                output_logprobs = sample.logprobs
676
677
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
678
679
        return outputs

680
681
682
683
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
684
685
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
686
        videos: Optional[PromptVideoInput] = None,
687
688
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
689
690
691
        if images is not None:
            assert len(prompts) == len(images)

692
693
694
        if videos is not None:
            assert len(prompts) == len(videos)

695
696
697
698
699
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

700
701
702
703
        if audios is not None:
            for i, audio in enumerate(audios):
                inputs[i]["multi_modal_data"] = {"audio": audio}

704
705
706
707
        if videos is not None:
            for i, video in enumerate(videos):
                inputs[i]["multi_modal_data"] = {"video": video}

708
        req_outputs = self.model.generate(inputs,
709
                                          sampling_params=sampling_params)
710
711
712
713
714
715
716

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
717
718
719

    def generate_encoder_decoder_w_logprobs(
        self,
720
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
721
        sampling_params: SamplingParams,
722
723
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
724
725
726
727
728
729
730
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
731
732
733
734
735
736
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
737

Woosuk Kwon's avatar
Woosuk Kwon committed
738
739
740
741
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
742
        images: Optional[PromptImageInput] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
743
744
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
745
        outputs = self.generate(prompts, greedy_params, images=images)
746
747
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
748

749
750
751
752
753
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
754
        num_prompt_logprobs: Optional[int] = None,
755
756
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
757
        videos: Optional[PromptVideoInput] = None,
758
        stop_token_ids: Optional[List[int]] = None,
759
760
761
762
763
764
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
765
            prompt_logprobs=num_prompt_logprobs,
766
767
768
769
770
771
772
            stop_token_ids=stop_token_ids)

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
                                        videos=videos)
773

774
775
    def generate_encoder_decoder_greedy_logprobs(
        self,
776
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
777
778
        max_tokens: int,
        num_logprobs: int,
779
780
781
782
783
784
785
786
787
788
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            use_beam_search=False,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
789
790
791
792
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

793
        return self.generate_encoder_decoder_w_logprobs(
794
795
            encoder_decoder_prompts, greedy_logprobs_params)

796
797
798
799
800
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
801
    ) -> List[Tuple[List[List[int]], List[str]]]:
802
803
804
805
806
807
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
808

809
810
811
812
813
814
815
816
817
818
819
820
821
822
    def generate_beam_search_new(
        self,
        prompts: Union[List[str], List[List[int]]],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        outputs = self.model.beam_search(prompts, beam_width, max_tokens)
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

823
824
825
826
827
828
829
830
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

831
832
833
834
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
835
836
837
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
838

839
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
840
841
def vllm_runner():
    return VllmRunner
842
843
844
845
846
847
848
849
850


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
851
852
853
854
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
855
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
872
873
874
875
876
877
878


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

879
    return cuda_device_count_stateless()
880
881
882


temp_dir = tempfile.gettempdir()
883
884
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
885
886
887
888


@pytest.fixture
def dummy_opt_path():
889
890
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
891
        snapshot_download(repo_id="facebook/opt-125m",
892
                          local_dir=_dummy_opt_path,
893
894
895
896
897
898
899
900
901
902
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path