conftest.py 29.6 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
import pytest
import torch
15
import torch.nn as nn
16
import torch.nn.functional as F
17
from huggingface_hub import snapshot_download
18
from PIL import Image
19
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
20
                          BatchFeature)
21
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm import LLM, SamplingParams
26
from vllm.assets.image import ImageAsset
27
from vllm.assets.video import VideoAsset
28
from vllm.config import TokenizerPoolConfig
29
from vllm.connections import global_http_connection
30
from vllm.distributed import (destroy_distributed_environment,
31
32
33
                              destroy_model_parallel,
                              init_distributed_environment,
                              initialize_model_parallel)
34
35
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
36
from vllm.logger import init_logger
37
from vllm.outputs import RequestOutput
38
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
39
                        identity, is_cpu)
40

41
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
45
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
46

47
48
49
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
                         List[List[Tuple[np.ndarray, int]]]]
50
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
51

52

53
def _read_prompts(filename: str) -> List[str]:
54
    with open(filename, "r") as f:
55
56
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58


59
60
61
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
62
63
64
65
66
67
68


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
69

70
71
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
72

73
74

class _ImageAssets(_ImageAssetsBase):
75
76

    def __init__(self) -> None:
77
78
79
80
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
81
82
83
84
85
86
87
88

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
89
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _VideoAssetsBase(UserList):
        pass
else:

    class _VideoAssetsBase(UserList[VideoAsset]):
        pass


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


117
118
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
119
120
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
121
122


123
124
125
126
127
128
129
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
    cleanup()


145
146
def cleanup():
    destroy_model_parallel()
147
    destroy_distributed_environment()
148
149
150
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
151
152
    if not is_cpu():
        torch.cuda.empty_cache()
153
154


155
@pytest.fixture()
156
def should_do_global_cleanup_after_test(request) -> bool:
157
158
159
160
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
161

162
    return not request.node.get_closest_marker("skip_global_cleanup")
163
164


165
@pytest.fixture(autouse=True)
166
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
167
    yield
168
169
    if should_do_global_cleanup_after_test:
        cleanup()
170
171


Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
@pytest.fixture
def example_prompts() -> List[str]:
174
175
    prompts = []
    for filename in _TEST_PROMPTS:
176
        prompts += _read_prompts(filename)
177
178
179
    return prompts


180
181
182
183
184
185
186
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


187
@pytest.fixture
188
189
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
190
191
192
193
194
195
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
212
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
213
        DecoderPromptType.EMPTY_STR:
214
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
215
        DecoderPromptType.CUSTOM:
216
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
217
218
219
    }


220
221
222
223
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
224
        prompts += _read_prompts(filename)
225
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
226
227


228
229
230
231
232
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


233
234
235
236
237
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


238
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
239

Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242

class HfRunner:

243
    def wrap_device(self, input: _T) -> _T:
244
        if not is_cpu():
245
246
247
            # Check if the input is already on the GPU
            if hasattr(input, 'device') and input.device.type == "cuda":
                return input  # Already on GPU, no need to move
248
249
            return input.to("cuda")
        else:
250
251
252
            # Check if the input is already on the CPU
            if hasattr(input, 'device') and input.device.type == "cpu":
                return input  # Already on CPU, no need to move
253
254
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
259
        *,
260
        model_kwargs: Optional[Dict[str, Any]] = None,
261
        is_embedding_model: bool = False,
262
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
263
264
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
265
    ) -> None:
266
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
267

268
        self.model_name = model_name
269

270
        if is_embedding_model:
271
272
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
273
274
275
276
277
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
278
        else:
279
            model_kwargs = model_kwargs if model_kwargs is not None else {}
280
            self.model = self.wrap_device(
281
                auto_cls.from_pretrained(
282
283
284
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
285
                    **model_kwargs,
286
                ))
287
288
289
290
291
292
293

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

294
295
296
297
298
299
300
301
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
302

303
304
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
307
    def generate(
        self,
        prompts: List[str],
308
        images: Optional[PromptImageInput] = None,
309
        videos: Optional[List[np.ndarray]] = None,
310
        **kwargs: Any,
311
    ) -> List[Tuple[List[List[int]], List[str]]]:
312
313
        if images:
            assert len(prompts) == len(images)
314
315

        outputs: List[Tuple[List[List[int]], List[str]]] = []
316
        for i, prompt in enumerate(prompts):
317
318
319
320
321
322
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
323
324
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
325
326

            inputs = self.processor(**processor_kwargs)
327
            inputs = self.postprocess_inputs(inputs)
328

Woosuk Kwon's avatar
Woosuk Kwon committed
329
            output_ids = self.model.generate(
330
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
333
                use_cache=True,
                **kwargs,
            )
334
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
338
339
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341
342
343
344
345
346
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
347
        images: Optional[PromptImageInput] = None,
348
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
349
    ) -> List[Tuple[List[int], str]]:
350
351
        outputs = self.generate(prompts,
                                do_sample=False,
352
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
353
354
                                images=images,
                                **kwargs)
355
356
357

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
358
359
360
361
362
363

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
364
    ) -> List[Tuple[List[List[int]], List[str]]]:
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
379

380
381
382
383
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
384
        images: Optional[PromptImageInput] = None,
385
        videos: Optional[List[np.ndarray]] = None,
386
        **kwargs: Any,
387
    ) -> List[List[torch.Tensor]]:
388
389
390
391
392
393
394
395
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
396
397
            if videos is not None and videos[i] is not None:
                processor_kwargs["videos"] = videos[i]
398
399

            inputs = self.processor(**processor_kwargs)
400
            inputs = self.postprocess_inputs(inputs)
401

402
            output = self.model.generate(
403
                **self.wrap_device(inputs),
404
405
406
407
408
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
409
                **kwargs,
410
            )
411
            seq_logprobs: List[torch.Tensor] = []
412
413
414
415
416
417
418
419
420
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
421
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
422
423
424
425
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    def _hidden_states_to_logprobs(
        self,
        hidden_states,
        num_logprobs,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs: List[torch.Tensor] = []
        output_len = len(hidden_states)
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
                last_hidden_states,
                self.model.get_output_embeddings().weight.t(),
            )
            if getattr(self.model.get_output_embeddings(), "bias",
                       None) is not None:
                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

464
465
466
467
468
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
469
470
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
471
        videos: Optional[List[np.ndarray]] = None,
472
        **kwargs: Any,
473
    ) -> List[TokensTextLogprobs]:
474
475
476
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
477

478
479
480
481
482
483
484
485
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

486
487
488
489
490
            if audios is not None:
                audio, sr = audios[i]
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr

491
492
            if videos is not None:
                processor_kwargs["videos"] = videos[i]
493
            inputs = self.processor(**processor_kwargs)
494
            inputs = self.postprocess_inputs(inputs)
495

496
            output = self.model.generate(
497
                **self.wrap_device(inputs),
498
499
500
501
502
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
503
                **kwargs,
504
505
            )

506
507
508
509
510
511
512
513
514
515
516
517
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
518

519
520
521
522
523
524
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
525
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
526
527
528
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
529
    ) -> List[TokensTextLogprobs]:
530
531
532
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
533

534
535
536
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
537

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
            encoder_input_ids = self.wrap_device(
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
            decoder_input_ids = (
                None if decoder_prompt is None else self.wrap_device(
                    self.tokenizer(decoder_prompt,
                                   return_tensors="pt").input_ids))

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
563
564
565
566
567
568
569
570
571
572
573

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

574
575
576
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

577
578
579
580
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
581
582
583
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
584

Cyrus Leung's avatar
Cyrus Leung committed
585
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
586
587
588
589
590
591
592
593
594
595
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
596
597
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
598
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
599
        dtype: str = "half",
600
        disable_log_stats: bool = True,
601
        tensor_parallel_size: int = 1,
602
603
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
604
        swap_space: int = 4,
605
        enforce_eager: Optional[bool] = False,
606
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
607
608
609
610
611
612
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
613
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
614
            enforce_eager=enforce_eager,
615
            disable_log_stats=disable_log_stats,
616
            tensor_parallel_size=tensor_parallel_size,
617
            max_model_len=max_model_len,
618
619
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
620
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
621
622
623
624
625
626
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
627
        images: Optional[PromptImageInput] = None,
628
    ) -> List[Tuple[List[List[int]], List[str]]]:
629
        if images is not None:
630
            assert len(prompts) == len(images)
631

632
633
634
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
635
                inputs[i]["multi_modal_data"] = {"image": image}
636

637
        req_outputs = self.model.generate(inputs,
638
                                          sampling_params=sampling_params)
639
640

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
641
642
643
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
644
645
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
646
647
            for sample in req_output.outputs:
                output_str = sample.text
648
                output_ids = list(sample.token_ids)
649
650
651
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
652
653
        return outputs

654
    @staticmethod
655
656
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
657
658
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
659
        for req_output in req_outputs:
660
            assert len(req_output.outputs) > 0
661
662
            for sample in req_output.outputs:
                output_str = sample.text
663
                output_ids = list(sample.token_ids)
664
                output_logprobs = sample.logprobs
665
666
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
667
668
        return outputs

669
670
671
672
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
673
674
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
675
        videos: Optional[PromptVideoInput] = None,
676
677
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
678
679
        assert sampling_params.logprobs is not None

680
681
682
        if images is not None:
            assert len(prompts) == len(images)

683
684
685
        if videos is not None:
            assert len(prompts) == len(videos)

686
687
688
689
690
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

691
692
693
694
        if audios is not None:
            for i, audio in enumerate(audios):
                inputs[i]["multi_modal_data"] = {"audio": audio}

695
696
697
698
699
        if videos is not None:
            for i, video in enumerate(videos):
                inputs[i]["multi_modal_data"] = {"video": video}
        print(f"[INPUTS!!!!]: {inputs}, {sampling_params}")

700
        req_outputs = self.model.generate(inputs,
701
                                          sampling_params=sampling_params)
702
703
704
705
706
707
708

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
709
710
711

    def generate_encoder_decoder_w_logprobs(
        self,
712
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
713
        sampling_params: SamplingParams,
714
715
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
716
717
718
719
720
721
722
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
723
724
725
726
727
728
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
729

Woosuk Kwon's avatar
Woosuk Kwon committed
730
731
732
733
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
734
        images: Optional[PromptImageInput] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
735
736
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
737
        outputs = self.generate(prompts, greedy_params, images=images)
738
739
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
740

741
742
743
744
745
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
746
        num_prompt_logprobs: Optional[int] = None,
747
748
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
749
        videos: Optional[PromptVideoInput] = None,
750
        stop_token_ids: Optional[List[int]] = None,
751
752
753
754
755
756
757
758
759
760
761
762
763
764
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
            stop_token_ids=stop_token_ids)

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
                                        videos=videos)
765

766
767
    def generate_encoder_decoder_greedy_logprobs(
        self,
768
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
769
770
        max_tokens: int,
        num_logprobs: int,
771
772
773
774
775
776
777
778
779
780
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            use_beam_search=False,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
781
782
783
784
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

785
        return self.generate_encoder_decoder_w_logprobs(
786
787
            encoder_decoder_prompts, greedy_logprobs_params)

788
789
790
791
792
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
793
    ) -> List[Tuple[List[List[int]], List[str]]]:
794
795
796
797
798
799
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
800

801
802
803
804
805
806
807
808
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

809
810
811
812
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
813
814
815
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
816

817
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
818
819
def vllm_runner():
    return VllmRunner
820
821
822
823
824
825
826
827
828


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
829
830
831
832
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
833
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
850
851
852
853
854
855
856


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

857
    return cuda_device_count_stateless()
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880


temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")


@pytest.fixture
def dummy_opt_path():
    json_path = os.path.join(_dummy_path, "config.json")
    if not os.path.exists(_dummy_path):
        snapshot_download(repo_id="facebook/opt-125m",
                          local_dir=_dummy_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_path