"docs/getting_started/installation/cpu.arm.inc.md" did not exist on "1ad69e8375e841095c2f682299be487fd9b8f47e"
score_utils.py 8.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, TypeAlias, cast
4
5

from torch.nn import CosineSimilarity
6
from typing_extensions import Required, TypedDict
7

8
from vllm.config import ModelConfig
9
from vllm.entrypoints.chat_utils import (
10
11
12
13
    BaseMultiModalItemTracker,
    ChatCompletionContentPartImageEmbedsParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartTextParam,
14
    ChatTemplateResolutionError,
15
16
17
    MultiModalItemTracker,
    _ContentPart,
    _parse_chat_message_content_part,
18
    apply_hf_chat_template,
19
)
20
21
22
from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
23
from vllm.outputs import PoolingRequestOutput
24
from vllm.tokenizers import TokenizerLike
25

26
27
28
ScoreContentPartParam: TypeAlias = (
    ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam
)
29
30
31
32
33


class ScoreMultiModalParam(TypedDict, total=False):
    """
    A specialized parameter type for scoring multimodal content
34

35
36
37
38
    The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
    1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
    2. Including chat-specific fields would confuse users about their purpose in scoring
    3. This is a more focused interface that only exposes what's needed for scoring
39
40
    """  # noqa: E501

41
42
43
    content: Required[list[ScoreContentPartParam]]
    """The multimodal contents"""

44
45

def _cosine_similarity(
46
    tokenizer: TokenizerLike,
47
48
49
    embed_1: list[PoolingRequestOutput],
    embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
50
    scorer = CosineSimilarity(0)
51
    scores: list[PoolingRequestOutput] = []
52
53
54
55

    for emb_1, emb_2 in zip(embed_1, embed_2):
        pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)

56
57
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
58
59
60
61
62
63
64
65
66
            padding = [pad_token_id]

        tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

        scores.append(
            PoolingRequestOutput(
                request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                outputs=pair_score,
                prompt_token_ids=tokens,
67
                num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
68
69
70
                finished=True,
            )
        )
71
72
73
74
75

    return scores


def _validate_score_input_lens(
76
77
    data_1: list[str] | list[ScoreContentPartParam],
    data_2: list[str] | list[ScoreContentPartParam],
78
):
79
80
81
82
    len_1 = len(data_1)
    len_2 = len(data_2)

    if len_1 > 1 and len_1 != len_2:
83
        raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
84
    if len_1 == 0:
85
        raise ValueError("At least one text element must be given")
86
87
88
89
90
    if len_2 == 0:
        raise ValueError("At least one text_pair element must be given")


def parse_score_data(
91
92
    data_1: str | ScoreContentPartParam,
    data_2: str | ScoreContentPartParam,
93
    model_config: ModelConfig,
94
) -> tuple[str, str, MultiModalDataDict | None]:
95
    mm_tracker = MultiModalItemTracker(model_config)
96
97
98
99

    content_1 = _parse_score_content(data_1, mm_tracker)
    content_2 = _parse_score_content(data_2, mm_tracker)

100
    def ensure_str(content: _ContentPart | None) -> str:
101
102
103
        if content is not None and isinstance(content, str):
            return cast(str, content)
        else:
104
            raise ValueError(f"Only string content is supported, but got {content}.")
105
106
107
108
109
110
111
112

    prompt_1 = ensure_str(content_1)
    prompt_2 = ensure_str(content_2)

    return prompt_1, prompt_2, mm_tracker.all_mm_data()


def _parse_score_content(
113
    data: str | ScoreContentPartParam,
114
    mm_tracker: BaseMultiModalItemTracker,
115
) -> _ContentPart | None:
116
    if isinstance(data, str):
117
118
119
        part = ChatCompletionContentPartTextParam(type="text", text=data)
    else:
        part = data
120
121
122
123

    mm_parser = mm_tracker.create_parser()

    parse_res = _parse_chat_message_content_part(
124
        part,
125
126
127
128
129
130
131
132
133
134
        mm_parser,
        wrap_dicts=False,
        interleave_strings=False,
    )

    if parse_res:
        return parse_res

    mm_placeholder_storage = mm_parser.mm_placeholder_storage()

135
136
137
138
    if (
        len(mm_placeholder_storage) != 1
        or len(next(iter(mm_placeholder_storage.values()))) != 1
    ):
139
140
141
142
143
        raise ValueError("Only one multi-modal item is supported")

    return next(iter(mm_placeholder_storage.values()))[0]


144
145
def _apply_model_score_template(
    model_config: ModelConfig, prompt_1: str, prompt_2: str
146
) -> str:
Simon Mo's avatar
Simon Mo committed
147
148
    # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
    from vllm.model_executor.model_loader import get_model_cls
149
150
151
152
153
154
155
156

    model = get_model_cls(model_config)
    if supports_score_template(model):
        full_prompt = model.get_score_template(prompt_1, prompt_2)
        if full_prompt is None:
            raise ValueError("Get empty score template from model")
        return full_prompt

157
    raise ValueError(f"Unsupported model architecture: {model_config.architecture}")
158
159
160
161
162
163
164
165


def post_process_tokens(
    model_config: ModelConfig,
    prompt: TokensPrompt,
) -> None:
    """
    Perform architecture-specific manipulations on the input tokens.
166

167
168
169
    Note:
        This is an in-place operation.
    """
Simon Mo's avatar
Simon Mo committed
170
171
172
    # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
    from vllm.model_executor.model_loader import get_model_cls

173
174
175
176
177
178
    model = get_model_cls(model_config)
    if supports_score_template(model):
        model.post_process_tokens(prompt)


def get_score_prompt(
179
    model_config: ModelConfig,
180
    tokenizer: TokenizerLike,
181
    tokenization_kwargs: dict[str, Any],
182
183
    data_1: str | ScoreContentPartParam,
    data_2: str | ScoreContentPartParam,
184
    score_template: str | None = None,
185
186
187
188
) -> tuple[str, TokensPrompt]:
    prompt_1, prompt_2, mm_data = parse_score_data(
        data_1,
        data_2,
189
        model_config,
190
    )
191
    from vllm.model_executor.model_loader import get_model_cls
192

193
    model = get_model_cls(model_config)
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

    def default_tokenizer_encode():
        if supports_score_template(model):
            full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
            prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
        else:
            if model_config.use_pad_token:
                # cross_encoder models defaults to using pad_token.
                prompt_inputs = tokenizer(
                    text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
                )
                full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
            else:
                # `llm as reranker` models defaults to not using pad_token.
                full_prompt = prompt_1 + prompt_2
                prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
        return full_prompt, prompt_inputs

    # FIXME: For now, we only apply a template when one is explicitly provided.
    # We cannot rely on the tokenizer's chat template because many models
    # inherit junk templates from their base LLM, which breaks both the models
    # and the tests that use them.
    if score_template is None:
        full_prompt, prompt_inputs = default_tokenizer_encode()
218
    else:
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        # FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
        # If that fails because there is no such template,
        # fall back to the default implementation.
        try:
            full_prompt = apply_hf_chat_template(
                tokenizer,
                [
                    {"role": "query", "content": prompt_1},
                    {"role": "document", "content": prompt_2},
                ],
                score_template,
                tools=None,
                model_config=model_config,
            )
            prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
        except ChatTemplateResolutionError:
            full_prompt, prompt_inputs = default_tokenizer_encode()
236
237
238

    engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])

239
240
241
    if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
        engine_prompt["token_type_ids"] = token_type_ids

242
243
244
245
246
    post_process_tokens(model_config, engine_prompt)

    if mm_data is not None:
        engine_prompt["multi_modal_data"] = mm_data
    return full_prompt, engine_prompt
247
248
249
250
251
252
253
254


def compress_token_type_ids(token_type_ids: list[int]) -> int:
    """
    Return position of the first 1 or the length of the list
    if not found.
    """
    first_one = len(token_type_ids)
255
256
257
258
    err_msg = (
        "Token type ids are expected to be a sequence"
        " of zeros followed by a sequence of ones"
    )
259
260
261
262
263
264
265
266
267
    for i, type_id in enumerate(token_type_ids):
        if type_id == 0 and first_one < i:
            raise ValueError(err_msg)
        elif type_id == 1 and first_one > i:
            first_one = i
        elif type_id > 1:
            raise ValueError(err_msg)

    return first_one