test_oot_registration.py 3.31 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import pytest
4

5
from vllm import LLM, SamplingParams
6
from vllm.assets.image import ImageAsset
7

8
from ..utils import create_new_process_for_each_test
9
10


11
@create_new_process_for_each_test()
12
13
14
15
def test_plugin(
    monkeypatch: pytest.MonkeyPatch,
    dummy_opt_path: str,
):
16
    # V1 shuts down rather than raising an error here.
17
18
19
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "0")
        m.setenv("VLLM_PLUGINS", "")
20

21
22
23
24
        with pytest.raises(Exception) as excinfo:
            LLM(model=dummy_opt_path, load_format="dummy")
        error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM"  # noqa: E501
        assert (error_msg in str(excinfo.value))
25
26


27
@create_new_process_for_each_test()
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def test_oot_registration_text_generation(
    monkeypatch: pytest.MonkeyPatch,
    dummy_opt_path: str,
):
    with monkeypatch.context() as m:
        m.setenv("VLLM_PLUGINS", "register_dummy_model")
        prompts = ["Hello, my name is", "The text does not matter"]
        sampling_params = SamplingParams(temperature=0)
        llm = LLM(model=dummy_opt_path, load_format="dummy")
        first_token = llm.get_tokenizer().decode(0)
        outputs = llm.generate(prompts, sampling_params)

        for output in outputs:
            generated_text = output.outputs[0].text
            # make sure only the first token is generated
            rest = generated_text.replace(first_token, "")
            assert rest == ""
45
46


47
@create_new_process_for_each_test()
48
49
50
51
52
53
54
55
56
def test_oot_registration_embedding(
    monkeypatch: pytest.MonkeyPatch,
    dummy_gemma2_embedding_path: str,
):
    with monkeypatch.context() as m:
        m.setenv("VLLM_PLUGINS", "register_dummy_model")
        prompts = ["Hello, my name is", "The text does not matter"]
        llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
        outputs = llm.embed(prompts)
57

58
59
        for output in outputs:
            assert all(v == 0 for v in output.outputs.embedding)
60
61


62
63
64
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")


65
@create_new_process_for_each_test()
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def test_oot_registration_multimodal(
    monkeypatch: pytest.MonkeyPatch,
    dummy_llava_path: str,
):
    with monkeypatch.context() as m:
        m.setenv("VLLM_PLUGINS", "register_dummy_model")
        prompts = [{
            "prompt": "What's in the image?<image>",
            "multi_modal_data": {
                "image": image
            },
        }, {
            "prompt": "Describe the image<image>",
            "multi_modal_data": {
                "image": image
            },
        }]

        sampling_params = SamplingParams(temperature=0)
        llm = LLM(model=dummy_llava_path,
                  load_format="dummy",
                  max_num_seqs=1,
                  trust_remote_code=True,
                  gpu_memory_utilization=0.98,
                  max_model_len=4096,
                  enforce_eager=True,
                  limit_mm_per_prompt={"image": 1})
93

94
95
96
97
98
99
100
101
        first_token = llm.get_tokenizer().decode(0)
        outputs = llm.generate(prompts, sampling_params)

        for output in outputs:
            generated_text = output.outputs[0].text
            # make sure only the first token is generated
            rest = generated_text.replace(first_token, "")
            assert rest == ""