"docs/vscode:/vscode.git/clone" did not exist on "a8c6fcf65bfb2c64114e70285c4c3efe425382e6"
conftest.py 40.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
import os
5
import tempfile
6
from collections import UserList
7
from enum import Enum
8
9
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
import pytest
import torch
14
import torch.nn as nn
15
import torch.nn.functional as F
16
from huggingface_hub import snapshot_download
17
from PIL import Image
18
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
19
                          BatchFeature)
20
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
23
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm import LLM, SamplingParams
25
from vllm.assets.image import ImageAsset
26
from vllm.assets.video import VideoAsset
27
from vllm.config import LoadFormat, TaskOption, TokenizerPoolConfig
28
from vllm.connections import global_http_connection
29
from vllm.distributed import (cleanup_dist_env_and_memory,
30
31
                              init_distributed_environment,
                              initialize_model_parallel)
32
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
youkaichao's avatar
youkaichao committed
33
34
                         TokensPrompt, to_enc_dec_tuple_list,
                         zip_enc_dec_prompts)
35
from vllm.logger import init_logger
36
from vllm.outputs import RequestOutput
37
from vllm.sampling_params import BeamSearchParams
38
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
youkaichao's avatar
youkaichao committed
39
                        identity, is_list_of)
40

41
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
45
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
46
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
47

Cyrus Leung's avatar
Cyrus Leung committed
48
_M = TypeVar("_M")
49
50
51
52
53
54
55
56
57
58
59

MODELS_ON_S3 = [
    "distilbert/distilgpt2",
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Llama-3.2-1B",
    "meta-llama/Llama-3.2-1B-Instruct",
    "openai-community/gpt2",
    "ArthurZ/Ilama-3.2-1B",
    "llava-hf/llava-1.5-7b-hf",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    "ai21labs/Jamba-tiny-random",
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
    "nm-testing/Phi-3-mini-128k-instruct-FP8",
    "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
    "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
    "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue",
    "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse",
    "AMead10/Llama-3.2-1B-Instruct-AWQ",
    "shuyuej/Llama-3.2-1B-Instruct-GPTQ",
    "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head",
    "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024",
    "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
    "neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
    "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test",
    "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
    "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
    "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
    "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
    "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
    "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
    "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
    "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym",
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
    "nm-testing/tinyllama-oneshot-w4a16-channel-v2",
    "nm-testing/tinyllama-oneshot-w4a16-group128-v2",
    "nm-testing/tinyllama-oneshot-w8a16-per-channel",
    "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
    "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test",
    "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme",
    "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
    "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
    "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
    "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
    "nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor",
    "nm-testing/llama2.c-stories42M-pruned2.4-compressed",
110
111
112
113
]

MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights"

Cyrus Leung's avatar
Cyrus Leung committed
114
115
116
117
118
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]

PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
119

120

121
def _read_prompts(filename: str) -> List[str]:
122
    with open(filename) as f:
123
124
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126


127
128
129
class _ImageAssetPrompts(TypedDict):
    stop_sign: str
    cherry_blossom: str
130
131


132
133
class _ImageAssetsBase(UserList[ImageAsset]):
    pass
134

135
136

class _ImageAssets(_ImageAssetsBase):
137
138

    def __init__(self) -> None:
139
140
141
142
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
143
144
145
146
147
148
149
150

    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
151
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
152
153


154
155
156
157
class _VideoAssetPrompts(TypedDict):
    sample_demo_1: str


158
159
class _VideoAssetsBase(UserList[VideoAsset]):
    pass
160
161
162
163
164
165
166
167
168
169
170
171
172


class _VideoAssets(_VideoAssetsBase):

    def __init__(self) -> None:
        super().__init__([
            VideoAsset("sample_demo_1.mp4"),
        ])

    def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
        return [prompts["sample_demo_1"]]


173
174
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
175
176
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
177
178


Joe Runde's avatar
Joe Runde committed
179
@pytest.fixture(params=[True, False])
180
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
181
182
183
184
185
186
187
188
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
189
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
190
    else:
191
192
193
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
194
195


196
197
198
199
200
201
202
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


203
204
205
206
207
208
209
210
211
212
213
214
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
215
    cleanup_dist_env_and_memory()
216
217


218
@pytest.fixture()
219
def should_do_global_cleanup_after_test(request) -> bool:
220
221
222
223
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
224

225
    return not request.node.get_closest_marker("skip_global_cleanup")
226
227


228
@pytest.fixture(autouse=True)
229
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
230
    yield
231
    if should_do_global_cleanup_after_test:
232
        cleanup_dist_env_and_memory()
233
234


235
236
237
238
239
240
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
@pytest.fixture
def example_prompts() -> List[str]:
243
244
    prompts = []
    for filename in _TEST_PROMPTS:
245
        prompts += _read_prompts(filename)
246
247
248
    return prompts


249
250
251
252
253
254
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


255
256
257
258
259
260
261
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


262
@pytest.fixture
263
264
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
265
266
267
268
269
270
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
287
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
288
        DecoderPromptType.EMPTY_STR:
289
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
290
        DecoderPromptType.CUSTOM:
291
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
292
293
294
    }


295
296
297
298
@pytest.fixture
def example_long_prompts() -> List[str]:
    prompts = []
    for filename in _LONG_PROMPTS:
299
        prompts += _read_prompts(filename)
300
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302


303
304
305
306
307
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
    return IMAGE_ASSETS


308
309
310
311
312
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
    return VIDEO_ASSETS


313
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
314
_R = TypeVar("_R")
315

Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
318

class HfRunner:

319
    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
320
        from vllm.platforms import current_platform
321
322
323
        if x is None or isinstance(x, (bool, )):
            return x

324
        if device is None:
325
            device = "cpu" if current_platform.is_cpu() else "cuda"
326

327
328
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
329

330
331
332
333
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
334

Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
338
    def __init__(
        self,
        model_name: str,
        dtype: str = "half",
339
        *,
340
        model_kwargs: Optional[Dict[str, Any]] = None,
341
        is_sentence_transformer: bool = False,
342
        is_cross_encoder: bool = False,
343
        skip_tokenizer_init: bool = False,
344
        auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
345
        postprocess_inputs: Callable[..., BatchEncoding] = identity,
Woosuk Kwon's avatar
Woosuk Kwon committed
346
    ) -> None:
347
        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
348

349
        self.model_name = model_name
350

351
        if is_sentence_transformer:
352
353
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
354
355
356
357
            self.model = self.wrap_device(
                SentenceTransformer(
                    model_name,
                    device="cpu",
358
                    trust_remote_code=True,
359
                ).to(dtype=torch_dtype))
360
361
362
363
364
365
366
367
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
            self.model = CrossEncoder(model_name,
                                      device="cpu",
                                      trust_remote_code=True)
            self.model.model = self.wrap_device(self.model.model)\
                .to(dtype=torch_dtype)
368
        else:
369
            model_kwargs = model_kwargs if model_kwargs is not None else {}
370
            self.model = self.wrap_device(
371
                auto_cls.from_pretrained(
372
373
374
                    model_name,
                    torch_dtype=torch_dtype,
                    trust_remote_code=True,
375
                    **model_kwargs,
376
                ))
377

378
379
380
381
382
383
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=True,
            )
384

385
386
387
388
389
390
391
392
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
393
394
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
395

396
        self.dtype = dtype
397
398
        self.postprocess_inputs = postprocess_inputs

399
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
400
401
        self,
        prompts: List[str],
402
        images: Optional[PromptImageInput] = None,
403
404
405
406
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[BatchEncoding]:
        if images is not None:
407
            assert len(prompts) == len(images)
408

409
410
411
412
413
414
415
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

        all_inputs: List[BatchEncoding] = []
416
        for i, prompt in enumerate(prompts):
417
418
419
420
            processor_kwargs: Dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
421
422
423
424
425
426
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
            if audios is not None and (audio_tuple := audios[i]) is not None:
                audio, sr = audio_tuple
427
428
                processor_kwargs["audio"] = audio
                processor_kwargs["sampling_rate"] = sr
429
430

            inputs = self.processor(**processor_kwargs)
431
            inputs = self.postprocess_inputs(inputs, dtype=self.dtype)
432

433
434
435
436
            all_inputs.append(inputs)

        return all_inputs

437
438
439
440
441
442
443
444
445
446
447
    def classify(self, prompts: List[str]) -> List[str]:
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

448
449
450
451
    def generate(
        self,
        prompts: List[str],
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
452
        videos: Optional[PromptVideoInput] = None,
453
454
455
456
457
458
459
460
461
462
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
    ) -> List[Tuple[List[List[int]], List[str]]]:
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

        outputs: List[Tuple[List[List[int]], List[str]]] = []
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
463
            output_ids = self.model.generate(
464
                **self.wrap_device(inputs, device=self.model.device.type),
Woosuk Kwon's avatar
Woosuk Kwon committed
465
466
467
                use_cache=True,
                **kwargs,
            )
468
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
469
470
471
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
472
473
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
474
475
476
477
478
479
480
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
481
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
482
        videos: Optional[PromptVideoInput] = None,
483
        audios: Optional[PromptAudioInput] = None,
484
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
485
    ) -> List[Tuple[List[int], str]]:
486
487
        outputs = self.generate(prompts,
                                do_sample=False,
488
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
489
                                images=images,
490
491
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
492
                                **kwargs)
493
494
495

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
496
497
498
499
500
501

    def generate_beam_search(
        self,
        prompts: List[str],
        beam_width: int,
        max_tokens: int,
502
    ) -> List[Tuple[List[List[int]], List[str]]]:
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
                                num_return_sequences=beam_width)
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
517

518
519
520
521
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
522
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
523
        videos: Optional[PromptVideoInput] = None,
524
        audios: Optional[PromptAudioInput] = None,
525
        **kwargs: Any,
526
    ) -> List[List[torch.Tensor]]:
527
528
529
530
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
531

532
533
        all_logprobs: List[List[torch.Tensor]] = []
        for inputs in all_inputs:
534
            output = self.model.generate(
535
                **self.wrap_device(inputs, device=self.model.device.type),
536
537
538
539
540
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
541
                **kwargs,
542
            )
543
544
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
545
546
547
            all_logprobs.append(seq_logprobs)
        return all_logprobs

548
    def _hidden_states_to_seq_logprobs(
549
        self,
550
551
552
553
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
    ) -> List[torch.Tensor]:
        output_embeddings = self.model.get_output_embeddings()

554
555
556
557
        seq_logprobs: List[torch.Tensor] = []
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
558
559
                last_hidden_states.to(output_embeddings.weight.device),
                output_embeddings.weight.t(),
560
            )
561
562
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
563
564
565
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

566
567
568
569
570
571
572
573
574
575
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
        hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
        num_logprobs: int,
    ) -> Tuple[List[Dict[int, float]], int]:
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        # convert to dict
        seq_logprobs_lst: List[Dict[int, float]] = []
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

595
596
597
598
599
    def generate_greedy_logprobs_limit(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
600
601
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
602
        videos: Optional[PromptVideoInput] = None,
603
        **kwargs: Any,
604
    ) -> List[TokensTextLogprobs]:
605
606
607
608
609
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

610
611
612
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
613

614
        for inputs in all_inputs:
615
            output = self.model.generate(
616
                **self.wrap_device(inputs, device=self.model.device.type),
617
618
619
620
621
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
622
                **kwargs,
623
624
            )

625
626
627
628
629
630
631
632
633
634
635
636
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
637

638
639
640
641
642
643
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
644
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
645
646
        max_tokens: int,
        num_logprobs: int,
647
        images: Optional[PromptImageInput] = None,
648
        **kwargs: Any,
649
    ) -> List[TokensTextLogprobs]:
650
651
652
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
653

654
655
656
        all_logprobs: List[List[Dict[int, float]]] = []
        all_output_ids: List[List[int]] = []
        all_output_strs: List[str] = []
657

658
659
660
661
662
663
664
665
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
            processor_kwargs: Dict[str, Any] = {
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
666

667
            encoder_input_ids = self.wrap_device(
668
                self.processor(**processor_kwargs).input_ids,
669
670
671
672
673
674
675
                device=self.model.device.type,
            )

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
                decoder_input_ids = self.wrap_device(
676
                    self.tokenizer(decoder_prompt,
677
678
679
                                   return_tensors="pt").input_ids,
                    device=self.model.device.type,
                )
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696

            output = self.model.generate(
                encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
697
698
699
700
701
702
703
704
705
706
707

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

708
709
710
    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
        return self.model.encode(prompts)

711
712
713
    def predict(self, prompts: List[List[str]]) -> torch.Tensor:
        return self.model.predict(prompts, convert_to_tensor=True)

714
715
716
717
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
718
        del self.model
719
        cleanup_dist_env_and_memory()
720

Woosuk Kwon's avatar
Woosuk Kwon committed
721

Cyrus Leung's avatar
Cyrus Leung committed
722
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
723
724
725
726
727
728
729
730
731
def hf_runner():
    return HfRunner


class VllmRunner:

    def __init__(
        self,
        model_name: str,
732
        task: TaskOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
733
        tokenizer_name: Optional[str] = None,
734
        tokenizer_mode: str = "auto",
735
736
        # Use smaller max model length, otherwise bigger model cannot run due
        # to kv cache size limit.
737
        max_model_len: int = 1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
738
        dtype: str = "half",
739
        disable_log_stats: bool = True,
740
        tensor_parallel_size: int = 1,
741
742
        block_size: int = 16,
        enable_chunked_prefill: bool = False,
743
        swap_space: int = 4,
744
        enforce_eager: Optional[bool] = False,
745
        load_format: Optional[LoadFormat] = None,
746
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
747
    ) -> None:
748
        if model_name in MODELS_ON_S3 and not load_format:
749
            model_name = (f"{MODEL_WEIGHTS_S3_BUCKET}/{model_name}")
750
751
752
            load_format = LoadFormat.RUNAI_STREAMER
        if not load_format:
            load_format = LoadFormat.AUTO
Woosuk Kwon's avatar
Woosuk Kwon committed
753
754
        self.model = LLM(
            model=model_name,
755
            task=task,
Woosuk Kwon's avatar
Woosuk Kwon committed
756
            tokenizer=tokenizer_name,
757
            tokenizer_mode=tokenizer_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
758
759
            trust_remote_code=True,
            dtype=dtype,
760
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
761
            enforce_eager=enforce_eager,
762
            disable_log_stats=disable_log_stats,
763
            tensor_parallel_size=tensor_parallel_size,
764
            max_model_len=max_model_len,
765
766
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
767
            load_format=load_format,
768
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
769
770
        )

771
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
772
773
        self,
        prompts: List[str],
774
        images: Optional[PromptImageInput] = None,
775
776
777
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[TextPrompt]:
778
        if images is not None:
779
            assert len(prompts) == len(images)
780

781
782
783
784
785
786
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

787
788
789
        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
        if images is not None:
            for i, image in enumerate(images):
Cyrus Leung's avatar
Cyrus Leung committed
790
791
                if image is not None:
                    inputs[i]["multi_modal_data"] = {"image": image}
792

793
794
        if videos is not None:
            for i, video in enumerate(videos):
Cyrus Leung's avatar
Cyrus Leung committed
795
796
                if video is not None:
                    inputs[i]["multi_modal_data"] = {"video": video}
797
798
799

        if audios is not None:
            for i, audio in enumerate(audios):
Cyrus Leung's avatar
Cyrus Leung committed
800
801
                if audio is not None:
                    inputs[i]["multi_modal_data"] = {"audio": audio}
802
803
804
805
806
807
808
809
810
811

        return inputs

    def generate(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
812
        **kwargs: Any,
813
814
815
816
817
818
    ) -> List[Tuple[List[List[int]], List[str]]]:
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

819
        req_outputs = self.model.generate(inputs,
820
821
                                          sampling_params=sampling_params,
                                          **kwargs)
822
823

        outputs: List[Tuple[List[List[int]], List[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
824
825
826
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
827
828
            req_sample_output_ids: List[List[int]] = []
            req_sample_output_strs: List[str] = []
829
830
            for sample in req_output.outputs:
                output_str = sample.text
831
                output_ids = list(sample.token_ids)
832
833
834
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append(prompt_str + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
835
836
        return outputs

837
    @staticmethod
838
839
    def _final_steps_generate_w_logprobs(
        req_outputs: List[RequestOutput],
840
841
    ) -> List[TokensTextLogprobsPromptLogprobs]:
        outputs: List[TokensTextLogprobsPromptLogprobs] = []
842
        for req_output in req_outputs:
843
            assert len(req_output.outputs) > 0
844
845
            for sample in req_output.outputs:
                output_str = sample.text
846
                output_ids = list(sample.token_ids)
847
                output_logprobs = sample.logprobs
848
849
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
850
851
        return outputs

852
853
854
855
    def generate_w_logprobs(
        self,
        prompts: List[str],
        sampling_params: SamplingParams,
856
857
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
858
        videos: Optional[PromptVideoInput] = None,
859
        **kwargs: Any,
860
861
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
862
863
864
865
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
866

867
        req_outputs = self.model.generate(inputs,
868
869
                                          sampling_params=sampling_params,
                                          **kwargs)
870
871
872
873
874
875
876

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
877
878
879

    def generate_encoder_decoder_w_logprobs(
        self,
880
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
881
        sampling_params: SamplingParams,
882
883
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
884
885
886
887
888
889
890
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
        req_outputs = self.model.generate(encoder_decoder_prompts,
                                          sampling_params=sampling_params)
891
892
893
894
895
896
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
897

Woosuk Kwon's avatar
Woosuk Kwon committed
898
899
900
901
    def generate_greedy(
        self,
        prompts: List[str],
        max_tokens: int,
902
        images: Optional[PromptImageInput] = None,
903
904
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
905
        **kwargs: Any,
Woosuk Kwon's avatar
Woosuk Kwon committed
906
907
    ) -> List[Tuple[List[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
908
909
910
911
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
912
913
                                audios=audios,
                                **kwargs)
914
915
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
916

917
918
919
920
921
    def generate_greedy_logprobs(
        self,
        prompts: List[str],
        max_tokens: int,
        num_logprobs: int,
922
        num_prompt_logprobs: Optional[int] = None,
923
924
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
925
        videos: Optional[PromptVideoInput] = None,
926
        stop_token_ids: Optional[List[int]] = None,
927
        stop: Optional[List[str]] = None,
928
        **kwargs: Any,
929
930
931
932
933
934
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
935
            prompt_logprobs=num_prompt_logprobs,
936
937
            stop_token_ids=stop_token_ids,
            stop=stop)
938
939
940
941
942

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
943
944
                                        videos=videos,
                                        **kwargs)
945

946
947
    def generate_encoder_decoder_greedy_logprobs(
        self,
948
        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
949
950
        max_tokens: int,
        num_logprobs: int,
951
952
953
954
955
956
957
958
959
        num_prompt_logprobs: Optional[int] = None,
    ) -> Union[List[TokensTextLogprobs],
               List[TokensTextLogprobsPromptLogprobs]]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
        )
960
961
962
963
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

964
        return self.generate_encoder_decoder_w_logprobs(
965
966
            encoder_decoder_prompts, greedy_logprobs_params)

967
    def generate_beam_search(
968
969
970
971
972
        self,
        prompts: Union[List[str], List[List[int]]],
        beam_width: int,
        max_tokens: int,
    ) -> List[Tuple[List[List[int]], List[str]]]:
youkaichao's avatar
youkaichao committed
973
974
975
976
977
978
        if is_list_of(prompts, str, check="all"):
            prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
        else:
            prompts = [
                TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
            ]
979
980
981
        outputs = self.model.beam_search(
            prompts,
            BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
982
983
984
985
986
987
988
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

989
990
991
992
    def classify(self, prompts: List[str]) -> List[List[float]]:
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

Cyrus Leung's avatar
Cyrus Leung committed
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
    def encode(
        self,
        prompts: List[str],
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
    ) -> List[List[float]]:
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1005
        req_outputs = self.model.embed(inputs)
Cyrus Leung's avatar
Cyrus Leung committed
1006
        return [req_output.outputs.embedding for req_output in req_outputs]
1007

1008
1009
1010
1011
    def score(
        self,
        text_1: Union[str, List[str]],
        text_2: Union[str, List[str]],
1012
    ) -> List[float]:
1013
        req_outputs = self.model.score(text_1, text_2)
1014
        return [req_output.outputs.score for req_output in req_outputs]
1015

1016
1017
1018
1019
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        executor = self.model.llm_engine.model_executor
        return executor.apply_model(func)

1020
1021
1022
1023
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1024
        del self.model
1025
        cleanup_dist_env_and_memory()
1026

Woosuk Kwon's avatar
Woosuk Kwon committed
1027

1028
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1029
1030
def vllm_runner():
    return VllmRunner
1031
1032
1033
1034
1035
1036
1037
1038
1039


def get_tokenizer_pool_config(tokenizer_group_type):
    if tokenizer_group_type is None:
        return None
    if tokenizer_group_type == "ray":
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type="ray",
                                   extra_config={})
1040
1041
1042
1043
    if isinstance(tokenizer_group_type, type):
        return TokenizerPoolConfig(pool_size=1,
                                   pool_type=tokenizer_group_type,
                                   extra_config={})
1044
    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060


@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1061
1062
1063
1064
1065
1066
1067


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1068
    return cuda_device_count_stateless()
1069
1070
1071


temp_dir = tempfile.gettempdir()
1072
1073
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
1074
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
1075
1076
1077
1078


@pytest.fixture
def dummy_opt_path():
1079
1080
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1081
        snapshot_download(repo_id="facebook/opt-125m",
1082
                          local_dir=_dummy_opt_path,
1083
1084
1085
1086
1087
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1088
        with open(json_path) as f:
1089
1090
1091
1092
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
    return _dummy_opt_path


@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1107
        with open(json_path) as f:
1108
1109
1110
1111
1112
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1126
        with open(json_path) as f:
1127
1128
1129
1130
1131
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
            item.add_marker(skip_optional)