conftest.py 25.1 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
import sys
5
from collections import UserList
6
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9

import pytest
import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
from PIL import Image
13
14
15
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
                          AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
                          BatchFeature)
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
from tests.models.utils import DecoderPromptType
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm import LLM, SamplingParams
19
from vllm.assets.image import ImageAsset
20
from vllm.config import TokenizerPoolConfig
21
from vllm.connections import global_http_connection
22
23
from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
24
from vllm.inputs import TextPrompt
25
from vllm.logger import init_logger
26
from vllm.outputs import RequestOutput
27
from vllm.sequence import SampleLogprobs
28
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
29
30
                        is_cpu, to_enc_dec_tuple_list,
                        zip_enc_dec_prompt_lists)
31

32
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
33

34
35
36
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
37
38


39
def _read_prompts(filename: str) -> List[str]:
40
    with open(filename, "r") as f:
41
42
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44


45
46
47
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
48
49
50
51
52
53
54


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
55

56
57
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
58

59
60

class _ImageAssets(_ImageAssetsBase):
61
62

    def __init__(self) -> None:
63
64
65
66
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
67
68
69
70
71
72
73
74

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
75
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
76
77
78
79
80
81


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


82
83
84
85
86
87
88
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


89
90
def cleanup():
    destroy_model_parallel()
91
    destroy_distributed_environment()
92
93
94
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
95
96
    if not is_cpu():
        torch.cuda.empty_cache()
97
98


99
@pytest.fixture()
100
def should_do_global_cleanup_after_test(request) -> bool:
101
102
103
104
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
105
106
107
108

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

109
110
111
    return True


112
@pytest.fixture(autouse=True)
113
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
114
    yield
115
116
    if should_do_global_cleanup_after_test:
        cleanup()
117
118


Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
@pytest.fixture
def example_prompts() -> List[str]:
121
122
    prompts = []
    for filename in _TEST_PROMPTS:
123
        prompts += _read_prompts(filename)
124
125
126
    return prompts


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
@pytest.fixture
def example_encoder_decoder_prompts() \
    -> Dict[DecoderPromptType,
            Tuple[List[str], List[Optional[str]]]]:
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
    
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
        zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
        DecoderPromptType.EMPTY_STR:
        zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
        DecoderPromptType.CUSTOM:
        zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
    }


161
162
163
164
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
165
        prompts += _read_prompts(filename)
166
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
167
168


169
170
171
172
173
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


174
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
175

Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178

class HfRunner:

179
    def wrap_device(self, input: _T) -> _T:
180
181
182
183
184
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
187
188
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
189
        *,
190
        model_kwargs: Optional[Dict[str, Any]] = None,
191
192
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
193
        is_encoder_decoder_model: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
194
    ) -> None:
195
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
196

197
        self.model_name = model_name
198

199
        if is_embedding_model:
200
201
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
202
203
204
205
206
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
207
        else:
208
209
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
210
211
            elif is_encoder_decoder_model:
                auto_cls = AutoModelForSeq2SeqLM
212
213
214
            else:
                auto_cls = AutoModelForCausalLM

215
            model_kwargs = model_kwargs if model_kwargs is not None else {}
216
            self.model = self.wrap_device(
217
                auto_cls.from_pretrained(
218
219
220
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
221
                    **model_kwargs,
222
                ))
223
224
225
226
227
228
229
230

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
231
232
233
            # don't put this import at the top level
            # it will call torch.cuda.device_count()
            from transformers import AutoProcessor  # noqa: F401
234
235
236
237
238
239
240
241
242
243
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
        except Exception:
            logger.warning(
                "Unable to auto-load processor from HuggingFace for "
                "model %s. Using tokenizer instead.", model_name)
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
247

    def generate(
        self,
        prompts: List[str],
248
        images: Optional[List[Image.Image]] = None,
249
        **kwargs: Any,
250
    ) -> List[Tuple[List[List[int]], List[str]]]:
251
252
        if images:
            assert len(prompts) == len(images)
253
254

        outputs: List[Tuple[List[List[int]], List[str]]] = []
255
        for i, prompt in enumerate(prompts):
256
257
258
259
260
261
262
263
264
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

Woosuk Kwon's avatar
Woosuk Kwon committed
265
            output_ids = self.model.generate(
266
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
                use_cache=True,
                **kwargs,
            )
270
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
273
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
274
275
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
279
280
281
282
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
283
        images: Optional[List[Image.Image]] = None,
284
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
285
    ) -> List[Tuple[List[int], str]]:
286
287
        outputs = self.generate(prompts,
                                do_sample=False,
288
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
289
290
                                images=images,
                                **kwargs)
291
292
293

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
294
295
296
297
298
299

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
300
    ) -> List[Tuple[List[List[int]], List[str]]]:
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
315

316
317
318
319
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
320
321
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
322
    ) -> List[List[torch.Tensor]]:
323
324
325
326
327
328
329
330
331
332
333
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

334
            output = self.model.generate(
335
                **self.wrap_device(inputs),
336
337
338
339
340
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
341
                **kwargs,
342
            )
343
            seq_logprobs: List[torch.Tensor] = []
344
345
346
347
348
349
350
351
352
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
353
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
354
355
356
357
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    def _hidden_states_to_logprobs(
        self,
        hidden_states,
        num_logprobs,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs: List[torch.Tensor] = []
        output_len = len(hidden_states)
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
                last_hidden_states,
                self.model.get_output_embeddings().weight.t(),
            )
            if getattr(self.model.get_output_embeddings(), "bias",
                       None) is not None:
                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

396
397
398
399
400
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
401
402
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
403
404
405
406
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
407

408
409
410
411
412
413
414
415
416
417
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)

418
            output = self.model.generate(
419
                **self.wrap_device(inputs),
420
421
422
423
424
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
425
                **kwargs,
426
427
            )

428
429
430
431
432
433
434
435
436
437
438
439
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
440

441
442
443
444
445
446
447
448
449
450
451
452
453
454
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
        encoder_decoder_prompts: Tuple[List[str], List[str]],
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
455

456
457
458
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
            encoder_input_ids = self.wrap_device(
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
            decoder_input_ids = (
                None if decoder_prompt is None else self.wrap_device(
                    self.tokenizer(decoder_prompt,
                                   return_tensors="pt").input_ids))

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
485
486
487
488
489
490
491
492
493
494
495

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

496
497
498
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

499
500
501
502
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
503
504
505
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
506

Cyrus Leung's avatar
Cyrus Leung committed
507
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
508
509
510
511
512
513
514
515
516
517
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
518
519
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
520
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
521
        dtype: str = "half",
522
        disable_log_stats: bool = True,
523
        tensor_parallel_size: int = 1,
524
525
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
526
        swap_space: int = 4,
527
        enforce_eager: Optional[bool] = False,
528
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
529
530
531
532
533
534
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
535
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
536
            enforce_eager=enforce_eager,
537
            disable_log_stats=disable_log_stats,
538
            tensor_parallel_size=tensor_parallel_size,
539
            max_model_len=max_model_len,
540
541
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
542
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
543
544
545
546
547
548
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
549
        images: Optional[List[Image.Image]] = None,
550
    ) -> List[Tuple[List[List[int]], List[str]]]:
551
        if images is not None:
552
            assert len(prompts) == len(images)
553

554
555
556
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
557
                inputs[i]["multi_modal_data"] = {"image": image}
558

559
        req_outputs = self.model.generate(inputs,
560
                                          sampling_params=sampling_params)
561
562

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
563
564
565
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
566
567
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
568
569
            for sample in req_output.outputs:
                output_str = sample.text
570
                output_ids = list(sample.token_ids)
571
572
573
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
574
575
        return outputs

576
577
578
579
580
581
582
583
584
585
586
587
588
    def _final_steps_generate_w_logprobs(
        self,
        req_outputs: List[RequestOutput],
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = sample.token_ids
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

589
590
591
592
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
593
        images: Optional[List[Image.Image]] = None,
594
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
595
596
        assert sampling_params.logprobs is not None

597
598
599
600
601
602
603
604
605
        if images is not None:
            assert len(prompts) == len(images)

        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

        req_outputs = self.model.generate(inputs,
606
                                          sampling_params=sampling_params)
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        return self._final_steps_generate_w_logprobs(req_outputs)

    def generate_encoder_decoder_w_logprobs(
        self,
        encoder_decoder_prompts: Tuple[List[str], List[str]],
        sampling_params: SamplingParams,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
        return self._final_steps_generate_w_logprobs(req_outputs)
622

Woosuk Kwon's avatar
Woosuk Kwon committed
623
624
625
626
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
627
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
628
629
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
630
        outputs = self.generate(prompts, greedy_params, images=images)
631
632
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
633

634
635
636
637
638
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
639
640
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
641
        stop_token_ids: Optional[List[int]] = None,
642
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
643
644
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
645
646
                                                logprobs=num_logprobs,
                                                stop_token_ids=stop_token_ids)
647
648
649
        outputs = self.generate_w_logprobs(prompts,
                                           greedy_logprobs_params,
                                           images=images)
650
651
652
653

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
    def generate_encoder_decoder_greedy_logprobs(
        self,
        encoder_decoder_prompts: Tuple[List[str], List[str]],
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                use_beam_search=False,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

        outputs = self.generate_encoder_decoder_w_logprobs(
            encoder_decoder_prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

674
675
676
677
678
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
679
    ) -> List[Tuple[List[List[int]], List[str]]]:
680
681
682
683
684
685
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
686

687
688
689
690
691
692
693
694
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

695
696
697
698
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
699
700
701
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
702

703
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
704
705
def vllm_runner():
    return VllmRunner
706
707
708
709
710
711
712
713
714


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
715
716
717
718
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
719
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
736
737
738
739
740
741
742


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

743
    return cuda_device_count_stateless()