conftest.py 26.6 KB
Newer Older
1
2
import contextlib
import gc
3
import json
4
import os
5
import sys
6
import tempfile
7
from collections import UserList
8
from enum import Enum
9
10
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
                    TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
13

import pytest
import torch
14
import torch.nn as nn
15
import torch.nn.functional as F
16
from huggingface_hub import snapshot_download
17
from PIL import Image
18
19
20
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
                          AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
                          BatchFeature)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22

from vllm import LLM, SamplingParams
23
from vllm.assets.image import ImageAsset
24
from vllm.config import TokenizerPoolConfig
25
from vllm.connections import global_http_connection
26
27
from vllm.distributed import (destroy_distributed_environment,
                              destroy_model_parallel)
28
29
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
30
from vllm.logger import init_logger
31
from vllm.outputs import RequestOutput
32
from vllm.sequence import SampleLogprobs
33
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
34
                        identity, is_cpu)
35

36
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
37

38
39
40
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
41
42


43
def _read_prompts(filename: str) -> List[str]:
44
    with open(filename, "r") as f:
45
46
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48


49
50
51
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
52
53
54
55
56
57
58


if sys.version_info < (3, 9):
    # UserList cannot be subscripted
    class _ImageAssetsBase(UserList):
        pass
else:
59

60
61
    class _ImageAssetsBase(UserList[ImageAsset]):
        pass
62

63
64

class _ImageAssets(_ImageAssetsBase):
65
66

    def __init__(self) -> None:
67
68
69
70
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
71
72
73
74
75
76
77
78

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
79
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
80
81
82
83
84
85


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""


86
87
88
89
90
91
92
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


93
94
def cleanup():
    destroy_model_parallel()
95
    destroy_distributed_environment()
96
97
98
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    gc.collect()
99
100
    if not is_cpu():
        torch.cuda.empty_cache()
101
102


103
@pytest.fixture()
104
def should_do_global_cleanup_after_test(request) -> bool:
105
106
107
108
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
109
110
111
112

    if request.node.get_closest_marker("skip_global_cleanup"):
        return False

113
114
115
    return True


116
@pytest.fixture(autouse=True)
117
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
118
    yield
119
120
    if should_do_global_cleanup_after_test:
        cleanup()
121
122


Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
@pytest.fixture
def example_prompts() -> List[str]:
125
126
    prompts = []
    for filename in _TEST_PROMPTS:
127
        prompts += _read_prompts(filename)
128
129
130
    return prompts


131
132
133
134
135
136
137
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


138
@pytest.fixture
139
140
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
    
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
163
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
164
        DecoderPromptType.EMPTY_STR:
165
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
166
        DecoderPromptType.CUSTOM:
167
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
168
169
170
    }


171
172
173
174
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
175
        prompts += _read_prompts(filename)
176
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178


179
180
181
182
183
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


184
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
185

Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
188

class HfRunner:

189
    def wrap_device(self, input: _T) -> _T:
190
191
192
193
194
        if not is_cpu():
            return input.to("cuda")
        else:
            return input.to("cpu")

Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
197
198
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
199
        *,
200
        model_kwargs: Optional[Dict[str, Any]] = None,
201
202
        is_embedding_model: bool = False,
        is_vision_model: bool = False,
203
        is_encoder_decoder_model: bool = False,
204
205
        postprocess_inputs: Callable[[BatchEncoding],
                                     BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
206
    ) -> None:
207
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
208

209
        self.model_name = model_name
210

211
        if is_embedding_model:
212
213
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
214
215
216
217
218
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
                ).to(dtype=torch_dtype))
219
        else:
220
221
            if is_vision_model:
                auto_cls = AutoModelForVision2Seq
222
223
            elif is_encoder_decoder_model:
                auto_cls = AutoModelForSeq2SeqLM
224
225
226
            else:
                auto_cls = AutoModelForCausalLM

227
            model_kwargs = model_kwargs if model_kwargs is not None else {}
228
            self.model = self.wrap_device(
229
                auto_cls.from_pretrained(
230
231
232
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
233
                    **model_kwargs,
234
                ))
235
236
237
238
239
240
241
242

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )

        try:
243
244
245
            # don't put this import at the top level
            # it will call torch.cuda.device_count()
            from transformers import AutoProcessor  # noqa: F401
246
247
248
249
250
            self.processor = AutoProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
251
        except Exception as exc:
252
            logger.warning(
253
254
                "Unable to auto-load HuggingFace processor for model (%s). "
                "Using tokenizer instead. Reason: %s", model_name, exc)
255
            self.processor = self.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
256

257
258
        self.postprocess_inputs = postprocess_inputs

Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
    def generate(
        self,
        prompts: List[str],
262
        images: Optional[List[Image.Image]] = None,
263
        **kwargs: Any,
264
    ) -> List[Tuple[List[List[int]], List[str]]]:
265
266
        if images:
            assert len(prompts) == len(images)
267
268

        outputs: List[Tuple[List[List[int]], List[str]]] = []
269
        for i, prompt in enumerate(prompts):
270
271
272
273
274
275
276
277
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
278
            inputs = self.postprocess_inputs(inputs)
279

Woosuk Kwon's avatar
Woosuk Kwon committed
280
            output_ids = self.model.generate(
281
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
282
283
284
                use_cache=True,
                **kwargs,
            )
285
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
288
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
289
290
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
293
294
295
296
297
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
298
        images: Optional[List[Image.Image]] = None,
299
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
300
    ) -> List[Tuple[List[int], str]]:
301
302
        outputs = self.generate(prompts,
                                do_sample=False,
303
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
304
305
                                images=images,
                                **kwargs)
306
307
308

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
309
310
311
312
313
314

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
315
    ) -> List[Tuple[List[List[int]], List[str]]]:
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
330

331
332
333
334
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
335
336
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
337
    ) -> List[List[torch.Tensor]]:
338
339
340
341
342
343
344
345
346
347
        all_logprobs: List[List[torch.Tensor]] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
348
            inputs = self.postprocess_inputs(inputs)
349

350
            output = self.model.generate(
351
                **self.wrap_device(inputs),
352
353
354
355
356
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
357
                **kwargs,
358
            )
359
            seq_logprobs: List[torch.Tensor] = []
360
361
362
363
364
365
366
367
368
            for hidden_states in output.hidden_states:
                last_hidden_states = hidden_states[-1][0]
                logits = torch.matmul(
                    last_hidden_states,
                    self.model.get_output_embeddings().weight.t(),
                )
                if self.model.get_output_embeddings().bias is not None:
                    logits += self.model.get_output_embeddings(
                    ).bias.unsqueeze(0)
369
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
370
371
372
373
                seq_logprobs.append(logprobs)
            all_logprobs.append(seq_logprobs)
        return all_logprobs

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    def _hidden_states_to_logprobs(
        self,
        hidden_states,
        num_logprobs,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs: List[torch.Tensor] = []
        output_len = len(hidden_states)
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
                last_hidden_states,
                self.model.get_output_embeddings().weight.t(),
            )
            if getattr(self.model.get_output_embeddings(), "bias",
                       None) is not None:
                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

412
413
414
415
416
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
417
418
        images: Optional[List[Image.Image]] = None,
        **kwargs: Any,
419
420
421
422
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
423

424
425
426
427
428
429
430
431
432
        for i, prompt in enumerate(prompts):
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]

            inputs = self.processor(**processor_kwargs)
433
            inputs = self.postprocess_inputs(inputs)
434

435
            output = self.model.generate(
436
                **self.wrap_device(inputs),
437
438
439
440
441
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
442
                **kwargs,
443
444
            )

445
446
447
448
449
450
451
452
453
454
455
456
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
457

458
459
460
461
462
463
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
464
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
465
466
467
468
469
470
471
        max_tokens: int,
        num_logprobs: int,
        **kwargs: Any,
    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
472

473
474
475
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        for (encoder_prompt,
             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
            encoder_input_ids = self.wrap_device(
                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
            decoder_input_ids = (
                None if decoder_prompt is None else self.wrap_device(
                    self.tokenizer(decoder_prompt,
                                   return_tensors="pt").input_ids))

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
502
503
504
505
506
507
508
509
510
511
512

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

513
514
515
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

516
517
518
519
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
520
521
522
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
523

Cyrus Leung's avatar
Cyrus Leung committed
524
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
525
526
527
528
529
530
531
532
533
534
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
535
536
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
537
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
538
        dtype: str = "half",
539
        disable_log_stats: bool = True,
540
        tensor_parallel_size: int = 1,
541
542
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
543
        swap_space: int = 4,
544
        enforce_eager: Optional[bool] = False,
545
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
546
547
548
549
550
551
    ) -> None:
        self.model = LLM(
            model=model_name,
            tokenizer=tokenizer_name,
            trust_remote_code=True,
            dtype=dtype,
552
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
553
            enforce_eager=enforce_eager,
554
            disable_log_stats=disable_log_stats,
555
            tensor_parallel_size=tensor_parallel_size,
556
            max_model_len=max_model_len,
557
558
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
559
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
560
561
562
563
564
565
        )

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
566
567
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
568
    ) -> List[Tuple[List[List[int]], List[str]]]:
569
        if images is not None:
570
            assert len(prompts) == len(images)
571

572
573
574
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
575
                inputs[i]["multi_modal_data"] = {"image": image}
576

577
        req_outputs = self.model.generate(inputs,
578
                                          sampling_params=sampling_params)
579
580

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
581
582
583
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
584
585
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
586
587
            for sample in req_output.outputs:
                output_str = sample.text
588
                output_ids = list(sample.token_ids)
589
590
591
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
592
593
        return outputs

594
595
596
597
598
599
600
601
    def _final_steps_generate_w_logprobs(
        self,
        req_outputs: List[RequestOutput],
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
        for req_output in req_outputs:
            for sample in req_output.outputs:
                output_str = sample.text
602
                output_ids = list(sample.token_ids)
603
604
605
606
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs))
        return outputs

607
608
609
610
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
611
612
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
613
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
614
615
        assert sampling_params.logprobs is not None

616
617
618
619
620
621
622
623
624
        if images is not None:
            assert len(prompts) == len(images)

        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
                inputs[i]["multi_modal_data"] = {"image": image}

        req_outputs = self.model.generate(inputs,
625
                                          sampling_params=sampling_params)
626
627
628
629
        return self._final_steps_generate_w_logprobs(req_outputs)

    def generate_encoder_decoder_w_logprobs(
        self,
630
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
631
632
633
634
635
636
637
638
639
640
        sampling_params: SamplingParams,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
        return self._final_steps_generate_w_logprobs(req_outputs)
641

Woosuk Kwon's avatar
Woosuk Kwon committed
642
643
644
645
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
646
        images: Optional[List[Image.Image]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
647
648
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
649
        outputs = self.generate(prompts, greedy_params, images=images)
650
651
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
652

653
654
655
656
657
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
658
659
        images: Optional[Union[List[Image.Image],
                               List[List[Image.Image]]]] = None,
660
        stop_token_ids: Optional[List[int]] = None,
661
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
662
663
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                max_tokens=max_tokens,
664
665
                                                logprobs=num_logprobs,
                                                stop_token_ids=stop_token_ids)
666
667
668
        outputs = self.generate_w_logprobs(prompts,
                                           greedy_logprobs_params,
                                           images=images)
669
670
671
672

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

673
674
    def generate_encoder_decoder_greedy_logprobs(
        self,
675
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        max_tokens: int,
        num_logprobs: int,
    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
        greedy_logprobs_params = SamplingParams(temperature=0.0,
                                                use_beam_search=False,
                                                max_tokens=max_tokens,
                                                logprobs=num_logprobs)
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

        outputs = self.generate_encoder_decoder_w_logprobs(
            encoder_decoder_prompts, greedy_logprobs_params)

        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

693
694
695
696
697
    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
698
    ) -> List[Tuple[List[List[int]], List[str]]]:
699
700
701
702
703
704
        beam_search_params = SamplingParams(n=beam_width,
                                            use_beam_search=True,
                                            temperature=0.0,
                                            max_tokens=max_tokens)
        outputs = self.generate(prompts, beam_search_params)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
705

706
707
708
709
710
711
712
713
    def encode(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.encode(prompts)
        outputs = []
        for req_output in req_outputs:
            embedding = req_output.outputs.embedding
            outputs.append(embedding)
        return outputs

714
715
716
717
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
718
719
720
        del self.model
        cleanup()

Woosuk Kwon's avatar
Woosuk Kwon committed
721

722
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
723
724
def vllm_runner():
    return VllmRunner
725
726
727
728
729
730
731
732
733


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
734
735
736
737
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
738
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
755
756
757
758
759
760
761


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

762
    return cuda_device_count_stateless()
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785


temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")


@pytest.fixture
def dummy_opt_path():
    json_path = os.path.join(_dummy_path, "config.json")
    if not os.path.exists(_dummy_path):
        snapshot_download(repo_id="facebook/opt-125m",
                          local_dir=_dummy_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
        with open(json_path, "r") as f:
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_path