conftest.py 27.2 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
                    TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
import pytest
import torch
15
import torch.nn as nn
16
import torch.nn.functional as F
17
from huggingface_hub import snapshot_download
18
from PIL import Image
19
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
20
                          BatchFeature)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22

from vllm import LLM, SamplingParams
23
from vllm.assets.image import ImageAsset
24
from vllm.config import TokenizerPoolConfig
25
from vllm.connections import global_http_connection
26
from vllm.distributed import (destroy_distributed_environment,
27
28
29
                              destroy_model_parallel,
                              init_distributed_environment,
                              initialize_model_parallel)
30
31
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
32
from vllm.logger import init_logger
33
from vllm.outputs import RequestOutput
34
from vllm.sequence import SampleLogprobs
35
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
36
                        identity, is_cpu)
37

38
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
39

40
41
42
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
43

44
45
46
47
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
                         List[List[Tuple[np.ndarray, int]]]]

48

49
def _read_prompts(filename: str) -> List[str]:
50
    with open(filename, "r") as f:
51
52
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54


55
56
57
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
58
59
60
61
62
63
64


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
65

66
67
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
68

69
70

class _ImageAssets(_ImageAssetsBase):
71
72

    def __init__(self) -> None:
73
74
75
76
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
77
78
79
80
81
82
83
84

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
85
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
86
87
88
89
90
91


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


92
93
94
95
96
97
98
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
    cleanup()


114
115
def cleanup():
    destroy_model_parallel()
116
    destroy_distributed_environment()
117
118
119
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
120
121
    if not is_cpu():
        torch.cuda.empty_cache()
122
123


124
@pytest.fixture()
125
def should_do_global_cleanup_after_test(request) -> bool:
126
127
128
129
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
130
131
132
133

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

134
135
136
    return True


137
@pytest.fixture(autouse=True)
138
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
139
    yield
140
141
    if should_do_global_cleanup_after_test:
        cleanup()
142
143


Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
@pytest.fixture
def example_prompts() -> List[str]:
146
147
    prompts = []
    for filename in _TEST_PROMPTS:
148
        prompts += _read_prompts(filename)
149
150
151
    return prompts


152
153
154
155
156
157
158
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


159
@pytest.fixture
160
161
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
162
163
164
165
166
167
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
184
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
185
        DecoderPromptType.EMPTY_STR:
186
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
187
        DecoderPromptType.CUSTOM:
188
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
189
190
191
    }


192
193
194
195
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
196
        prompts += _read_prompts(filename)
197
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199


200
201
202
203
204
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


205
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
206

Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
209

class HfRunner:

210
    def wrap_device(self, input: _T) -> _T:
211
212
213
214
215
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
219
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
220
        *,
221
        model_kwargs: Optional[Dict[str, Any]] = None,
222
        is_embedding_model: bool = False,
223
        auto_cls=AutoModelForCausalLM,
224
225
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
226
    ) -> None:
227
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
228

229
        self.model_name = model_name
230

231
        if is_embedding_model:
232
233
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
234
235
236
237
238
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
239
        else:
240
            model_kwargs = model_kwargs if model_kwargs is not None else {}
241
            self.model = self.wrap_device(
242
                auto_cls.from_pretrained(
243
244
245
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
246
                    **model_kwargs,
247
                ))
248
249
250
251
252
253
254
255

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
256
257
258
            # don't put this import at the top level
            # it will call torch.cuda.device_count()
            from transformers import AutoProcessor  # noqa: F401
259
260
261
262
263
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
264
        except Exception as exc:
265
            logger.warning(
266
267
                "Unable to auto-load HuggingFace processor for model (%s). "
                "Using tokenizer instead. Reason: %s", model_name, exc)
268
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
269

270
271
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
272
273
274
    def generate(
        self,
        prompts: List[str],
275
        images: Optional[List[Image.Image]] = None,
276
        **kwargs: Any,
277
    ) -> List[Tuple[List[List[int]], List[str]]]:
278
279
        if images:
            assert len(prompts) == len(images)
280
281

        outputs: List[Tuple[List[List[int]], List[str]]] = []
282
        for i, prompt in enumerate(prompts):
283
284
285
286
287
288
289
290
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
291
            inputs = self.postprocess_inputs(inputs)
292

Woosuk Kwon's avatar
Woosuk Kwon committed
293
            output_ids = self.model.generate(
294
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296
297
                use_cache=True,
                **kwargs,
            )
298
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
302
303
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
307
308
309
310
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
311
        images: Optional[List[Image.Image]] = None,
312
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
313
    ) -> List[Tuple[List[int], str]]:
314
315
        outputs = self.generate(prompts,
                                do_sample=False,
316
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
317
318
                                images=images,
                                **kwargs)
319
320
321

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
322
323
324
325
326
327

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
328
    ) -> List[Tuple[List[List[int]], List[str]]]:
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
343

344
345
346
347
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
348
349
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
350
    ) -> List[List[torch.Tensor]]:
351
352
353
354
355
356
357
358
359
360
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
361
            inputs = self.postprocess_inputs(inputs)
362

363
            output = self.model.generate(
364
                **self.wrap_device(inputs),
365
366
367
368
369
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
370
                **kwargs,
371
            )
372
            seq_logprobs: List[torch.Tensor] = []
373
374
375
376
377
378
379
380
381
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
382
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
383
384
385
386
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    def _hidden_states_to_logprobs(
        self,
        hidden_states,
        num_logprobs,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs: List[torch.Tensor] = []
        output_len = len(hidden_states)
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
                last_hidden_states,
                self.model.get_output_embeddings().weight.t(),
            )
            if getattr(self.model.get_output_embeddings(), "bias",
                       None) is not None:
                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

425
426
427
428
429
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
430
        images: Optional[List[Image.Image]] = None,
431
        audios: Optional[List[Tuple[np.ndarray, int]]] = None,
432
        **kwargs: Any,
433
434
435
436
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
437

438
439
440
441
442
443
444
445
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

446
447
448
449
450
            if audios is not None:
                audio, sr = audios[i]
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr

451
            inputs = self.processor(**processor_kwargs)
452
            inputs = self.postprocess_inputs(inputs)
453

454
            output = self.model.generate(
455
                **self.wrap_device(inputs),
456
457
458
459
460
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
461
                **kwargs,
462
463
            )

464
465
466
467
468
469
470
471
472
473
474
475
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
476

477
478
479
480
481
482
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
483
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
484
485
486
487
488
489
490
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
491

492
493
494
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
495

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
            encoder_input_ids = self.wrap_device(
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
            decoder_input_ids = (
                None if decoder_prompt is None else self.wrap_device(
                    self.tokenizer(decoder_prompt,
                                   return_tensors="pt").input_ids))

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
521
522
523
524
525
526
527
528
529
530
531

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

532
533
534
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

535
536
537
538
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
539
540
541
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
542

Cyrus Leung's avatar
Cyrus Leung committed
543
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
544
545
546
547
548
549
550
551
552
553
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
554
555
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
556
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
557
        dtype: str = "half",
558
        disable_log_stats: bool = True,
559
        tensor_parallel_size: int = 1,
560
561
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
562
        swap_space: int = 4,
563
        enforce_eager: Optional[bool] = False,
564
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
565
566
567
568
569
570
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
571
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
572
            enforce_eager=enforce_eager,
573
            disable_log_stats=disable_log_stats,
574
            tensor_parallel_size=tensor_parallel_size,
575
            max_model_len=max_model_len,
576
577
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
578
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
579
580
581
582
583
584
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
585
        images: Optional[PromptImageInput] = None,
586
    ) -> List[Tuple[List[List[int]], List[str]]]:
587
        if images is not None:
588
            assert len(prompts) == len(images)
589

590
591
592
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
593
                inputs[i]["multi_modal_data"] = {"image": image}
594

595
        req_outputs = self.model.generate(inputs,
596
                                          sampling_params=sampling_params)
597
598

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
599
600
601
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
602
603
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
604
605
            for sample in req_output.outputs:
                output_str = sample.text
606
                output_ids = list(sample.token_ids)
607
608
609
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
610
611
        return outputs

612
613
614
615
616
617
618
619
    def _final_steps_generate_w_logprobs(
        self,
        req_outputs: List[RequestOutput],
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
620
                output_ids = list(sample.token_ids)
621
622
623
624
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

625
626
627
628
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
629
630
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
631
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
632
633
        assert sampling_params.logprobs is not None

634
635
636
637
638
639
640
641
        if images is not None:
            assert len(prompts) == len(images)

        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

642
643
644
645
        if audios is not None:
            for i, audio in enumerate(audios):
                inputs[i]["multi_modal_data"] = {"audio": audio}

646
        req_outputs = self.model.generate(inputs,
647
                                          sampling_params=sampling_params)
648
649
650
651
        return self._final_steps_generate_w_logprobs(req_outputs)

    def generate_encoder_decoder_w_logprobs(
        self,
652
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
653
654
655
656
657
658
659
660
661
662
        sampling_params: SamplingParams,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
        return self._final_steps_generate_w_logprobs(req_outputs)
663

Woosuk Kwon's avatar
Woosuk Kwon committed
664
665
666
667
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
668
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
669
670
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
671
        outputs = self.generate(prompts, greedy_params, images=images)
672
673
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
674

675
676
677
678
679
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
680
681
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
682
        stop_token_ids: Optional[List[int]] = None,
683
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
684
685
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
686
687
                                                logprobs=num_logprobs,
                                                stop_token_ids=stop_token_ids)
688
689
        outputs = self.generate_w_logprobs(prompts,
                                           greedy_logprobs_params,
690
691
                                           images=images,
                                           audios=audios)
692
693
694
695

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

696
697
    def generate_encoder_decoder_greedy_logprobs(
        self,
698
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                use_beam_search=False,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

        outputs = self.generate_encoder_decoder_w_logprobs(
            encoder_decoder_prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

716
717
718
719
720
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
721
    ) -> List[Tuple[List[List[int]], List[str]]]:
722
723
724
725
726
727
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
728

729
730
731
732
733
734
735
736
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

737
738
739
740
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
741
742
743
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
744

745
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
746
747
def vllm_runner():
    return VllmRunner
748
749
750
751
752
753
754
755
756


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
757
758
759
760
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
761
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
778
779
780
781
782
783
784


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

785
    return cuda_device_count_stateless()
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808


temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")


@pytest.fixture
def dummy_opt_path():
    json_path = os.path.join(_dummy_path, "config.json")
    if not os.path.exists(_dummy_path):
        snapshot_download(repo_id="facebook/opt-125m",
                          local_dir=_dummy_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_path