conftest.py 44.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
import math
5
import os
6
import tempfile
7
from enum import Enum
8
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
9
10
import pytest
import pytest_html
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
import pytest
import torch
15
import torch.nn as nn
16
import torch.nn.functional as F
17
from huggingface_hub import snapshot_download
18
from PIL import Image
19
20
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          BatchEncoding, BatchFeature)
21
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Woosuk Kwon's avatar
Woosuk Kwon committed
22

23
24
from tests.models.utils import (TokensTextLogprobs,
                                TokensTextLogprobsPromptLogprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm import LLM, SamplingParams
26
from vllm.assets.audio import AudioAsset
27
from vllm.assets.image import ImageAsset
28
from vllm.assets.video import VideoAsset
29
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
30
from vllm.connections import global_http_connection
31
from vllm.distributed import (cleanup_dist_env_and_memory,
32
33
                              init_distributed_environment,
                              initialize_model_parallel)
34
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
35
                         to_enc_dec_tuple_list, zip_enc_dec_prompts)
36
from vllm.logger import init_logger
37
from vllm.outputs import RequestOutput
38
from vllm.sampling_params import BeamSearchParams
zhuwenwen's avatar
zhuwenwen committed
39

40
from vllm.sequence import Logprob
41
from vllm.transformers_utils.utils import maybe_model_redirect
zhuwenwen's avatar
zhuwenwen committed
42
from .utils import models_path_prefix
43

44
logger = init_logger(__name__)
Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
47
48
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
49
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
50

Cyrus Leung's avatar
Cyrus Leung committed
51
_M = TypeVar("_M")
52

53
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
Cyrus Leung's avatar
Cyrus Leung committed
54
55

PromptImageInput = _PromptMultiModalInput[Image.Image]
56
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
Cyrus Leung's avatar
Cyrus Leung committed
57
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
58

59

60
def _read_prompts(filename: str) -> list[str]:
61
    with open(filename) as f:
62
63
        prompts = f.readlines()
        return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65


66
class ImageAssetPrompts(TypedDict):
67
68
    stop_sign: str
    cherry_blossom: str
69
70


71
class ImageTestAssets(list[ImageAsset]):
72
73

    def __init__(self) -> None:
74
75
76
77
        super().__init__([
            ImageAsset("stop_sign"),
            ImageAsset("cherry_blossom"),
        ])
78

79
    def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
80
81
82
83
84
85
        """
        Convenience method to define the prompt for each test image.

        The order of the returned prompts matches the order of the
        assets when iterating through this object.
        """
86
        return [prompts["stop_sign"], prompts["cherry_blossom"]]
87
88


89
90
class VideoAssetPrompts(TypedDict):
    baby_reading: str
91
92


93
class VideoTestAssets(list[VideoAsset]):
94
95
96

    def __init__(self) -> None:
        super().__init__([
97
            VideoAsset("baby_reading"),
98
99
        ])

100
101
    def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
        return [prompts["baby_reading"]]
102
103


104
class AudioAssetPrompts(TypedDict):
105
106
    mary_had_lamb: str
    winning_call: str
107
108


109
class AudioTestAssets(list[AudioAsset]):
110
111
112
113
114
115
116

    def __init__(self) -> None:
        super().__init__([
            AudioAsset("mary_had_lamb"),
            AudioAsset("winning_call"),
        ])

117
    def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
118
        return [prompts["mary_had_lamb"], prompts["winning_call"]]
119
120


121
IMAGE_ASSETS = ImageTestAssets()
122
"""Singleton instance of {class}`ImageTestAssets`."""
123
VIDEO_ASSETS = VideoTestAssets()
124
"""Singleton instance of {class}`VideoTestAssets`."""
125
AUDIO_ASSETS = AudioTestAssets()
126
"""Singleton instance of {class}`AudioTestAssets`."""
127
128


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@pytest.fixture(scope="function", autouse=True)
def cleanup_VLLM_USE_V1(monkeypatch):
    """
    The V1 oracle sets "VLLM_USE_V1" during loading. This means
    that each invocation of a test change the env variable.

    If we touch "VLLM_USE_V1" with monkeypatch, then any changes
    made during the test run by vLLM will be cleaned up.

    This fixture is used by every test.
    """

    # If VLLM_USE_V1 is not set, set then delete. This will
    # cause monkeypatch to clean up VLLM_USE_V1 upon exit
    # if VLLM modifies the value of envs.VLLM_USE_V1.
    if "VLLM_USE_V1" not in os.environ:
        monkeypatch.setenv("VLLM_USE_V1", "")
        monkeypatch.delenv("VLLM_USE_V1")


Joe Runde's avatar
Joe Runde committed
149
@pytest.fixture(params=[True, False])
150
def run_with_both_engines(request, monkeypatch):
Joe Runde's avatar
Joe Runde committed
151
152
153
    # Automatically runs tests twice, once with V1 and once without
    use_v1 = request.param
    # Tests decorated with `@skip_v1` are only run without v1
154
    skip_v0 = request.node.get_closest_marker("skip_v0")
Joe Runde's avatar
Joe Runde committed
155
156
157
158
159
    skip_v1 = request.node.get_closest_marker("skip_v1")

    if use_v1:
        if skip_v1:
            pytest.skip("Skipping test on vllm V1")
160
        monkeypatch.setenv('VLLM_USE_V1', '1')
Joe Runde's avatar
Joe Runde committed
161
    else:
162
163
        if skip_v0:
            pytest.skip("Skipping test on vllm V0")
164
165
166
        monkeypatch.setenv('VLLM_USE_V1', '0')

    yield
Joe Runde's avatar
Joe Runde committed
167
168


169
170
171
172
173
174
175
@pytest.fixture(autouse=True)
def init_test_http_connection():
    # pytest_asyncio may use a different event loop per test
    # so we need to make sure the async client is created anew
    global_http_connection.reuse_client = False


176
177
178
179
180
181
182
183
184
185
186
187
@pytest.fixture
def dist_init():
    temp_file = tempfile.mkstemp()[1]
    init_distributed_environment(
        world_size=1,
        rank=0,
        distributed_init_method=f"file://{temp_file}",
        local_rank=0,
        backend="nccl",
    )
    initialize_model_parallel(1, 1)
    yield
188
    cleanup_dist_env_and_memory()
189
190


191
@pytest.fixture()
192
def should_do_global_cleanup_after_test(request) -> bool:
193
194
195
196
    """Allow subdirectories to skip global cleanup by overriding this fixture.
    This can provide a ~10x speedup for non-GPU unit tests since they don't need
    to initialize torch.
    """
197

198
    return not request.node.get_closest_marker("skip_global_cleanup")
199
200


201
@pytest.fixture(autouse=True)
202
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
203
    yield
204
    if should_do_global_cleanup_after_test:
205
        cleanup_dist_env_and_memory()
206
207


208
209
210
211
212
213
@pytest.fixture(autouse=True)
def dynamo_reset():
    yield
    torch._dynamo.reset()


Woosuk Kwon's avatar
Woosuk Kwon committed
214
@pytest.fixture
215
def example_prompts() -> list[str]:
216
217
    prompts = []
    for filename in _TEST_PROMPTS:
218
        prompts += _read_prompts(filename)
219
220
221
    return prompts


222
223
224
225
226
227
@pytest.fixture
def example_system_message() -> str:
    with open(_SYS_MSG) as f:
        return f.read()


228
229
230
231
232
233
234
class DecoderPromptType(Enum):
    """For encoder/decoder models only."""
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3


235
@pytest.fixture
236
def example_encoder_decoder_prompts(
237
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
238
239
240
241
242
243
    '''
    Returns an encoder prompt list and a decoder prompt list, wherein each pair
    of same-index entries in both lists corresponds to an (encoder prompt,
    decoder prompt) tuple.

    Returns:
244

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    * Encoder prompt list
    * Decoder prompt list (reverse of encoder prompt list)
    '''

    encoder_prompts = []
    for filename in _TEST_PROMPTS:
        encoder_prompts += _read_prompts(filename)

    custom_decoder_prompts = encoder_prompts[::-1]
    empty_str_decoder_prompts = [""] * len(encoder_prompts)
    none_decoder_prompts = [None] * len(encoder_prompts)

    # NONE decoder prompt type
    return {
        DecoderPromptType.NONE:
260
        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
261
        DecoderPromptType.EMPTY_STR:
262
        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
263
        DecoderPromptType.CUSTOM:
264
        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
265
266
267
    }


268
@pytest.fixture
269
def example_long_prompts() -> list[str]:
270
271
    prompts = []
    for filename in _LONG_PROMPTS:
272
        prompts += _read_prompts(filename)
273
    return prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275


276
@pytest.fixture(scope="session")
277
def image_assets() -> ImageTestAssets:
278
279
280
    return IMAGE_ASSETS


281
@pytest.fixture(scope="session")
282
def video_assets() -> VideoTestAssets:
283
284
285
    return VIDEO_ASSETS


286
@pytest.fixture(scope="session")
287
def audio_assets() -> AudioTestAssets:
288
289
290
    return AUDIO_ASSETS


291
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
292
_R = TypeVar("_R")
293

Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296

class HfRunner:

297
    def get_default_device(self):
298
        from vllm.platforms import current_platform
299

300
301
        return ("cpu"
                if current_platform.is_cpu() else current_platform.device_type)
302
303

    def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
304
305
306
        if x is None or isinstance(x, (bool, )):
            return x

307
        if device is None:
308
            device = self.device
309

310
311
        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}
312

313
314
315
316
        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)
317

Woosuk Kwon's avatar
Woosuk Kwon committed
318
319
320
    def __init__(
        self,
        model_name: str,
321
        dtype: str = "auto",
322
        *,
323
        model_kwargs: Optional[dict[str, Any]] = None,
324
        trust_remote_code: bool = True,
325
        is_sentence_transformer: bool = False,
326
        is_cross_encoder: bool = False,
327
        skip_tokenizer_init: bool = False,
328
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
Woosuk Kwon's avatar
Woosuk Kwon committed
329
    ) -> None:
330
        model_name = maybe_model_redirect(model_name)
331
        self.model_name = model_name
332

333
334
        self.config = AutoConfig.from_pretrained(
            model_name,
335
            trust_remote_code=trust_remote_code,
336
337
        )
        self.device = self.get_default_device()
338
339
340
341
342
343
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )
344
345
346
347

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

348
        if is_sentence_transformer:
349
350
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer
351
352
353
354
355

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
356
                trust_remote_code=trust_remote_code,
357
            )
358
359
360
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder
361
362
363
364
365

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
366
                trust_remote_code=trust_remote_code,
367
            )
368
        else:
369
370
            model = auto_cls.from_pretrained(
                model_name,
371
                trust_remote_code=trust_remote_code,
372
373
374
                **model_kwargs,
            )

375
376
377
378
379
380
            # in case some unquantized custom models are not in same dtype
            if (getattr(model, "quantization_method", None) is None
                    and any(p.dtype != self.dtype
                            for p in model.parameters())):
                model = model.to(dtype=self.dtype)

381
382
383
            if (getattr(model, "quantization_method", None) != "bitsandbytes"
                    and len({p.device
                             for p in model.parameters()}) < 2):
384
                model = model.to(device=self.device)
385
386

            self.model = model
387

388
389
390
391
        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
392
                trust_remote_code=trust_remote_code,
393
            )
394

395
396
397
398
399
400
        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401
        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
401
            trust_remote_code=trust_remote_code,
402
        )
403
404
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer
Woosuk Kwon's avatar
Woosuk Kwon committed
405

406
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
407
        self,
408
        prompts: list[str],
409
        images: Optional[PromptImageInput] = None,
410
411
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
412
    ) -> list[Union[BatchFeature, BatchEncoding]]:
413
        if images is not None:
414
            assert len(prompts) == len(images)
415

416
417
418
419
420
421
        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

422
        all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
423
        for i, prompt in enumerate(prompts):
424
            processor_kwargs: dict[str, Any] = {
425
426
427
                "text": prompt,
                "return_tensors": "pt",
            }
Cyrus Leung's avatar
Cyrus Leung committed
428
429
430
431
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
432
433
434
435
436
437
438
439
440
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs
441
442

            inputs = self.processor(**processor_kwargs)
443
444
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)
445

446
447
448
449
            all_inputs.append(inputs)

        return all_inputs

450
451
452
453
454
455
456
457
458
    def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
        all_inputs = self.get_inputs(prompts)
        embeddings = []
        for inputs in all_inputs:
            input_ids = self.wrap_device(inputs)["input_ids"]
            embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
            embeddings.append(embedding)
        return embeddings

459
    def classify(self, prompts: list[str]) -> list[str]:
460
461
462
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
463
464
        problem_type = getattr(self.config, "problem_type", "")

465
466
        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
467
468
469
470
471
472
            if problem_type == "regression":
                logits = output.logits[0].tolist()
            elif problem_type == "multi_label_classification":
                logits = output.logits.sigmoid()[0].tolist()
            else:
                logits = output.logits.softmax(dim=-1)[0].tolist()
473
474
475
476
            outputs.append(logits)

        return outputs

477
478
    def generate(
        self,
479
        prompts: list[str],
480
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
481
        videos: Optional[PromptVideoInput] = None,
482
483
        audios: Optional[PromptAudioInput] = None,
        **kwargs: Any,
484
    ) -> list[tuple[list[list[int]], list[str]]]:
485
486
487
488
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
489

490
        outputs: list[tuple[list[list[int]], list[str]]] = []
491
        for inputs in all_inputs:
Woosuk Kwon's avatar
Woosuk Kwon committed
492
            output_ids = self.model.generate(
493
                **self.wrap_device(inputs),
Woosuk Kwon's avatar
Woosuk Kwon committed
494
495
496
                use_cache=True,
                **kwargs,
            )
497
            output_str = self.processor.batch_decode(
Woosuk Kwon's avatar
Woosuk Kwon committed
498
499
500
                output_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
501
502
            )
            output_ids = output_ids.cpu().tolist()
Woosuk Kwon's avatar
Woosuk Kwon committed
503
504
505
506
507
            outputs.append((output_ids, output_str))
        return outputs

    def generate_greedy(
        self,
508
        prompts: list[str],
Woosuk Kwon's avatar
Woosuk Kwon committed
509
        max_tokens: int,
510
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
511
        videos: Optional[PromptVideoInput] = None,
512
        audios: Optional[PromptAudioInput] = None,
513
        **kwargs: Any,
514
    ) -> list[tuple[list[int], str]]:
515
516
        outputs = self.generate(prompts,
                                do_sample=False,
517
                                max_new_tokens=max_tokens,
Chang Su's avatar
Chang Su committed
518
                                images=images,
519
520
                                videos=videos,
                                audios=audios,
Chang Su's avatar
Chang Su committed
521
                                **kwargs)
522
523
524

        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
525
526
527

    def generate_beam_search(
        self,
528
        prompts: list[str],
529
530
        beam_width: int,
        max_tokens: int,
531
532
533
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
534
    ) -> list[tuple[list[list[int]], list[str]]]:
535
536
537
538
        outputs = self.generate(prompts,
                                do_sample=False,
                                max_new_tokens=max_tokens,
                                num_beams=beam_width,
539
540
541
542
543
                                num_return_sequences=beam_width,
                                images=images,
                                videos=videos,
                                audios=audios)

544
545
546
547
548
549
550
551
552
        for i in range(len(outputs)):
            output_ids, output_str = outputs[i]
            for j in range(len(output_ids)):
                output_ids[j] = [
                    x for x in output_ids[j]
                    if x != self.tokenizer.pad_token_id
                ]
            outputs[i] = (output_ids, output_str)
        return outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
553

554
555
    def generate_greedy_logprobs(
        self,
556
        prompts: list[str],
557
        max_tokens: int,
558
        images: Optional[PromptImageInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
559
        videos: Optional[PromptVideoInput] = None,
560
        audios: Optional[PromptAudioInput] = None,
561
        **kwargs: Any,
562
    ) -> list[list[torch.Tensor]]:
563
564
565
566
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)
567

568
        all_logprobs: list[list[torch.Tensor]] = []
569
        for inputs in all_inputs:
570
            output = self.model.generate(
571
                **self.wrap_device(inputs),
572
573
574
575
576
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
577
                **kwargs,
578
            )
579
580
            seq_logprobs = self._hidden_states_to_seq_logprobs(
                output.hidden_states)
581
582
583
            all_logprobs.append(seq_logprobs)
        return all_logprobs

584
    def _hidden_states_to_seq_logprobs(
585
        self,
586
587
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
    ) -> list[torch.Tensor]:
588
589
        output_embeddings = self.model.get_output_embeddings()

590
        seq_logprobs: list[torch.Tensor] = []
591
592
593
        for _, hidden_state in enumerate(hidden_states):
            last_hidden_states = hidden_state[-1][0]
            logits = torch.matmul(
594
595
596
597
                last_hidden_states.to(
                    device=output_embeddings.weight.device,
                    dtype=output_embeddings.weight.dtype,
                ),
598
                output_embeddings.weight.t(),
599
            )
600
601
            if getattr(output_embeddings, "bias", None) is not None:
                logits += output_embeddings.bias.unsqueeze(0)
602
603
604
            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            seq_logprobs.append(logprobs)

605
606
607
608
        return seq_logprobs

    def _hidden_states_to_logprobs(
        self,
609
        hidden_states: tuple[tuple[torch.Tensor, ...], ...],
610
        num_logprobs: Optional[int],
611
    ) -> tuple[list[dict[int, float]], int]:
612
613
614
        seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
        output_len = len(hidden_states)

615
        # convert to dict
616
        seq_logprobs_lst: list[dict[int, float]] = []
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
            # drop prompt logprobs
            if tok_idx == 0:
                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
            topk = tok_logprobs.topk(num_logprobs)

            tok_logprobs_dct = {}
            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
                tok_logprobs_dct[token_id.item()] = logprob.item()

            seq_logprobs_lst.append(tok_logprobs_dct)

        return (
            seq_logprobs_lst,
            output_len,
        )

634
635
    def generate_greedy_logprobs_limit(
        self,
636
        prompts: list[str],
637
        max_tokens: int,
638
        num_logprobs: Optional[int],
639
640
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
Cyrus Leung's avatar
Cyrus Leung committed
641
        videos: Optional[PromptVideoInput] = None,
642
        **kwargs: Any,
643
    ) -> list[TokensTextLogprobs]:
644
645
646
647
648
        all_inputs = self.get_inputs(prompts,
                                     images=images,
                                     videos=videos,
                                     audios=audios)

649
650
651
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
652

653
        for inputs in all_inputs:
654
            output = self.model.generate(
655
                **self.wrap_device(inputs),
656
657
658
659
660
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
661
                **kwargs,
662
663
            )

664
665
666
667
668
669
670
671
672
673
674
675
            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.hidden_states,
                                                num_logprobs)

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_len = len(seq_logprobs_lst)
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))
676

677
678
679
680
681
682
        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

    def generate_encoder_decoder_greedy_logprobs_limit(
        self,
683
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
684
        max_tokens: int,
685
        num_logprobs: Optional[int],
686
        images: Optional[PromptImageInput] = None,
687
        **kwargs: Any,
688
    ) -> list[TokensTextLogprobs]:
689
690
691
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''
692

693
694
695
        all_logprobs: list[list[dict[int, float]]] = []
        all_output_ids: list[list[int]] = []
        all_output_strs: list[str] = []
696

697
698
        for i, (encoder_prompt, decoder_prompt) in enumerate(
                to_enc_dec_tuple_list(encoder_decoder_prompts)):
699
            processor_kwargs: dict[str, Any] = {
700
701
702
703
704
                "text": encoder_prompt,
                "return_tensors": "pt",
            }
            if images is not None and images[i] is not None:
                processor_kwargs["images"] = images[i]
705

706
707
            encoder_inputs = self.processor(**processor_kwargs)
            encoder_inputs = self.wrap_device(encoder_inputs)
708
709
710
711

            if decoder_prompt is None:
                decoder_input_ids = None
            else:
712
713
714
                decoder_inputs = self.tokenizer(decoder_prompt,
                                                return_tensors="pt")
                decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
715
716
717
718
719
720
721
722

            output = self.model.generate(
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                do_sample=False,
                max_new_tokens=max_tokens,
                output_hidden_states=True,
                return_dict_in_generate=True,
723
                **encoder_inputs,
724
725
726
727
728
729
730
731
                **kwargs,
            )

            (
                seq_logprobs_lst,
                output_len,
            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
                                                num_logprobs)
732
733
734
735
736
737
738
739
740
741
742

            all_logprobs.append(seq_logprobs_lst)
            seq_ids = output.sequences[0]
            output_ids = seq_ids[-output_len:]
            all_output_ids.append(output_ids.tolist())
            all_output_strs.append(self.tokenizer.decode(output_ids))

        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
        return [(output_ids, output_str, output_logprobs)
                for output_ids, output_str, output_logprobs in outputs]

743
744
745
    def encode(self, prompts: list[str], *args,
               **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)
746

747
748
749
750
751
752
    def predict(self, prompts: list[list[str]], *args,
                **kwargs) -> torch.Tensor:
        return self.model.predict(prompts,
                                  *args,
                                  convert_to_tensor=True,
                                  **kwargs)
753

754
755
756
757
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
758
        del self.model
759
        cleanup_dist_env_and_memory()
760

Woosuk Kwon's avatar
Woosuk Kwon committed
761

Cyrus Leung's avatar
Cyrus Leung committed
762
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
763
764
765
766
767
def hf_runner():
    return HfRunner


class VllmRunner:
768
769
    """
    The default value of some arguments have been modified from
770
    {class}`~vllm.LLM` as follows:
771

772
773
774
    - `trust_remote_code`: Set to `True` instead of `False` for convenience.
    - `seed`: Set to `0` instead of `None` for test reproducibility.
    - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
775
776
    - `block_size`: To reduce memory usage, set default to `64` if on XPU
        devices, otherwise default to `16`.
777
778
    - `enable_chunked_prefill`: Set to `False` instead of `None` for
      test reproducibility.
779
    - `enforce_eager`: Set to `False` to test CUDA graph.
780
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
781
782
783
784

    def __init__(
        self,
        model_name: str,
785
786
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
787
        tokenizer_name: Optional[str] = None,
788
        tokenizer_mode: str = "auto",
789
790
        trust_remote_code: bool = True,
        seed: Optional[int] = 0,
791
        max_model_len: Optional[int] = 1024,
792
        dtype: str = "auto",
793
        disable_log_stats: bool = True,
794
        tensor_parallel_size: int = 1,
795
        block_size: int = 16 if not torch.xpu.is_available() else 64,
796
        enable_chunked_prefill: Optional[bool] = False,
797
        swap_space: int = 4,
798
        enforce_eager: Optional[bool] = False,
799
        **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
800
    ) -> None:
801
        self.llm = LLM(
Woosuk Kwon's avatar
Woosuk Kwon committed
802
            model=model_name,
803
804
            runner=runner,
            convert=convert,
Woosuk Kwon's avatar
Woosuk Kwon committed
805
            tokenizer=tokenizer_name,
806
            tokenizer_mode=tokenizer_mode,
807
            trust_remote_code=trust_remote_code,
Woosuk Kwon's avatar
Woosuk Kwon committed
808
            dtype=dtype,
809
            seed=seed,
810
            swap_space=swap_space,
Cyrus Leung's avatar
Cyrus Leung committed
811
            enforce_eager=enforce_eager,
812
            disable_log_stats=disable_log_stats,
813
            tensor_parallel_size=tensor_parallel_size,
814
            max_model_len=max_model_len,
815
816
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
817
            **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
818
819
        )

820
    def get_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
821
        self,
822
        prompts: Union[list[str], list[torch.Tensor], list[int]],
823
        images: Optional[PromptImageInput] = None,
824
825
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
826
    ) -> list[TextPrompt]:
827

828
829
830
831
832
        if any(x is not None and len(x) != len(prompts)
               for x in [images, videos, audios]):
            raise ValueError(
                "All non-None multimodal inputs must have the same length as "
                "prompts")
833

834
835
836
837
838
839
840
841
842
        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio
843

844
            text_prompt_kwargs: dict[str, Any] = {
845
846
                "multi_modal_data": multi_modal_data or None
            }
847
848
849
850
851
852
853
            if isinstance(prompt, str):
                text_prompt_kwargs["prompt"] = prompt
            elif isinstance(prompt, list):
                text_prompt_kwargs["prompt_token_ids"] = prompt
            else:
                text_prompt_kwargs["prompt_embeds"] = prompt

854
            inputs.append(TextPrompt(**text_prompt_kwargs))
855
856
857
858
859

        return inputs

    def generate(
        self,
860
        prompts: Union[list[str], list[torch.Tensor]],
861
862
863
864
        sampling_params: SamplingParams,
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
865
        **kwargs: Any,
866
    ) -> list[tuple[list[list[int]], list[str]]]:
867
868
869
870
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
871

872
873
874
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
875

876
        outputs: list[tuple[list[list[int]], list[str]]] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
877
878
879
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
880
881
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
882
883
            for sample in req_output.outputs:
                output_str = sample.text
884
                output_ids = list(sample.token_ids)
885
                req_sample_output_ids.append(prompt_ids + output_ids)
886
                req_sample_output_strs.append((prompt_str or "") + output_str)
887
            outputs.append((req_sample_output_ids, req_sample_output_strs))
Woosuk Kwon's avatar
Woosuk Kwon committed
888
889
        return outputs

890
    @staticmethod
891
    def _final_steps_generate_w_logprobs(
892
893
894
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
895
        for req_output in req_outputs:
896
            assert len(req_output.outputs) > 0
897
898
            for sample in req_output.outputs:
                output_str = sample.text
899
                output_ids = list(sample.token_ids)
900
                output_logprobs = sample.logprobs
901
902
            outputs.append((output_ids, output_str, output_logprobs,
                            req_output.prompt_logprobs))
903
904
        return outputs

905
906
    def generate_w_logprobs(
        self,
907
        prompts: list[str],
908
        sampling_params: SamplingParams,
909
910
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
911
        videos: Optional[PromptVideoInput] = None,
912
        **kwargs: Any,
913
914
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
915
916
917
918
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)
919

920
921
922
        req_outputs = self.llm.generate(inputs,
                                        sampling_params=sampling_params,
                                        **kwargs)
923
924
925
926
927
928
929

        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
930
931
932

    def generate_encoder_decoder_w_logprobs(
        self,
933
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
934
        sampling_params: SamplingParams,
935
936
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
937
938
939
940
941
        '''
        Logprobs generation for vLLM encoder/decoder models
        '''

        assert sampling_params.logprobs is not None
942
943
        req_outputs = self.llm.generate(encoder_decoder_prompts,
                                        sampling_params=sampling_params)
944
945
946
947
948
949
        toks_str_logsprobs_prompt_logprobs = (
            self._final_steps_generate_w_logprobs(req_outputs))
        # Omit prompt logprobs if not required by sampling params
        return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
                if sampling_params.prompt_logprobs is None else
                toks_str_logsprobs_prompt_logprobs)
950

Woosuk Kwon's avatar
Woosuk Kwon committed
951
952
    def generate_greedy(
        self,
953
        prompts: Union[list[str], list[torch.Tensor]],
Woosuk Kwon's avatar
Woosuk Kwon committed
954
        max_tokens: int,
955
        images: Optional[PromptImageInput] = None,
956
957
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
958
        **kwargs: Any,
959
    ) -> list[tuple[list[int], str]]:
Woosuk Kwon's avatar
Woosuk Kwon committed
960
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
961
962
963
964
        outputs = self.generate(prompts,
                                greedy_params,
                                images=images,
                                videos=videos,
965
966
                                audios=audios,
                                **kwargs)
967
968
        return [(output_ids[0], output_str[0])
                for output_ids, output_str in outputs]
969

970
971
    def generate_greedy_logprobs(
        self,
972
        prompts: list[str],
973
        max_tokens: int,
974
        num_logprobs: Optional[int],
975
        num_prompt_logprobs: Optional[int] = None,
976
977
        images: Optional[PromptImageInput] = None,
        audios: Optional[PromptAudioInput] = None,
978
        videos: Optional[PromptVideoInput] = None,
979
980
        stop_token_ids: Optional[list[int]] = None,
        stop: Optional[list[str]] = None,
981
        **kwargs: Any,
982
983
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
984
985
986
987
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
988
            prompt_logprobs=num_prompt_logprobs,
989
990
            stop_token_ids=stop_token_ids,
            stop=stop)
991
992
993
994
995

        return self.generate_w_logprobs(prompts,
                                        greedy_logprobs_params,
                                        images=images,
                                        audios=audios,
996
997
                                        videos=videos,
                                        **kwargs)
998

999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
        """
        Return the perplexity score associated with generating the prompts

        :param prompts: list of prompts to score
        :return: perplexity score of each prompt
        """
        outputs = self.generate_greedy_logprobs(prompts,
                                                max_tokens=1,
                                                num_logprobs=None,
                                                num_prompt_logprobs=0)

        perplexities = []
        for output in outputs:
            output = cast(TokensTextLogprobsPromptLogprobs, output)
            token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
            assert token_datas[0] is None
            token_log_probs = []
            for token_data in token_datas[1:]:
                assert token_data is not None
                assert len(token_data) == 1
                token_log_prob = list(token_data.values())[0].logprob
                token_log_probs.append(token_log_prob)

            perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
            perplexities.append(perplexity)

        return perplexities

1028
1029
    def generate_encoder_decoder_greedy_logprobs(
        self,
1030
        encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
1031
        max_tokens: int,
1032
        num_logprobs: Optional[int],
1033
        num_prompt_logprobs: Optional[int] = None,
1034
        skip_special_tokens: bool = True,
1035
1036
    ) -> Union[list[TokensTextLogprobs],
               list[TokensTextLogprobsPromptLogprobs]]:
1037
1038
1039
1040
1041
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=(num_prompt_logprobs),
1042
            skip_special_tokens=skip_special_tokens,
1043
        )
1044
1045
1046
1047
        '''
        Greedy logprobs generation for vLLM encoder/decoder models
        '''

1048
        return self.generate_encoder_decoder_w_logprobs(
1049
1050
            encoder_decoder_prompts, greedy_logprobs_params)

1051
    def generate_beam_search(
1052
        self,
1053
        prompts: list[str],
1054
1055
        beam_width: int,
        max_tokens: int,
1056
1057
1058
        images: Optional[PromptImageInput] = None,
        videos: Optional[PromptVideoInput] = None,
        audios: Optional[PromptAudioInput] = None,
1059
        concurrency_limit: Optional[int] = None,
1060
    ) -> list[tuple[list[list[int]], list[str]]]:
1061
1062
1063
1064
1065
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1066
1067
1068
1069
        outputs = self.llm.beam_search(inputs,
                                       BeamSearchParams(beam_width=beam_width,
                                                        max_tokens=max_tokens),
                                       concurrency_limit=concurrency_limit)
1070
1071
1072
1073
1074
1075
1076
        returned_outputs = []
        for output in outputs:
            token_ids = [x.tokens for x in output.sequences]
            texts = [x.text for x in output.sequences]
            returned_outputs.append((token_ids, texts))
        return returned_outputs

1077
    def classify(self, prompts: list[str]) -> list[list[float]]:
1078
        req_outputs = self.llm.classify(prompts)
1079
1080
        return [req_output.outputs.probs for req_output in req_outputs]

1081
1082
1083
1084
1085
1086
1087
    def embed(self,
              prompts: list[str],
              images: Optional[PromptImageInput] = None,
              videos: Optional[PromptVideoInput] = None,
              audios: Optional[PromptAudioInput] = None,
              *args,
              **kwargs) -> list[list[float]]:
Cyrus Leung's avatar
Cyrus Leung committed
1088
1089
1090
1091
1092
        inputs = self.get_inputs(prompts,
                                 images=images,
                                 videos=videos,
                                 audios=audios)

1093
        req_outputs = self.llm.embed(inputs, *args, **kwargs)
Cyrus Leung's avatar
Cyrus Leung committed
1094
        return [req_output.outputs.embedding for req_output in req_outputs]
1095

1096
    def encode(self, prompts: list[str]) -> list[list[float]]:
1097
        req_outputs = self.llm.encode(prompts)
1098
1099
        return [req_output.outputs.data for req_output in req_outputs]

1100
1101
1102
1103
    def reward(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.llm.reward(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

1104
1105
    def score(
        self,
1106
1107
        text_1: Union[str, list[str]],
        text_2: Union[str, list[str]],
1108
1109
        *args,
        **kwargs,
1110
    ) -> list[float]:
1111
        req_outputs = self.llm.score(text_1, text_2, *args, **kwargs)
1112
        return [req_output.outputs.score for req_output in req_outputs]
1113

1114
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
        if hasattr(self.llm.llm_engine, "model_executor"):
            # This works either in V0 or in V1 with
            # VLLM_ENABLE_V1_MULTIPROCESSING=0
            executor = self.llm.llm_engine.model_executor
            return executor.apply_model(func)

        # This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
        def _apply_model(self):
            return func(self.get_model())

        return self.llm.llm_engine.collective_rpc(_apply_model)
1126

1127
1128
1129
    def get_llm(self) -> LLM:
        return self.llm

1130
1131
1132
1133
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
1134
        del self.llm
1135
        cleanup_dist_env_and_memory()
1136

Woosuk Kwon's avatar
Woosuk Kwon committed
1137

1138
@pytest.fixture(scope="session")
Woosuk Kwon's avatar
Woosuk Kwon committed
1139
1140
def vllm_runner():
    return VllmRunner
1141
1142


1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
@pytest.fixture()
def temporary_enable_log_propagate():
    import logging
    logger = logging.getLogger("vllm")
    logger.propagate = True
    yield
    logger.propagate = False


@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
    # To capture vllm log, we should enable propagate=True temporarily
    # because caplog depends on logs propagated to the root logger.
    yield caplog
1157
1158
1159
1160
1161
1162
1163


@pytest.fixture(scope="session")
def num_gpus_available():
    """Get number of GPUs without initializing the CUDA context
    in current process."""

1164
1165
    from vllm.platforms import current_platform
    return current_platform.device_count()
1166
1167


1168
# temp_dir = tempfile.gettempdir()
zhuwenwen's avatar
zhuwenwen committed
1169
1170
1171
1172
_dummy_opt_path = os.path.join(models_path_prefix, "dummy_opt")
_dummy_llava_path = os.path.join(models_path_prefix, "dummy_llava")
_dummy_gemma2_embedding_path = os.path.join(models_path_prefix, "dummy_gemma2_embedding")

1173
1174
1175
1176


@pytest.fixture
def dummy_opt_path():
1177
1178
    json_path = os.path.join(_dummy_opt_path, "config.json")
    if not os.path.exists(_dummy_opt_path):
1179
        snapshot_download(repo_id="facebook/opt-125m",
1180
                          local_dir=_dummy_opt_path,
1181
1182
1183
1184
1185
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1186
        with open(json_path) as f:
1187
1188
1189
1190
            config = json.load(f)
        config["architectures"] = ["MyOPTForCausalLM"]
        with open(json_path, "w") as f:
            json.dump(config, f)
1191
1192
    return _dummy_opt_path

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210

# 定义一个 pytest 钩子,在测试后生成报告
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
    # 获取测试结果
    outcome = yield
    result = outcome.get_result()

    # 如果测试失败并且有浏览器实例,添加截图
    if result.when == "call" and result.failed:
        if hasattr(item, "funcargs") and "browser" in item.funcargs:
            browser = item.funcargs["browser"]
            screenshot_path = "screenshot.png"  # 设置截图路径
            browser.save_screenshot(screenshot_path)

            # 如果测试结果有 extra 属性,则添加截图
            if hasattr(result, "extra"):
                result.extra.append(pytest_html.extras.image(screenshot_path))
zhuwenwen's avatar
zhuwenwen committed
1211
1212


1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
@pytest.fixture
def dummy_llava_path():
    json_path = os.path.join(_dummy_llava_path, "config.json")
    if not os.path.exists(_dummy_llava_path):
        snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf",
                          local_dir=_dummy_llava_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1224
        with open(json_path) as f:
1225
1226
1227
1228
1229
            config = json.load(f)
        config["architectures"] = ["MyLlava"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_llava_path
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242


@pytest.fixture
def dummy_gemma2_embedding_path():
    json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
    if not os.path.exists(_dummy_gemma2_embedding_path):
        snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
                          local_dir=_dummy_gemma2_embedding_path,
                          ignore_patterns=[
                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
                              "*.msgpack"
                          ])
        assert os.path.exists(json_path)
1243
        with open(json_path) as f:
1244
1245
1246
1247
1248
            config = json.load(f)
        config["architectures"] = ["MyGemma2Embedding"]
        with open(json_path, "w") as f:
            json.dump(config, f)
    return _dummy_gemma2_embedding_path
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
    parser.addoption("--optional",
                     action="store_true",
                     default=False,
                     help="run optional test")


def pytest_collection_modifyitems(config, items):
    if config.getoption("--optional"):
        # --optional given in cli: do not skip optional tests
        return
    skip_optional = pytest.mark.skip(reason="need --optional option to run")
    for item in items:
        if "optional" in item.keywords:
            item.add_marker(skip_optional)
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279


@pytest.fixture(scope="session")
def cli_config_file():
    """Return the path to the CLI config file."""
    return os.path.join(_TEST_DIR, "config", "test_config.yaml")


@pytest.fixture(scope="session")
def cli_config_file_with_model():
    """Return the path to the CLI config file with model."""
    return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")