conftest.py 43.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import math
5
import os
6
import tempfile
7
from enum import Enum
8
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
import pytest
import torch
13
import torch.nn as nn
14
import torch.nn.functional as F
15
from huggingface_hub import snapshot_download
16
from PIL import Image
17
18
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
19
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
20

21
22
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm import LLM, SamplingParams
24
from vllm.assets.audio import AudioAsset
25
from vllm.assets.image import ImageAsset
26
from vllm.assets.video import VideoAsset
27
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
28
from vllm.connections import global_http_connection
29
from vllm.distributed import (cleanup_dist_env_and_memory,
30
31
                              init_distributed_environment,
                              initialize_model_parallel)
32
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
33
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
34
from vllm.logger import init_logger
35
from vllm.outputs import RequestOutput
36
from vllm.sampling_params import BeamSearchParams
37
from vllm.sequence import Logprob
38
from vllm.transformers_utils.utils import maybe_model_redirect
39

40
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
41

42
43
44
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
45
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
46

Cyrus Leung's avatar
Cyrus Leung committed
47
_M = TypeVar("_M")
48

49
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
50
51

PromptImageInput = _PromptMultiModalInput[Image.Image]
52
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
53
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
54

55

56
def _read_prompts(filename: str) -> list[str]:
57
    with open(filename) as f:
58
59
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61


62
class ImageAssetPrompts(TypedDict):
63
64
    stop_sign: str
    cherry_blossom: str
65
66


67
class ImageTestAssets(list[ImageAsset]):
68
69

    def __init__(self) -> None:
70
71
72
73
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
74

75
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
76
77
78
79
80
81
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
82
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
83
84


85
86
class VideoAssetPrompts(TypedDict):
    baby_reading: str
87
88


89
class VideoTestAssets(list[VideoAsset]):
90
91
92

    def __init__(self) -> None:
        super().__init__([
93
            VideoAsset("baby_reading"),
94
95
        ])

96
97
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
98
99


100
class AudioAssetPrompts(TypedDict):
101
102
103
104
    mary_had_lamb: str
    winning_call: str


105
class AudioTestAssets(list[AudioAsset]):
106
107
108
109
110
111
112

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

113
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
114
115
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

116

117
IMAGE_ASSETS = ImageTestAssets()
118
"""Singleton instance of {class}`ImageTestAssets`."""
119
VIDEO_ASSETS = VideoTestAssets()
120
"""Singleton instance of {class}`VideoTestAssets`."""
121
AUDIO_ASSETS = AudioTestAssets()
122
"""Singleton instance of {class}`AudioTestAssets`."""
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
145
@pytest.fixture(params=[True, False])
146
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
147
148
149
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
150
    skip_v0 = request.node.get_closest_marker("skip_v0")
Joe Runde's avatar
Joe Runde committed
151
152
153
154
155
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
156
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
157
    else:
158
159
        if skip_v0:
            pytest.skip("Skipping test on vllm V0")
160
161
162
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
163
164


165
166
167
168
169
170
171
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


172
173
174
175
176
177
178
179
180
181
182
183
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
184
    cleanup_dist_env_and_memory()
185
186


187
@pytest.fixture()
188
def should_do_global_cleanup_after_test(request) -> bool:
189
190
191
192
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
193

194
    return not request.node.get_closest_marker("skip_global_cleanup")
195
196


197
@pytest.fixture(autouse=True)
198
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
199
    yield
200
    if should_do_global_cleanup_after_test:
201
        cleanup_dist_env_and_memory()
202
203


204
205
206
207
208
209
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
210
@pytest.fixture
211
def example_prompts() -> list[str]:
212
213
    prompts = []
    for filename in _TEST_PROMPTS:
214
        prompts += _read_prompts(filename)
215
216
217
    return prompts


218
219
220
221
222
223
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


224
225
226
227
228
229
230
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


231
@pytest.fixture
232
def example_encoder_decoder_prompts(
233
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
234
235
236
237
238
239
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
256
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
257
        DecoderPromptType.EMPTY_STR:
258
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
259
        DecoderPromptType.CUSTOM:
260
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
261
262
263
    }


264
@pytest.fixture
265
def example_long_prompts() -> list[str]:
266
267
    prompts = []
    for filename in _LONG_PROMPTS:
268
        prompts += _read_prompts(filename)
269
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271


272
@pytest.fixture(scope="session")
273
def image_assets() -> ImageTestAssets:
274
275
276
    return IMAGE_ASSETS


277
@pytest.fixture(scope="session")
278
def video_assets() -> VideoTestAssets:
279
280
281
    return VIDEO_ASSETS


282
@pytest.fixture(scope="session")
283
def audio_assets() -> AudioTestAssets:
284
285
286
    return AUDIO_ASSETS


287
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
288
_R = TypeVar("_R")
289

Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
292

class HfRunner:

293
    def get_default_device(self):
294
        from vllm.platforms import current_platform
295

296
297
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
298
299

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
300
301
302
        if x is None or isinstance(x, (bool, )):
            return x

303
        if device is None:
304
            device = self.device
305

306
307
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
308

309
310
311
312
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
313

Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316
    def __init__(
        self,
        model_name: str,
317
        dtype: str = "auto",
318
        *,
319
        model_kwargs: Optional[dict[str, Any]] = None,
320
        trust_remote_code: bool = True,
321
        is_sentence_transformer: bool = False,
322
        is_cross_encoder: bool = False,
323
        skip_tokenizer_init: bool = False,
324
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
325
    ) -> None:
326
        model_name = maybe_model_redirect(model_name)
327
        self.model_name = model_name
328

329
330
        self.config = AutoConfig.from_pretrained(
            model_name,
331
            trust_remote_code=trust_remote_code,
332
333
        )
        self.device = self.get_default_device()
334
335
336
337
338
339
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
340
341
342
343

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

344
        if is_sentence_transformer:
345
346
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
347
348
349
350
351

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
352
                trust_remote_code=trust_remote_code,
353
            )
354
355
356
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
357
358
359
360
361

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
362
                trust_remote_code=trust_remote_code,
363
            )
364
        else:
365
366
            model = auto_cls.from_pretrained(
                model_name,
367
                trust_remote_code=trust_remote_code,
368
369
370
                **model_kwargs,
            )

371
372
373
374
375
376
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

377
378
379
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
380
                model = model.to(device=self.device)
381
382

            self.model = model
383

384
385
386
387
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
388
                trust_remote_code=trust_remote_code,
389
            )
390

391
392
393
394
395
396
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
397
            trust_remote_code=trust_remote_code,
398
        )
399
400
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
401

402
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
403
        self,
404
        prompts: list[str],
405
        images: Optional[PromptImageInput] = None,
406
407
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
408
    ) -> list[Union[BatchFeature, BatchEncoding]]:
409
        if images is not None:
410
            assert len(prompts) == len(images)
411

412
413
414
415
416
417
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

418
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
419
        for i, prompt in enumerate(prompts):
420
            processor_kwargs: dict[str, Any] = {
421
422
423
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
424
425
426
427
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
428
429
430
431
432
433
434
435
436
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
437
438

            inputs = self.processor(**processor_kwargs)
439
440
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
441

442
443
444
445
            all_inputs.append(inputs)

        return all_inputs

446
447
448
449
450
451
452
453
454
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

455
    def classify(self, prompts: list[str]) -> list[str]:
456
457
458
459
460
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
461
462
463
464
465
466
467
468
469

            problem_type = getattr(self.config, "problem_type", "")

            if problem_type == "regression":
                logits = output.logits[0].tolist()
            elif problem_type == "multi_label_classification":
                logits = output.logits.sigmoid()[0].tolist()
            else:
                logits = output.logits.softmax(dim=-1)[0].tolist()
470
471
472
473
            outputs.append(logits)

        return outputs

474
475
    def generate(
        self,
476
        prompts: list[str],
477
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
478
        videos: Optional[PromptVideoInput] = None,
479
480
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
481
    ) -> list[tuple[list[list[int]], list[str]]]:
482
483
484
485
486
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

487
        outputs: list[tuple[list[list[int]], list[str]]] = []
488
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
489
            output_ids = self.model.generate(
490
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
491
492
493
                use_cache=True,
                **kwargs,
            )
494
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
495
496
497
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
498
499
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
500
501
502
503
504
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
505
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
506
        max_tokens: int,
507
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
508
        videos: Optional[PromptVideoInput] = None,
509
        audios: Optional[PromptAudioInput] = None,
510
        **kwargs: Any,
511
    ) -> list[tuple[list[int], str]]:
512
513
        outputs = self.generate(prompts,
                                do_sample=False,
514
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
515
                                images=images,
516
517
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
518
                                **kwargs)
519
520
521

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
522
523
524

    def generate_beam_search(
        self,
525
        prompts: list[str],
526
527
        beam_width: int,
        max_tokens: int,
528
529
530
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
531
    ) -> list[tuple[list[list[int]], list[str]]]:
532
533
534
535
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
536
537
538
539
540
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

541
542
543
544
545
546
547
548
549
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
550

551
552
    def generate_greedy_logprobs(
        self,
553
        prompts: list[str],
554
        max_tokens: int,
555
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
556
        videos: Optional[PromptVideoInput] = None,
557
        audios: Optional[PromptAudioInput] = None,
558
        **kwargs: Any,
559
    ) -> list[list[torch.Tensor]]:
560
561
562
563
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
564

565
        all_logprobs: list[list[torch.Tensor]] = []
566
        for inputs in all_inputs:
567
            output = self.model.generate(
568
                **self.wrap_device(inputs),
569
570
571
572
573
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
574
                **kwargs,
575
            )
576
577
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
578
579
580
            all_logprobs.append(seq_logprobs)
        return all_logprobs

581
    def _hidden_states_to_seq_logprobs(
582
        self,
583
584
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
585
586
        output_embeddings = self.model.get_output_embeddings()

587
        seq_logprobs: list[torch.Tensor] = []
588
589
590
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
591
592
593
594
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
595
                output_embeddings.weight.t(),
596
            )
597
598
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
599
600
601
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

602
603
604
605
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
606
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
607
        num_logprobs: Optional[int],
608
    ) -> tuple[list[dict[int, float]], int]:
609
610
611
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

612
        # convert to dict
613
        seq_logprobs_lst: list[dict[int, float]] = []
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

631
632
    def generate_greedy_logprobs_limit(
        self,
633
        prompts: list[str],
634
        max_tokens: int,
635
        num_logprobs: Optional[int],
636
637
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
638
        videos: Optional[PromptVideoInput] = None,
639
        **kwargs: Any,
640
    ) -> list[TokensTextLogprobs]:
641
642
643
644
645
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

646
647
648
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
649

650
        for inputs in all_inputs:
651
            output = self.model.generate(
652
                **self.wrap_device(inputs),
653
654
655
656
657
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
658
                **kwargs,
659
660
            )

661
662
663
664
665
666
667
668
669
670
671
672
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
673

674
675
676
677
678
679
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
680
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
681
        max_tokens: int,
682
        num_logprobs: Optional[int],
683
        images: Optional[PromptImageInput] = None,
684
        **kwargs: Any,
685
    ) -> list[TokensTextLogprobs]:
686
687
688
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
689

690
691
692
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
693

694
695
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
696
            processor_kwargs: dict[str, Any] = {
697
698
699
700
701
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
702

703
704
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
705
706
707
708

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
709
710
711
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
712
713
714
715
716
717
718
719

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
720
                **encoder_inputs,
721
722
723
724
725
726
727
728
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
729
730
731
732
733
734
735
736
737
738
739

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

740
741
742
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
743

744
745
746
747
748
749
    def predict(self, prompts: list[list[str]], *args,
                **kwargs) -> torch.Tensor:
        return self.model.predict(prompts,
                                  *args,
                                  convert_to_tensor=True,
                                  **kwargs)
750

751
752
753
754
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
755
        del self.model
756
        cleanup_dist_env_and_memory()
757

Woosuk Kwon's avatar
Woosuk Kwon committed
758

Cyrus Leung's avatar
Cyrus Leung committed
759
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
760
761
762
763
764
def hf_runner():
    return HfRunner


class VllmRunner:
765
766
    """
    The default value of some arguments have been modified from
767
    {class}`~vllm.LLM` as follows:
768

769
770
771
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
772
773
    - `block_size`: To reduce memory usage, set default to `64` if on XPU
        devices, otherwise default to `16`.
774
775
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
776
    - `enforce_eager`: Set to `False` to test CUDA graph.
777
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
778
779
780
781

    def __init__(
        self,
        model_name: str,
782
783
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
784
        tokenizer_name: Optional[str] = None,
785
        tokenizer_mode: str = "auto",
786
787
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
788
        max_model_len: Optional[int] = 1024,
789
        dtype: str = "auto",
790
        disable_log_stats: bool = True,
791
        tensor_parallel_size: int = 1,
792
        block_size: int = 16 if not torch.xpu.is_available() else 64,
793
        enable_chunked_prefill: Optional[bool] = False,
794
        swap_space: int = 4,
795
        enforce_eager: Optional[bool] = False,
796
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
797
    ) -> None:
798
        self.llm = LLM(
Woosuk Kwon's avatar
Woosuk Kwon committed
799
            model=model_name,
800
801
            runner=runner,
            convert=convert,
Woosuk Kwon's avatar
Woosuk Kwon committed
802
            tokenizer=tokenizer_name,
803
            tokenizer_mode=tokenizer_mode,
804
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
805
            dtype=dtype,
806
            seed=seed,
807
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
808
            enforce_eager=enforce_eager,
809
            disable_log_stats=disable_log_stats,
810
            tensor_parallel_size=tensor_parallel_size,
811
            max_model_len=max_model_len,
812
813
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
814
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
815
816
        )

817
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
818
        self,
819
        prompts: Union[list[str], list[torch.Tensor], list[int]],
820
        images: Optional[PromptImageInput] = None,
821
822
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
823
    ) -> list[TextPrompt]:
824

825
826
827
828
829
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
830

831
832
833
834
835
836
837
838
839
840
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

841
            text_prompt_kwargs: dict[str, Any] = {
842
843
                "multi_modal_data": multi_modal_data or None
            }
844
845
846
847
848
849
850
            if isinstance(prompt, str):
                text_prompt_kwargs["prompt"] = prompt
            elif isinstance(prompt, list):
                text_prompt_kwargs["prompt_token_ids"] = prompt
            else:
                text_prompt_kwargs["prompt_embeds"] = prompt

851
            inputs.append(TextPrompt(**text_prompt_kwargs))
852
853
854
855
856

        return inputs

    def generate(
        self,
857
        prompts: Union[list[str], list[torch.Tensor]],
858
859
860
861
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
862
        **kwargs: Any,
863
    ) -> list[tuple[list[list[int]], list[str]]]:
864
865
866
867
868
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

869
870
871
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
872

873
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
874
875
876
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
877
878
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
879
880
            for sample in req_output.outputs:
                output_str = sample.text
881
                output_ids = list(sample.token_ids)
882
                req_sample_output_ids.append(prompt_ids + output_ids)
883
                req_sample_output_strs.append((prompt_str or "") + output_str)
884
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
885
886
        return outputs

887
    @staticmethod
888
    def _final_steps_generate_w_logprobs(
889
890
891
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
892
        for req_output in req_outputs:
893
            assert len(req_output.outputs) > 0
894
895
            for sample in req_output.outputs:
                output_str = sample.text
896
                output_ids = list(sample.token_ids)
897
                output_logprobs = sample.logprobs
898
899
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
900
901
        return outputs

902
903
    def generate_w_logprobs(
        self,
904
        prompts: list[str],
905
        sampling_params: SamplingParams,
906
907
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
908
        videos: Optional[PromptVideoInput] = None,
909
        **kwargs: Any,
910
911
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
912
913
914
915
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
916

917
918
919
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
920
921
922
923
924
925
926

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
927
928
929

    def generate_encoder_decoder_w_logprobs(
        self,
930
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
931
        sampling_params: SamplingParams,
932
933
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
934
935
936
937
938
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
939
940
        req_outputs = self.llm.generate(encoder_decoder_prompts,
                                        sampling_params=sampling_params)
941
942
943
944
945
946
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
947

Woosuk Kwon's avatar
Woosuk Kwon committed
948
949
    def generate_greedy(
        self,
950
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
951
        max_tokens: int,
952
        images: Optional[PromptImageInput] = None,
953
954
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
955
        **kwargs: Any,
956
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
957
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
958
959
960
961
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
962
963
                                audios=audios,
                                **kwargs)
964
965
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
966

967
968
    def generate_greedy_logprobs(
        self,
969
        prompts: list[str],
970
        max_tokens: int,
971
        num_logprobs: Optional[int],
972
        num_prompt_logprobs: Optional[int] = None,
973
974
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
975
        videos: Optional[PromptVideoInput] = None,
976
977
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
978
        **kwargs: Any,
979
980
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
981
982
983
984
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
985
            prompt_logprobs=num_prompt_logprobs,
986
987
            stop_token_ids=stop_token_ids,
            stop=stop)
988
989
990
991
992

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
993
994
                                        videos=videos,
                                        **kwargs)
995

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
        """
        Return the perplexity score associated with generating the prompts

        :param prompts: list of prompts to score
        :return: perplexity score of each prompt
        """
        outputs = self.generate_greedy_logprobs(prompts,
                                                max_tokens=1,
                                                num_logprobs=None,
                                                num_prompt_logprobs=0)

        perplexities = []
        for output in outputs:
            output = cast(TokensTextLogprobsPromptLogprobs, output)
            token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
            assert token_datas[0] is None
            token_log_probs = []
            for token_data in token_datas[1:]:
                assert token_data is not None
                assert len(token_data) == 1
                token_log_prob = list(token_data.values())[0].logprob
                token_log_probs.append(token_log_prob)

            perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
            perplexities.append(perplexity)

        return perplexities

1025
1026
    def generate_encoder_decoder_greedy_logprobs(
        self,
1027
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
1028
        max_tokens: int,
1029
        num_logprobs: Optional[int],
1030
        num_prompt_logprobs: Optional[int] = None,
1031
        skip_special_tokens: bool = True,
1032
1033
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
1034
1035
1036
1037
1038
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
1039
            skip_special_tokens=skip_special_tokens,
1040
        )
1041
1042
1043
1044
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

1045
        return self.generate_encoder_decoder_w_logprobs(
1046
1047
            encoder_decoder_prompts, greedy_logprobs_params)

1048
    def generate_beam_search(
1049
        self,
1050
        prompts: list[str],
1051
1052
        beam_width: int,
        max_tokens: int,
1053
1054
1055
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
1056
        concurrency_limit: Optional[int] = None,
1057
    ) -> list[tuple[list[list[int]], list[str]]]:
1058
1059
1060
1061
1062
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1063
1064
1065
1066
        outputs = self.llm.beam_search(inputs,
                                       BeamSearchParams(beam_width=beam_width,
                                                        max_tokens=max_tokens),
                                       concurrency_limit=concurrency_limit)
1067
1068
1069
1070
1071
1072
1073
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1074
    def classify(self, prompts: list[str]) -> list[list[float]]:
1075
        req_outputs = self.llm.classify(prompts)
1076
1077
        return [req_output.outputs.probs for req_output in req_outputs]

1078
1079
1080
1081
1082
1083
1084
    def embed(self,
              prompts: list[str],
              images: Optional[PromptImageInput] = None,
              videos: Optional[PromptVideoInput] = None,
              audios: Optional[PromptAudioInput] = None,
              *args,
              **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1085
1086
1087
1088
1089
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1090
        req_outputs = self.llm.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1091
        return [req_output.outputs.embedding for req_output in req_outputs]
1092

1093
    def encode(self, prompts: list[str]) -> list[list[float]]:
1094
        req_outputs = self.llm.encode(prompts)
1095
1096
        return [req_output.outputs.data for req_output in req_outputs]

1097
1098
1099
1100
    def reward(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.llm.reward(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

1101
1102
    def score(
        self,
1103
1104
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
1105
1106
        *args,
        **kwargs,
1107
    ) -> list[float]:
1108
        req_outputs = self.llm.score(text_1, text_2, *args, **kwargs)
1109
        return [req_output.outputs.score for req_output in req_outputs]
1110

1111
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        if hasattr(self.llm.llm_engine, "model_executor"):
            # This works either in V0 or in V1 with
            # VLLM_ENABLE_V1_MULTIPROCESSING=0
            executor = self.llm.llm_engine.model_executor
            return executor.apply_model(func)

        # This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
        def _apply_model(self):
            return func(self.get_model())

        return self.llm.llm_engine.collective_rpc(_apply_model)
1123

1124
1125
1126
1127
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1128
        del self.llm
1129
        cleanup_dist_env_and_memory()
1130

Woosuk Kwon's avatar
Woosuk Kwon committed
1131

1132
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1133
1134
def vllm_runner():
    return VllmRunner
1135
1136


1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1151
1152
1153
1154
1155
1156
1157


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1158
1159
    from vllm.platforms import current_platform
    return current_platform.device_count()
1160
1161
1162


temp_dir = tempfile.gettempdir()
1163
1164
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1165
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1166
1167
1168
1169


@pytest.fixture
def dummy_opt_path():
1170
1171
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1172
        snapshot_download(repo_id="facebook/opt-125m",
1173
                          local_dir=_dummy_opt_path,
1174
1175
1176
1177
1178
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1179
        with open(json_path) as f:
1180
1181
1182
1183
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1198
        with open(json_path) as f:
1199
1200
1201
1202
1203
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1217
        with open(json_path) as f:
1218
1219
1220
1221
1222
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1241
            item.add_marker(skip_optional)
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")