conftest.py 27 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
                    TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13

import pytest
import torch
14
import torch.nn as nn
15
import torch.nn.functional as F
16
from huggingface_hub import snapshot_download
17
from PIL import Image
18
19
20
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
                          AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
                          BatchFeature)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22

from vllm import LLM, SamplingParams
23
from vllm.assets.image import ImageAsset
24
from vllm.config import TokenizerPoolConfig
25
from vllm.connections import global_http_connection
26
from vllm.distributed import (destroy_distributed_environment,
27
28
29
                              destroy_model_parallel,
                              init_distributed_environment,
                              initialize_model_parallel)
30
31
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
32
from vllm.logger import init_logger
33
from vllm.outputs import RequestOutput
34
from vllm.sequence import SampleLogprobs
35
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
36
                        identity, is_cpu)
37

38
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
39

40
41
42
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
43
44


45
def _read_prompts(filename: str) -> List[str]:
46
    with open(filename, "r") as f:
47
48
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50


51
52
53
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
54
55
56
57
58
59
60


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
61

62
63
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
64

65
66

class _ImageAssets(_ImageAssetsBase):
67
68

    def __init__(self) -> None:
69
70
71
72
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
73
74
75
76
77
78
79
80

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
81
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
82
83
84
85
86
87


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


88
89
90
91
92
93
94
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
    cleanup()


110
111
def cleanup():
    destroy_model_parallel()
112
    destroy_distributed_environment()
113
114
115
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
116
117
    if not is_cpu():
        torch.cuda.empty_cache()
118
119


120
@pytest.fixture()
121
def should_do_global_cleanup_after_test(request) -> bool:
122
123
124
125
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
126
127
128
129

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

130
131
132
    return True


133
@pytest.fixture(autouse=True)
134
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
135
    yield
136
137
    if should_do_global_cleanup_after_test:
        cleanup()
138
139


Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
@pytest.fixture
def example_prompts() -> List[str]:
142
143
    prompts = []
    for filename in _TEST_PROMPTS:
144
        prompts += _read_prompts(filename)
145
146
147
    return prompts


148
149
150
151
152
153
154
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


155
@pytest.fixture
156
157
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
    
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
180
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
181
        DecoderPromptType.EMPTY_STR:
182
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
183
        DecoderPromptType.CUSTOM:
184
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
185
186
187
    }


188
189
190
191
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
192
        prompts += _read_prompts(filename)
193
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195


196
197
198
199
200
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


201
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
202

Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205

class HfRunner:

206
    def wrap_device(self, input: _T) -> _T:
207
208
209
210
211
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
215
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
216
        *,
217
        model_kwargs: Optional[Dict[str, Any]] = None,
218
219
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
220
        is_encoder_decoder_model: bool = False,
221
222
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
223
    ) -> None:
224
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
225

226
        self.model_name = model_name
227

228
        if is_embedding_model:
229
230
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
231
232
233
234
235
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
236
        else:
237
238
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
239
240
            elif is_encoder_decoder_model:
                auto_cls = AutoModelForSeq2SeqLM
241
242
243
            else:
                auto_cls = AutoModelForCausalLM

244
            model_kwargs = model_kwargs if model_kwargs is not None else {}
245
            self.model = self.wrap_device(
246
                auto_cls.from_pretrained(
247
248
249
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
250
                    **model_kwargs,
251
                ))
252
253
254
255
256
257
258
259

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
260
261
262
            # don't put this import at the top level
            # it will call torch.cuda.device_count()
            from transformers import AutoProcessor  # noqa: F401
263
264
265
266
267
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
268
        except Exception as exc:
269
            logger.warning(
270
271
                "Unable to auto-load HuggingFace processor for model (%s). "
                "Using tokenizer instead. Reason: %s", model_name, exc)
272
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
273

274
275
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
    def generate(
        self,
        prompts: List[str],
279
        images: Optional[List[Image.Image]] = None,
280
        **kwargs: Any,
281
    ) -> List[Tuple[List[List[int]], List[str]]]:
282
283
        if images:
            assert len(prompts) == len(images)
284
285

        outputs: List[Tuple[List[List[int]], List[str]]] = []
286
        for i, prompt in enumerate(prompts):
287
288
289
290
291
292
293
294
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
295
            inputs = self.postprocess_inputs(inputs)
296

Woosuk Kwon's avatar
Woosuk Kwon committed
297
            output_ids = self.model.generate(
298
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
                use_cache=True,
                **kwargs,
            )
302
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
306
307
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
311
312
313
314
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
315
        images: Optional[List[Image.Image]] = None,
316
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
317
    ) -> List[Tuple[List[int], str]]:
318
319
        outputs = self.generate(prompts,
                                do_sample=False,
320
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
321
322
                                images=images,
                                **kwargs)
323
324
325

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
326
327
328
329
330
331

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
332
    ) -> List[Tuple[List[List[int]], List[str]]]:
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
347

348
349
350
351
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
352
353
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
354
    ) -> List[List[torch.Tensor]]:
355
356
357
358
359
360
361
362
363
364
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
365
            inputs = self.postprocess_inputs(inputs)
366

367
            output = self.model.generate(
368
                **self.wrap_device(inputs),
369
370
371
372
373
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
374
                **kwargs,
375
            )
376
            seq_logprobs: List[torch.Tensor] = []
377
378
379
380
381
382
383
384
385
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
386
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
387
388
389
390
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def _hidden_states_to_logprobs(
        self,
        hidden_states,
        num_logprobs,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs: List[torch.Tensor] = []
        output_len = len(hidden_states)
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
                last_hidden_states,
                self.model.get_output_embeddings().weight.t(),
            )
            if getattr(self.model.get_output_embeddings(), "bias",
                       None) is not None:
                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

429
430
431
432
433
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
434
435
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
436
437
438
439
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
440

441
442
443
444
445
446
447
448
449
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
450
            inputs = self.postprocess_inputs(inputs)
451

452
            output = self.model.generate(
453
                **self.wrap_device(inputs),
454
455
456
457
458
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
459
                **kwargs,
460
461
            )

462
463
464
465
466
467
468
469
470
471
472
473
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
474

475
476
477
478
479
480
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
481
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
482
483
484
485
486
487
488
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
489

490
491
492
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
493

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
            encoder_input_ids = self.wrap_device(
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
            decoder_input_ids = (
                None if decoder_prompt is None else self.wrap_device(
                    self.tokenizer(decoder_prompt,
                                   return_tensors="pt").input_ids))

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
519
520
521
522
523
524
525
526
527
528
529

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

530
531
532
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

533
534
535
536
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
537
538
539
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
540

Cyrus Leung's avatar
Cyrus Leung committed
541
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
542
543
544
545
546
547
548
549
550
551
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
552
553
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
554
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
555
        dtype: str = "half",
556
        disable_log_stats: bool = True,
557
        tensor_parallel_size: int = 1,
558
559
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
560
        swap_space: int = 4,
561
        enforce_eager: Optional[bool] = False,
562
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
563
564
565
566
567
568
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
569
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
570
            enforce_eager=enforce_eager,
571
            disable_log_stats=disable_log_stats,
572
            tensor_parallel_size=tensor_parallel_size,
573
            max_model_len=max_model_len,
574
575
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
576
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
577
578
579
580
581
582
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
583
584
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
585
    ) -> List[Tuple[List[List[int]], List[str]]]:
586
        if images is not None:
587
            assert len(prompts) == len(images)
588

589
590
591
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
592
                inputs[i]["multi_modal_data"] = {"image": image}
593

594
        req_outputs = self.model.generate(inputs,
595
                                          sampling_params=sampling_params)
596
597

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
598
599
600
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
601
602
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
603
604
            for sample in req_output.outputs:
                output_str = sample.text
605
                output_ids = list(sample.token_ids)
606
607
608
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
609
610
        return outputs

611
612
613
614
615
616
617
618
    def _final_steps_generate_w_logprobs(
        self,
        req_outputs: List[RequestOutput],
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
619
                output_ids = list(sample.token_ids)
620
621
622
623
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

624
625
626
627
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
628
629
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
630
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
631
632
        assert sampling_params.logprobs is not None

633
634
635
636
637
638
639
640
641
        if images is not None:
            assert len(prompts) == len(images)

        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

        req_outputs = self.model.generate(inputs,
642
                                          sampling_params=sampling_params)
643
644
645
646
        return self._final_steps_generate_w_logprobs(req_outputs)

    def generate_encoder_decoder_w_logprobs(
        self,
647
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
648
649
650
651
652
653
654
655
656
657
        sampling_params: SamplingParams,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
        return self._final_steps_generate_w_logprobs(req_outputs)
658

Woosuk Kwon's avatar
Woosuk Kwon committed
659
660
661
662
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
663
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
664
665
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
666
        outputs = self.generate(prompts, greedy_params, images=images)
667
668
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
669

670
671
672
673
674
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
675
676
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
677
        stop_token_ids: Optional[List[int]] = None,
678
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
679
680
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
681
682
                                                logprobs=num_logprobs,
                                                stop_token_ids=stop_token_ids)
683
684
685
        outputs = self.generate_w_logprobs(prompts,
                                           greedy_logprobs_params,
                                           images=images)
686
687
688
689

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

690
691
    def generate_encoder_decoder_greedy_logprobs(
        self,
692
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                use_beam_search=False,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

        outputs = self.generate_encoder_decoder_w_logprobs(
            encoder_decoder_prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

710
711
712
713
714
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
715
    ) -> List[Tuple[List[List[int]], List[str]]]:
716
717
718
719
720
721
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
722

723
724
725
726
727
728
729
730
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

731
732
733
734
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
735
736
737
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
738

739
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
740
741
def vllm_runner():
    return VllmRunner
742
743
744
745
746
747
748
749
750


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
751
752
753
754
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
755
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
772
773
774
775
776
777
778


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

779
    return cuda_device_count_stateless()
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802


temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")


@pytest.fixture
def dummy_opt_path():
    json_path = os.path.join(_dummy_path, "config.json")
    if not os.path.exists(_dummy_path):
        snapshot_download(repo_id="facebook/opt-125m",
                          local_dir=_dummy_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_path