test_minicpmv_tp.py 3.78 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
from typing import List

import pytest

import vllm
8
from tests.utils import fork_new_process_for_each_test
9
10
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
11
from vllm.platforms import current_platform
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

PROMPT_TEMPLATE = (
    "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
    "(<image>./</image>)\nWhat is in the image?<|eot_id|>"
    "<|start_header_id|>assistant<|end_header_id|>\n\n")

IMAGE_ASSETS = [
    ImageAsset("stop_sign"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
    "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.",  # noqa: E501
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
    sampling_params = vllm.SamplingParams(
        temperature=0,
        max_tokens=5,
        stop_token_ids=[128001, 128009],  # eos_id, eot_id
    )

    inputs = [{
        "prompt": PROMPT_TEMPLATE,
        "multi_modal_data": {
            "image": asset.pil_image
        },
    } for asset in IMAGE_ASSETS]

    outputs = llm.generate(
        inputs,
        sampling_params,
        lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
        if lora_id else None,
    )
    # Print the outputs.
    generated_texts: List[str] = []
    for output in outputs:
        generated_text = output.outputs[0].text.strip()
        generated_texts.append(generated_text)
55
        print(f"Generated text: {generated_text!r}")
56
57
58
    return generated_texts


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@pytest.mark.xfail(
    current_platform.is_rocm(),
    reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_lora(minicpmv_lora_files):
    llm = vllm.LLM(
        MODEL_PATH,
        max_num_seqs=2,
        enable_lora=True,
        max_loras=2,
        max_lora_rank=8,
        enforce_eager=True,
        trust_remote_code=True,
        enable_chunked_prefill=True,
    )
    output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
    for i in range(len(EXPECTED_OUTPUT)):
        assert EXPECTED_OUTPUT[i].startswith(output1[i])
    output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
    for i in range(len(EXPECTED_OUTPUT)):
        assert EXPECTED_OUTPUT[i].startswith(output2[i])


@pytest.mark.xfail(
    current_platform.is_rocm(),
    reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
87
88
89
90
91
92
    llm = vllm.LLM(
        MODEL_PATH,
        enable_lora=True,
        max_num_seqs=2,
        max_loras=4,
        max_lora_rank=64,
93
        tensor_parallel_size=4,
94
        trust_remote_code=True,
95
        enforce_eager=True,
96
        enable_chunked_prefill=True,
97
98
99
100
101
102
    )
    output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
    for i in range(len(EXPECTED_OUTPUT)):
        assert EXPECTED_OUTPUT[i].startswith(output_tp[i])


103
104
105
106
107
@pytest.mark.xfail(
    current_platform.is_rocm(),
    reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
108
109
110
111
    llm = vllm.LLM(
        MODEL_PATH,
        enable_lora=True,
        max_num_seqs=2,
112
113
        max_loras=2,
        max_lora_rank=8,
114
115
        tensor_parallel_size=4,
        trust_remote_code=True,
116
        fully_sharded_loras=True,
117
        enable_chunked_prefill=True,
118
119
120
121
    )
    output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
    for i in range(len(EXPECTED_OUTPUT)):
        assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
122
123
124
    output_tp = do_sample(llm, minicpmv_lora_files, lora_id=2)
    for i in range(len(EXPECTED_OUTPUT)):
        assert EXPECTED_OUTPUT[i].startswith(output_tp[i])