test_omni_input_preprocessor.py 2.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from vllm_omni.inputs.preprocess import OmniInputPreprocessor


def _make_preprocessor(monkeypatch):
    preprocessor = object.__new__(OmniInputPreprocessor)
    monkeypatch.setattr(preprocessor, "_truncate_inputs", lambda tokens, tokenization_kwargs=None: tokens)
    monkeypatch.setattr(
        preprocessor,
        "_process_multimodal",
        lambda *args, **kwargs: {"prompt_token_ids": [1, 2, 3]},
    )
    monkeypatch.setattr(preprocessor, "_tokenize_prompt", lambda prompt_text, tokenization_kwargs=None: [9, 8, 7])
    return preprocessor


def test_process_tokens_keeps_additional_information(monkeypatch):
    preprocessor = _make_preprocessor(monkeypatch)
    parsed = {
        "prompt_token_ids": [1, 2, 3],
        "prompt_embeds": "embeds",
        "additional_information": {"task": ["tts"], "lang": ["auto"]},
    }

    inputs = OmniInputPreprocessor._process_tokens(preprocessor, parsed)

    assert inputs["prompt_token_ids"] == [1, 2, 3]
    assert inputs["prompt_embeds"] == "embeds"
    assert inputs["additional_information"] == {"task": ["tts"], "lang": ["auto"]}


def test_process_text_keeps_additional_information(monkeypatch):
    preprocessor = _make_preprocessor(monkeypatch)
    parsed = {
        "prompt": "hello",
        "prompt_embeds": "embeds",
        "additional_information": {"speaker": ["alice"]},
    }

    inputs = OmniInputPreprocessor._process_text(preprocessor, parsed)

    assert inputs["prompt_token_ids"] == [9, 8, 7]
    assert inputs["prompt_embeds"] == "embeds"
    assert inputs["additional_information"] == {"speaker": ["alice"]}


def test_process_text_multimodal_skips_empty_payloads(monkeypatch):
    preprocessor = _make_preprocessor(monkeypatch)
    parsed = {
        "prompt": "hello",
        "multi_modal_data": {"image": "fake"},
        "prompt_embeds": None,
        "additional_information": None,
    }

    inputs = OmniInputPreprocessor._process_text(preprocessor, parsed)

    assert inputs["prompt_token_ids"] == [1, 2, 3]
    assert "prompt_embeds" not in inputs
    assert "additional_information" not in inputs