test_oot_registration.py 2.82 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4

5
import pytest
6

7
from vllm import LLM, SamplingParams
8
from vllm.assets.image import ImageAsset
9

10
from ..utils import fork_new_process_for_each_test
11
12


13
@fork_new_process_for_each_test
14
15
16
def test_plugin(dummy_opt_path, monkeypatch):
    # V1 shuts down rather than raising an error here.
    monkeypatch.setenv("VLLM_USE_V1", "0")
17
18
19
    os.environ["VLLM_PLUGINS"] = ""
    with pytest.raises(Exception) as excinfo:
        LLM(model=dummy_opt_path, load_format="dummy")
20
    error_msg = "has no vLLM implementation and " \
21
                "the Transformers implementation is not compatible with vLLM"
22
    assert (error_msg in str(excinfo.value))
23
24


25
@fork_new_process_for_each_test
26
def test_oot_registration_text_generation(dummy_opt_path):
27
    os.environ["VLLM_PLUGINS"] = "register_dummy_model"
28
29
    prompts = ["Hello, my name is", "The text does not matter"]
    sampling_params = SamplingParams(temperature=0)
30
    llm = LLM(model=dummy_opt_path, load_format="dummy")
31
32
33
34
35
36
37
38
    first_token = llm.get_tokenizer().decode(0)
    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        generated_text = output.outputs[0].text
        # make sure only the first token is generated
        rest = generated_text.replace(first_token, "")
        assert rest == ""
39
40


41
42
43
44
45
@fork_new_process_for_each_test
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
    os.environ["VLLM_PLUGINS"] = "register_dummy_model"
    prompts = ["Hello, my name is", "The text does not matter"]
    llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
46
    outputs = llm.embed(prompts)
47
48
49
50
51

    for output in outputs:
        assert all(v == 0 for v in output.outputs.embedding)


52
53
54
55
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")


@fork_new_process_for_each_test
56
def test_oot_registration_multimodal(dummy_llava_path, monkeypatch):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    os.environ["VLLM_PLUGINS"] = "register_dummy_model"
    prompts = [{
        "prompt": "What's in the image?<image>",
        "multi_modal_data": {
            "image": image
        },
    }, {
        "prompt": "Describe the image<image>",
        "multi_modal_data": {
            "image": image
        },
    }]

    sampling_params = SamplingParams(temperature=0)
    llm = LLM(model=dummy_llava_path,
              load_format="dummy",
              max_num_seqs=1,
              trust_remote_code=True,
              gpu_memory_utilization=0.98,
              max_model_len=4096,
              enforce_eager=True,
              limit_mm_per_prompt={"image": 1})
    first_token = llm.get_tokenizer().decode(0)
    outputs = llm.generate(prompts, sampling_params)

    for output in outputs:
        generated_text = output.outputs[0].text
        # make sure only the first token is generated
        rest = generated_text.replace(first_token, "")
        assert rest == ""