test_jinavl_reranker.py 7.87 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Union
4
5
6
7

import pytest
from transformers import AutoModel

8
9
10
11
12
from vllm.entrypoints.chat_utils import ChatCompletionContentPartImageParam
from vllm.entrypoints.score_utils import ScoreMultiModalParam

from ....conftest import HfRunner, VllmRunner

13
14
15
16
17
18
19
20
21
22
model_name = "jinaai/jina-reranker-m0"

mm_processor_kwargs = {
    "min_pixels": 3136,
    "max_pixels": 602112,
}

limit_mm_per_prompt = {"image": 2}


23
24
25
26
27
28
29
30
31
def vllm_reranker(
    vllm_runner: type[VllmRunner],
    model_name: str,
    dtype: str,
    query_strs: list[str],
    document_strs: list[str],
    query_type: str = "text",
    doc_type: str = "text",
):
32

33
    def create_image_param(url: str) -> ChatCompletionContentPartImageParam:
34
35
        return {"type": "image_url", "image_url": {"url": f"{url}"}}

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    query: Union[list[str], ScoreMultiModalParam]
    if query_type == "text":
        query = query_strs
    elif query_type == "image":
        query = ScoreMultiModalParam(
            content=[create_image_param(url) for url in query_strs])

    documents: Union[list[str], ScoreMultiModalParam]
    if doc_type == "text":
        documents = document_strs
    elif doc_type == "image":
        documents = ScoreMultiModalParam(
            content=[create_image_param(url) for url in document_strs])

    with vllm_runner(
            model_name,
52
            runner="pooling",
53
54
55
56
57
58
            dtype=dtype,
            max_num_seqs=2,
            max_model_len=2048,
            mm_processor_kwargs=mm_processor_kwargs,
            limit_mm_per_prompt=limit_mm_per_prompt,
    ) as vllm_model:
59
        outputs = vllm_model.llm.score(query, documents)
60
61
62
63

    return [output.outputs.score for output in outputs]


64
65
66
67
68
69
70
71
72
def hf_reranker(
    hf_runner: type[HfRunner],
    model_name: str,
    dtype: str,
    query_strs: list[str],
    document_strs: list[str],
    query_type: str = "text",
    doc_type: str = "text",
):
73
74
75
76
77
    checkpoint_to_hf_mapper = {
        "visual.": "model.visual.",
        "model.": "model.language_model.",
    }

78
    data_pairs = [[query_strs[0], d] for d in document_strs]
79

80
81
82
83
84
85
86
87
88
89
90
    with hf_runner(
            model_name,
            dtype=dtype,
            trust_remote_code=True,
            auto_cls=AutoModel,
            model_kwargs={"key_mapping": checkpoint_to_hf_mapper},
    ) as hf_model:
        return hf_model.model.compute_score(data_pairs,
                                            max_length=2048,
                                            query_type=query_type,
                                            doc_type=doc_type)
91
92
93
94


# Visual Documents Reranking
@pytest.mark.parametrize("model_name", [model_name])
95
96
@pytest.mark.parametrize("dtype", ["half"])
def test_model_text_image(hf_runner, vllm_runner, model_name, dtype):
97
98
99
100
101
102
    query = ["slm markdown"]
    documents = [
        "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
        "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
    ]

103
104
105
106
    hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
                             "text", "image")
    vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
                                 documents, "text", "image")
107
108
109
110
111
112
113

    assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
    assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)


# Textual Documents Reranking
@pytest.mark.parametrize("model_name", [model_name])
114
115
@pytest.mark.parametrize("dtype", ["half"])
def test_model_text_text(hf_runner, vllm_runner, model_name, dtype):
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    query = ["slm markdown"]
    documents = [
        """We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient 
        web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML 
        into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding 
        large language models. The models effectiveness results from two key innovations: (1) a three-stage 
        data synthesis pipeline that generates high quality, diverse training data by iteratively drafting, 
        refining, and critiquing web content extraction; and (2) a unified training framework combining 
        continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that 
        ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated 
        benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly 
        lower computational requirements.""",  # noqa: E501
        "数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
    ]
130
131
132
133
    hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
                             "text", "text")
    vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
                                 documents, "text", "text")
134
135
136
137
138
139
140

    assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
    assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)


# Image Querying for Textual Documents
@pytest.mark.parametrize("model_name", [model_name])
141
142
@pytest.mark.parametrize("dtype", ["half"])
def test_model_image_text(hf_runner, vllm_runner, model_name, dtype):
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    query = [
        "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
    ]
    documents = [
        """We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient
        web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML
        into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding
        large language models. The models effectiveness results from two key innovations: (1) a three-stage
        data synthesis pipeline that generates high quality, diverse training data by iteratively drafting,
        refining, and critiquing web content extraction; and (2) a unified training framework combining
        continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that
        ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated
        benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly
        lower computational requirements.""",  # noqa: E501
        "数据提取么?为什么不用正则啊,你用正则不就全解决了么?",
    ]

160
161
162
163
    hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
                             "image", "text")
    vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
                                 documents, "image", "text")
164
165
166
167
168
169
170

    assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
    assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)


# Image Querying for Image Documents
@pytest.mark.parametrize("model_name", [model_name])
171
172
@pytest.mark.parametrize("dtype", ["half"])
def test_model_image_image(hf_runner, vllm_runner, model_name, dtype):
173
174
175
176
177
178
179
180
    query = [
        "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
    ]
    documents = [
        "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png",
        "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png",
    ]

181
182
183
184
    hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents,
                             "image", "image")
    vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query,
                                 documents, "image", "image")
185
186
187

    assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02)
    assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02)