conftest.py 43.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import math
5
import os
6
import tempfile
7
from enum import Enum
8
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
import pytest
import torch
13
import torch.nn as nn
14
import torch.nn.functional as F
15
from huggingface_hub import snapshot_download
16
from PIL import Image
17
18
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
19
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
20

21
22
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm import LLM, SamplingParams
24
from vllm.assets.audio import AudioAsset
25
from vllm.assets.image import ImageAsset
26
from vllm.assets.video import VideoAsset
27
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
28
from vllm.connections import global_http_connection
29
from vllm.distributed import (cleanup_dist_env_and_memory,
30
31
                              init_distributed_environment,
                              initialize_model_parallel)
32
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
33
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
34
from vllm.logger import init_logger
35
from vllm.outputs import RequestOutput
36
from vllm.sampling_params import BeamSearchParams
37
from vllm.sequence import Logprob
38
from vllm.transformers_utils.utils import maybe_model_redirect
39

40
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
41

42
43
44
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
45
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
46

Cyrus Leung's avatar
Cyrus Leung committed
47
_M = TypeVar("_M")
48

49
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
50
51

PromptImageInput = _PromptMultiModalInput[Image.Image]
52
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
53
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
54

55

56
def _read_prompts(filename: str) -> list[str]:
57
    with open(filename) as f:
58
59
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61


62
class ImageAssetPrompts(TypedDict):
63
64
    stop_sign: str
    cherry_blossom: str
65
66


67
class ImageTestAssets(list[ImageAsset]):
68
69

    def __init__(self) -> None:
70
71
72
73
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
74

75
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
76
77
78
79
80
81
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
82
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
83
84


85
86
class VideoAssetPrompts(TypedDict):
    baby_reading: str
87
88


89
class VideoTestAssets(list[VideoAsset]):
90
91
92

    def __init__(self) -> None:
        super().__init__([
93
            VideoAsset("baby_reading"),
94
95
        ])

96
97
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
98
99


100
class AudioAssetPrompts(TypedDict):
101
102
103
104
    mary_had_lamb: str
    winning_call: str


105
class AudioTestAssets(list[AudioAsset]):
106
107
108
109
110
111
112

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

113
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
114
115
        return [prompts["mary_had_lamb"], prompts["winning_call"]]

116

117
IMAGE_ASSETS = ImageTestAssets()
118
"""Singleton instance of {class}`ImageTestAssets`."""
119
VIDEO_ASSETS = VideoTestAssets()
120
"""Singleton instance of {class}`VideoTestAssets`."""
121
AUDIO_ASSETS = AudioTestAssets()
122
"""Singleton instance of {class}`AudioTestAssets`."""
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
145
@pytest.fixture(params=[True, False])
146
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
147
148
149
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
150
    skip_v0 = request.node.get_closest_marker("skip_v0")
Joe Runde's avatar
Joe Runde committed
151
152
153
154
155
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
156
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
157
    else:
158
159
        if skip_v0:
            pytest.skip("Skipping test on vllm V0")
160
161
162
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
163
164


165
166
167
168
169
170
171
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


172
173
174
175
176
177
178
179
180
181
182
183
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
184
    cleanup_dist_env_and_memory()
185
186


187
@pytest.fixture()
188
def should_do_global_cleanup_after_test(request) -> bool:
189
190
191
192
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
193

194
    return not request.node.get_closest_marker("skip_global_cleanup")
195
196


197
@pytest.fixture(autouse=True)
198
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
199
    yield
200
    if should_do_global_cleanup_after_test:
201
        cleanup_dist_env_and_memory()
202
203


204
205
206
207
208
209
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
210
@pytest.fixture
211
def example_prompts() -> list[str]:
212
213
    prompts = []
    for filename in _TEST_PROMPTS:
214
        prompts += _read_prompts(filename)
215
216
217
    return prompts


218
219
220
221
222
223
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


224
225
226
227
228
229
230
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


231
@pytest.fixture
232
def example_encoder_decoder_prompts(
233
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
234
235
236
237
238
239
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
256
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
257
        DecoderPromptType.EMPTY_STR:
258
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
259
        DecoderPromptType.CUSTOM:
260
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
261
262
263
    }


264
@pytest.fixture
265
def example_long_prompts() -> list[str]:
266
267
    prompts = []
    for filename in _LONG_PROMPTS:
268
        prompts += _read_prompts(filename)
269
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271


272
@pytest.fixture(scope="session")
273
def image_assets() -> ImageTestAssets:
274
275
276
    return IMAGE_ASSETS


277
@pytest.fixture(scope="session")
278
def video_assets() -> VideoTestAssets:
279
280
281
    return VIDEO_ASSETS


282
@pytest.fixture(scope="session")
283
def audio_assets() -> AudioTestAssets:
284
285
286
    return AUDIO_ASSETS


287
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
288
_R = TypeVar("_R")
289

Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
292

class HfRunner:

293
    def get_default_device(self):
294
        from vllm.platforms import current_platform
295

296
297
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
298
299

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
300
301
302
        if x is None or isinstance(x, (bool, )):
            return x

303
        if device is None:
304
            device = self.device
305

306
307
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
308

309
310
311
312
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
313

Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316
    def __init__(
        self,
        model_name: str,
317
        dtype: str = "auto",
318
        *,
319
        model_kwargs: Optional[dict[str, Any]] = None,
320
        trust_remote_code: bool = True,
321
        is_sentence_transformer: bool = False,
322
        is_cross_encoder: bool = False,
323
        skip_tokenizer_init: bool = False,
324
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
325
    ) -> None:
326
        model_name = maybe_model_redirect(model_name)
327
        self.model_name = model_name
328

329
330
        self.config = AutoConfig.from_pretrained(
            model_name,
331
            trust_remote_code=trust_remote_code,
332
333
        )
        self.device = self.get_default_device()
334
335
336
337
338
339
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
340
341
342
343

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

344
        if is_sentence_transformer:
345
346
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
347
348
349
350
351

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
352
                trust_remote_code=trust_remote_code,
353
            )
354
355
356
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
357
358
359
360
361

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
362
                trust_remote_code=trust_remote_code,
363
            )
364
        else:
365
366
            model = auto_cls.from_pretrained(
                model_name,
367
                trust_remote_code=trust_remote_code,
368
369
370
                **model_kwargs,
            )

371
372
373
374
375
376
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

377
378
379
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
380
                model = model.to(device=self.device)
381
382

            self.model = model
383

384
385
386
387
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
388
                trust_remote_code=trust_remote_code,
389
            )
390

391
392
393
394
395
396
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
397
            trust_remote_code=trust_remote_code,
398
        )
399
400
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
401

402
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
403
        self,
404
        prompts: list[str],
405
        images: Optional[PromptImageInput] = None,
406
407
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
408
    ) -> list[Union[BatchFeature, BatchEncoding]]:
409
        if images is not None:
410
            assert len(prompts) == len(images)
411

412
413
414
415
416
417
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

418
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
419
        for i, prompt in enumerate(prompts):
420
            processor_kwargs: dict[str, Any] = {
421
422
423
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
424
425
426
427
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
428
429
430
431
432
433
434
435
436
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
437
438

            inputs = self.processor(**processor_kwargs)
439
440
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
441

442
443
444
445
            all_inputs.append(inputs)

        return all_inputs

446
447
448
449
450
451
452
453
454
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

455
    def classify(self, prompts: list[str]) -> list[str]:
456
457
458
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
459
460
        problem_type = getattr(self.config, "problem_type", "")

461
462
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
463
464
465
466
467
468
            if problem_type == "regression":
                logits = output.logits[0].tolist()
            elif problem_type == "multi_label_classification":
                logits = output.logits.sigmoid()[0].tolist()
            else:
                logits = output.logits.softmax(dim=-1)[0].tolist()
469
470
471
472
            outputs.append(logits)

        return outputs

473
474
    def generate(
        self,
475
        prompts: list[str],
476
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
477
        videos: Optional[PromptVideoInput] = None,
478
479
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
480
    ) -> list[tuple[list[list[int]], list[str]]]:
481
482
483
484
485
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

486
        outputs: list[tuple[list[list[int]], list[str]]] = []
487
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
488
            output_ids = self.model.generate(
489
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
490
491
492
                use_cache=True,
                **kwargs,
            )
493
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
494
495
496
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
497
498
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
499
500
501
502
503
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
504
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
505
        max_tokens: int,
506
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
507
        videos: Optional[PromptVideoInput] = None,
508
        audios: Optional[PromptAudioInput] = None,
509
        **kwargs: Any,
510
    ) -> list[tuple[list[int], str]]:
511
512
        outputs = self.generate(prompts,
                                do_sample=False,
513
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
514
                                images=images,
515
516
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
517
                                **kwargs)
518
519
520

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
521
522
523

    def generate_beam_search(
        self,
524
        prompts: list[str],
525
526
        beam_width: int,
        max_tokens: int,
527
528
529
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
530
    ) -> list[tuple[list[list[int]], list[str]]]:
531
532
533
534
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
535
536
537
538
539
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

540
541
542
543
544
545
546
547
548
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
549

550
551
    def generate_greedy_logprobs(
        self,
552
        prompts: list[str],
553
        max_tokens: int,
554
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
555
        videos: Optional[PromptVideoInput] = None,
556
        audios: Optional[PromptAudioInput] = None,
557
        **kwargs: Any,
558
    ) -> list[list[torch.Tensor]]:
559
560
561
562
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
563

564
        all_logprobs: list[list[torch.Tensor]] = []
565
        for inputs in all_inputs:
566
            output = self.model.generate(
567
                **self.wrap_device(inputs),
568
569
570
571
572
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
573
                **kwargs,
574
            )
575
576
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
577
578
579
            all_logprobs.append(seq_logprobs)
        return all_logprobs

580
    def _hidden_states_to_seq_logprobs(
581
        self,
582
583
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
584
585
        output_embeddings = self.model.get_output_embeddings()

586
        seq_logprobs: list[torch.Tensor] = []
587
588
589
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
590
591
592
593
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
594
                output_embeddings.weight.t(),
595
            )
596
597
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
598
599
600
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

601
602
603
604
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
605
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
606
        num_logprobs: Optional[int],
607
    ) -> tuple[list[dict[int, float]], int]:
608
609
610
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

611
        # convert to dict
612
        seq_logprobs_lst: list[dict[int, float]] = []
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

630
631
    def generate_greedy_logprobs_limit(
        self,
632
        prompts: list[str],
633
        max_tokens: int,
634
        num_logprobs: Optional[int],
635
636
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
637
        videos: Optional[PromptVideoInput] = None,
638
        **kwargs: Any,
639
    ) -> list[TokensTextLogprobs]:
640
641
642
643
644
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

645
646
647
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
648

649
        for inputs in all_inputs:
650
            output = self.model.generate(
651
                **self.wrap_device(inputs),
652
653
654
655
656
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
657
                **kwargs,
658
659
            )

660
661
662
663
664
665
666
667
668
669
670
671
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
672

673
674
675
676
677
678
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
679
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
680
        max_tokens: int,
681
        num_logprobs: Optional[int],
682
        images: Optional[PromptImageInput] = None,
683
        **kwargs: Any,
684
    ) -> list[TokensTextLogprobs]:
685
686
687
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
688

689
690
691
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
692

693
694
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
695
            processor_kwargs: dict[str, Any] = {
696
697
698
699
700
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
701

702
703
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
704
705
706
707

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
708
709
710
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
711
712
713
714
715
716
717
718

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
719
                **encoder_inputs,
720
721
722
723
724
725
726
727
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
728
729
730
731
732
733
734
735
736
737
738

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

739
740
741
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
742

743
744
745
746
747
748
    def predict(self, prompts: list[list[str]], *args,
                **kwargs) -> torch.Tensor:
        return self.model.predict(prompts,
                                  *args,
                                  convert_to_tensor=True,
                                  **kwargs)
749

750
751
752
753
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
754
        del self.model
755
        cleanup_dist_env_and_memory()
756

Woosuk Kwon's avatar
Woosuk Kwon committed
757

Cyrus Leung's avatar
Cyrus Leung committed
758
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
759
760
761
762
763
def hf_runner():
    return HfRunner


class VllmRunner:
764
765
    """
    The default value of some arguments have been modified from
766
    {class}`~vllm.LLM` as follows:
767

768
769
770
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
771
772
    - `block_size`: To reduce memory usage, set default to `64` if on XPU
        devices, otherwise default to `16`.
773
774
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
775
    - `enforce_eager`: Set to `False` to test CUDA graph.
776
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
777
778
779
780

    def __init__(
        self,
        model_name: str,
781
782
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
783
        tokenizer_name: Optional[str] = None,
784
        tokenizer_mode: str = "auto",
785
786
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
787
        max_model_len: Optional[int] = 1024,
788
        dtype: str = "auto",
789
        disable_log_stats: bool = True,
790
        tensor_parallel_size: int = 1,
791
        block_size: int = 16 if not torch.xpu.is_available() else 64,
792
        enable_chunked_prefill: Optional[bool] = False,
793
        swap_space: int = 4,
794
        enforce_eager: Optional[bool] = False,
795
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
796
    ) -> None:
797
        self.llm = LLM(
Woosuk Kwon's avatar
Woosuk Kwon committed
798
            model=model_name,
799
800
            runner=runner,
            convert=convert,
Woosuk Kwon's avatar
Woosuk Kwon committed
801
            tokenizer=tokenizer_name,
802
            tokenizer_mode=tokenizer_mode,
803
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
804
            dtype=dtype,
805
            seed=seed,
806
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
807
            enforce_eager=enforce_eager,
808
            disable_log_stats=disable_log_stats,
809
            tensor_parallel_size=tensor_parallel_size,
810
            max_model_len=max_model_len,
811
812
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
813
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
814
815
        )

816
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
817
        self,
818
        prompts: Union[list[str], list[torch.Tensor], list[int]],
819
        images: Optional[PromptImageInput] = None,
820
821
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
822
    ) -> list[TextPrompt]:
823

824
825
826
827
828
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
829

830
831
832
833
834
835
836
837
838
839
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio

840
            text_prompt_kwargs: dict[str, Any] = {
841
842
                "multi_modal_data": multi_modal_data or None
            }
843
844
845
846
847
848
849
            if isinstance(prompt, str):
                text_prompt_kwargs["prompt"] = prompt
            elif isinstance(prompt, list):
                text_prompt_kwargs["prompt_token_ids"] = prompt
            else:
                text_prompt_kwargs["prompt_embeds"] = prompt

850
            inputs.append(TextPrompt(**text_prompt_kwargs))
851
852
853
854
855

        return inputs

    def generate(
        self,
856
        prompts: Union[list[str], list[torch.Tensor]],
857
858
859
860
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
861
        **kwargs: Any,
862
    ) -> list[tuple[list[list[int]], list[str]]]:
863
864
865
866
867
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

868
869
870
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
871

872
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
873
874
875
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
876
877
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
878
879
            for sample in req_output.outputs:
                output_str = sample.text
880
                output_ids = list(sample.token_ids)
881
                req_sample_output_ids.append(prompt_ids + output_ids)
882
                req_sample_output_strs.append((prompt_str or "") + output_str)
883
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
884
885
        return outputs

886
    @staticmethod
887
    def _final_steps_generate_w_logprobs(
888
889
890
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
891
        for req_output in req_outputs:
892
            assert len(req_output.outputs) > 0
893
894
            for sample in req_output.outputs:
                output_str = sample.text
895
                output_ids = list(sample.token_ids)
896
                output_logprobs = sample.logprobs
897
898
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
899
900
        return outputs

901
902
    def generate_w_logprobs(
        self,
903
        prompts: list[str],
904
        sampling_params: SamplingParams,
905
906
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
907
        videos: Optional[PromptVideoInput] = None,
908
        **kwargs: Any,
909
910
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
911
912
913
914
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
915

916
917
918
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
919
920
921
922
923
924
925

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
926
927
928

    def generate_encoder_decoder_w_logprobs(
        self,
929
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
930
        sampling_params: SamplingParams,
931
932
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
933
934
935
936
937
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
938
939
        req_outputs = self.llm.generate(encoder_decoder_prompts,
                                        sampling_params=sampling_params)
940
941
942
943
944
945
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
946

Woosuk Kwon's avatar
Woosuk Kwon committed
947
948
    def generate_greedy(
        self,
949
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
950
        max_tokens: int,
951
        images: Optional[PromptImageInput] = None,
952
953
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
954
        **kwargs: Any,
955
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
956
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
957
958
959
960
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
961
962
                                audios=audios,
                                **kwargs)
963
964
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
965

966
967
    def generate_greedy_logprobs(
        self,
968
        prompts: list[str],
969
        max_tokens: int,
970
        num_logprobs: Optional[int],
971
        num_prompt_logprobs: Optional[int] = None,
972
973
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
974
        videos: Optional[PromptVideoInput] = None,
975
976
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
977
        **kwargs: Any,
978
979
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
980
981
982
983
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
984
            prompt_logprobs=num_prompt_logprobs,
985
986
            stop_token_ids=stop_token_ids,
            stop=stop)
987
988
989
990
991

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
992
993
                                        videos=videos,
                                        **kwargs)
994

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
    def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
        """
        Return the perplexity score associated with generating the prompts

        :param prompts: list of prompts to score
        :return: perplexity score of each prompt
        """
        outputs = self.generate_greedy_logprobs(prompts,
                                                max_tokens=1,
                                                num_logprobs=None,
                                                num_prompt_logprobs=0)

        perplexities = []
        for output in outputs:
            output = cast(TokensTextLogprobsPromptLogprobs, output)
            token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
            assert token_datas[0] is None
            token_log_probs = []
            for token_data in token_datas[1:]:
                assert token_data is not None
                assert len(token_data) == 1
                token_log_prob = list(token_data.values())[0].logprob
                token_log_probs.append(token_log_prob)

            perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
            perplexities.append(perplexity)

        return perplexities

1024
1025
    def generate_encoder_decoder_greedy_logprobs(
        self,
1026
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
1027
        max_tokens: int,
1028
        num_logprobs: Optional[int],
1029
        num_prompt_logprobs: Optional[int] = None,
1030
        skip_special_tokens: bool = True,
1031
1032
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
1033
1034
1035
1036
1037
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
1038
            skip_special_tokens=skip_special_tokens,
1039
        )
1040
1041
1042
1043
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

1044
        return self.generate_encoder_decoder_w_logprobs(
1045
1046
            encoder_decoder_prompts, greedy_logprobs_params)

1047
    def generate_beam_search(
1048
        self,
1049
        prompts: list[str],
1050
1051
        beam_width: int,
        max_tokens: int,
1052
1053
1054
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
1055
        concurrency_limit: Optional[int] = None,
1056
    ) -> list[tuple[list[list[int]], list[str]]]:
1057
1058
1059
1060
1061
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1062
1063
1064
1065
        outputs = self.llm.beam_search(inputs,
                                       BeamSearchParams(beam_width=beam_width,
                                                        max_tokens=max_tokens),
                                       concurrency_limit=concurrency_limit)
1066
1067
1068
1069
1070
1071
1072
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1073
    def classify(self, prompts: list[str]) -> list[list[float]]:
1074
        req_outputs = self.llm.classify(prompts)
1075
1076
        return [req_output.outputs.probs for req_output in req_outputs]

1077
1078
1079
1080
1081
1082
1083
    def embed(self,
              prompts: list[str],
              images: Optional[PromptImageInput] = None,
              videos: Optional[PromptVideoInput] = None,
              audios: Optional[PromptAudioInput] = None,
              *args,
              **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1084
1085
1086
1087
1088
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1089
        req_outputs = self.llm.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1090
        return [req_output.outputs.embedding for req_output in req_outputs]
1091

1092
    def encode(self, prompts: list[str]) -> list[list[float]]:
1093
        req_outputs = self.llm.encode(prompts)
1094
1095
        return [req_output.outputs.data for req_output in req_outputs]

1096
1097
1098
1099
    def reward(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.llm.reward(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

1100
1101
    def score(
        self,
1102
1103
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
1104
1105
        *args,
        **kwargs,
1106
    ) -> list[float]:
1107
        req_outputs = self.llm.score(text_1, text_2, *args, **kwargs)
1108
        return [req_output.outputs.score for req_output in req_outputs]
1109

1110
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        if hasattr(self.llm.llm_engine, "model_executor"):
            # This works either in V0 or in V1 with
            # VLLM_ENABLE_V1_MULTIPROCESSING=0
            executor = self.llm.llm_engine.model_executor
            return executor.apply_model(func)

        # This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
        def _apply_model(self):
            return func(self.get_model())

        return self.llm.llm_engine.collective_rpc(_apply_model)
1122

1123
1124
1125
1126
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1127
        del self.llm
1128
        cleanup_dist_env_and_memory()
1129

Woosuk Kwon's avatar
Woosuk Kwon committed
1130

1131
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1132
1133
def vllm_runner():
    return VllmRunner
1134
1135


1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1150
1151
1152
1153
1154
1155
1156


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1157
1158
    from vllm.platforms import current_platform
    return current_platform.device_count()
1159
1160
1161


temp_dir = tempfile.gettempdir()
1162
1163
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1164
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1165
1166
1167
1168


@pytest.fixture
def dummy_opt_path():
1169
1170
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1171
        snapshot_download(repo_id="facebook/opt-125m",
1172
                          local_dir=_dummy_opt_path,
1173
1174
1175
1176
1177
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1178
        with open(json_path) as f:
1179
1180
1181
1182
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1197
        with open(json_path) as f:
1198
1199
1200
1201
1202
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1216
        with open(json_path) as f:
1217
1218
1219
1220
1221
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
1240
            item.add_marker(skip_optional)
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")