conftest.py 25.8 KB
Newer Older
1
2
import contextlib
import gc
3
import os
4
import sys
5
from collections import UserList
6
from enum import Enum
7
8
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
                    TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11

import pytest
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from PIL import Image
15
16
17
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
                          AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
                          BatchFeature)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19

from vllm import LLM, SamplingParams
20
from vllm.assets.image import ImageAsset
21
from vllm.config import TokenizerPoolConfig
22
from vllm.connections import global_http_connection
23
24
from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
25
26
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
27
from vllm.logger import init_logger
28
from vllm.outputs import RequestOutput
29
from vllm.sequence import SampleLogprobs
30
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
31
                        identity, is_cpu)
32

33
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
34

35
36
37
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
38
39


40
def _read_prompts(filename: str) -> List[str]:
41
    with open(filename, "r") as f:
42
43
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45


46
47
48
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
49
50
51
52
53
54
55


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
56

57
58
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
59

60
61

class _ImageAssets(_ImageAssetsBase):
62
63

    def __init__(self) -> None:
64
65
66
67
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
68
69
70
71
72
73
74
75

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
76
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
77
78
79
80
81
82


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


83
84
85
86
87
88
89
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


90
91
def cleanup():
    destroy_model_parallel()
92
    destroy_distributed_environment()
93
94
95
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
96
97
    if not is_cpu():
        torch.cuda.empty_cache()
98
99


100
@pytest.fixture()
101
def should_do_global_cleanup_after_test(request) -> bool:
102
103
104
105
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
106
107
108
109

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

110
111
112
    return True


113
@pytest.fixture(autouse=True)
114
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
115
    yield
116
117
    if should_do_global_cleanup_after_test:
        cleanup()
118
119


Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
@pytest.fixture
def example_prompts() -> List[str]:
122
123
    prompts = []
    for filename in _TEST_PROMPTS:
124
        prompts += _read_prompts(filename)
125
126
127
    return prompts


128
129
130
131
132
133
134
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


135
@pytest.fixture
136
137
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
    
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
160
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
161
        DecoderPromptType.EMPTY_STR:
162
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
163
        DecoderPromptType.CUSTOM:
164
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
165
166
167
    }


168
169
170
171
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
172
        prompts += _read_prompts(filename)
173
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175


176
177
178
179
180
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


181
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
182

Woosuk Kwon's avatar
Woosuk Kwon committed
183
184
185

class HfRunner:

186
    def wrap_device(self, input: _T) -> _T:
187
188
189
190
191
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
195
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
196
        *,
197
        model_kwargs: Optional[Dict[str, Any]] = None,
198
199
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
200
        is_encoder_decoder_model: bool = False,
201
202
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
203
    ) -> None:
204
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
205

206
        self.model_name = model_name
207

208
        if is_embedding_model:
209
210
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
211
212
213
214
215
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
216
        else:
217
218
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
219
220
            elif is_encoder_decoder_model:
                auto_cls = AutoModelForSeq2SeqLM
221
222
223
            else:
                auto_cls = AutoModelForCausalLM

224
            model_kwargs = model_kwargs if model_kwargs is not None else {}
225
            self.model = self.wrap_device(
226
                auto_cls.from_pretrained(
227
228
229
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
230
                    **model_kwargs,
231
                ))
232
233
234
235
236
237
238
239

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
240
241
242
            # don't put this import at the top level
            # it will call torch.cuda.device_count()
            from transformers import AutoProcessor  # noqa: F401
243
244
245
246
247
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
248
        except Exception as exc:
249
            logger.warning(
250
251
                "Unable to auto-load HuggingFace processor for model (%s). "
                "Using tokenizer instead. Reason: %s", model_name, exc)
252
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
253

254
255
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
    def generate(
        self,
        prompts: List[str],
259
        images: Optional[List[Image.Image]] = None,
260
        **kwargs: Any,
261
    ) -> List[Tuple[List[List[int]], List[str]]]:
262
263
        if images:
            assert len(prompts) == len(images)
264
265

        outputs: List[Tuple[List[List[int]], List[str]]] = []
266
        for i, prompt in enumerate(prompts):
267
268
269
270
271
272
273
274
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
275
            inputs = self.postprocess_inputs(inputs)
276

Woosuk Kwon's avatar
Woosuk Kwon committed
277
            output_ids = self.model.generate(
278
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281
                use_cache=True,
                **kwargs,
            )
282
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
286
287
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290
291
292
293
294
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
295
        images: Optional[List[Image.Image]] = None,
296
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
297
    ) -> List[Tuple[List[int], str]]:
298
299
        outputs = self.generate(prompts,
                                do_sample=False,
300
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
301
302
                                images=images,
                                **kwargs)
303
304
305

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
306
307
308
309
310
311

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
312
    ) -> List[Tuple[List[List[int]], List[str]]]:
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
327

328
329
330
331
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
332
333
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
334
    ) -> List[List[torch.Tensor]]:
335
336
337
338
339
340
341
342
343
344
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
345
            inputs = self.postprocess_inputs(inputs)
346

347
            output = self.model.generate(
348
                **self.wrap_device(inputs),
349
350
351
352
353
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
354
                **kwargs,
355
            )
356
            seq_logprobs: List[torch.Tensor] = []
357
358
359
360
361
362
363
364
365
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
366
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
367
368
369
370
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    def _hidden_states_to_logprobs(
        self,
        hidden_states,
        num_logprobs,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs: List[torch.Tensor] = []
        output_len = len(hidden_states)
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
                last_hidden_states,
                self.model.get_output_embeddings().weight.t(),
            )
            if getattr(self.model.get_output_embeddings(), "bias",
                       None) is not None:
                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

409
410
411
412
413
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
414
415
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
416
417
418
419
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
420

421
422
423
424
425
426
427
428
429
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
430
            inputs = self.postprocess_inputs(inputs)
431

432
            output = self.model.generate(
433
                **self.wrap_device(inputs),
434
435
436
437
438
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
439
                **kwargs,
440
441
            )

442
443
444
445
446
447
448
449
450
451
452
453
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
454

455
456
457
458
459
460
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
461
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
462
463
464
465
466
467
468
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
469

470
471
472
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
473

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
            encoder_input_ids = self.wrap_device(
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
            decoder_input_ids = (
                None if decoder_prompt is None else self.wrap_device(
                    self.tokenizer(decoder_prompt,
                                   return_tensors="pt").input_ids))

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
499
500
501
502
503
504
505
506
507
508
509

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

510
511
512
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

513
514
515
516
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
517
518
519
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
520

Cyrus Leung's avatar
Cyrus Leung committed
521
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
522
523
524
525
526
527
528
529
530
531
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
532
533
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
534
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
535
        dtype: str = "half",
536
        disable_log_stats: bool = True,
537
        tensor_parallel_size: int = 1,
538
539
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
540
        swap_space: int = 4,
541
        enforce_eager: Optional[bool] = False,
542
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
543
544
545
546
547
548
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
549
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
550
            enforce_eager=enforce_eager,
551
            disable_log_stats=disable_log_stats,
552
            tensor_parallel_size=tensor_parallel_size,
553
            max_model_len=max_model_len,
554
555
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
556
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
557
558
559
560
561
562
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
563
564
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
565
    ) -> List[Tuple[List[List[int]], List[str]]]:
566
        if images is not None:
567
            assert len(prompts) == len(images)
568

569
570
571
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
572
                inputs[i]["multi_modal_data"] = {"image": image}
573

574
        req_outputs = self.model.generate(inputs,
575
                                          sampling_params=sampling_params)
576
577

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
578
579
580
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
581
582
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
583
584
            for sample in req_output.outputs:
                output_str = sample.text
585
                output_ids = list(sample.token_ids)
586
587
588
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
589
590
        return outputs

591
592
593
594
595
596
597
598
    def _final_steps_generate_w_logprobs(
        self,
        req_outputs: List[RequestOutput],
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
599
                output_ids = list(sample.token_ids)
600
601
602
603
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

604
605
606
607
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
608
609
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
610
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
611
612
        assert sampling_params.logprobs is not None

613
614
615
616
617
618
619
620
621
        if images is not None:
            assert len(prompts) == len(images)

        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

        req_outputs = self.model.generate(inputs,
622
                                          sampling_params=sampling_params)
623
624
625
626
        return self._final_steps_generate_w_logprobs(req_outputs)

    def generate_encoder_decoder_w_logprobs(
        self,
627
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
628
629
630
631
632
633
634
635
636
637
        sampling_params: SamplingParams,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
        return self._final_steps_generate_w_logprobs(req_outputs)
638

Woosuk Kwon's avatar
Woosuk Kwon committed
639
640
641
642
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
643
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
644
645
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
646
        outputs = self.generate(prompts, greedy_params, images=images)
647
648
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
649

650
651
652
653
654
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
655
656
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
657
        stop_token_ids: Optional[List[int]] = None,
658
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
659
660
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
661
662
                                                logprobs=num_logprobs,
                                                stop_token_ids=stop_token_ids)
663
664
665
        outputs = self.generate_w_logprobs(prompts,
                                           greedy_logprobs_params,
                                           images=images)
666
667
668
669

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

670
671
    def generate_encoder_decoder_greedy_logprobs(
        self,
672
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                use_beam_search=False,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

        outputs = self.generate_encoder_decoder_w_logprobs(
            encoder_decoder_prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

690
691
692
693
694
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
695
    ) -> List[Tuple[List[List[int]], List[str]]]:
696
697
698
699
700
701
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
702

703
704
705
706
707
708
709
710
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

711
712
713
714
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
715
716
717
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
718

719
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
720
721
def vllm_runner():
    return VllmRunner
722
723
724
725
726
727
728
729
730


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
731
732
733
734
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
735
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
752
753
754
755
756
757
758


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

759
    return cuda_device_count_stateless()