score_utils.py 7.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional, Union, cast
4
5

from torch.nn import CosineSimilarity
6
from typing_extensions import Required, TypeAlias, TypedDict
7

8
9
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
10
11
12
13
14
15
16
17
    BaseMultiModalItemTracker,
    ChatCompletionContentPartImageEmbedsParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartTextParam,
    MultiModalItemTracker,
    _ContentPart,
    _parse_chat_message_content_part,
)
18
19
20
from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
21
from vllm.outputs import PoolingRequestOutput
22
23
24
25
26
from vllm.transformers_utils.tokenizer import (
    AnyTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)
27

28
ScoreContentPartParam: TypeAlias = Union[
29
30
    ChatCompletionContentPartImageParam, ChatCompletionContentPartImageEmbedsParam
]
31
32
33
34
35


class ScoreMultiModalParam(TypedDict, total=False):
    """
    A specialized parameter type for scoring multimodal content
36

37
38
39
40
    The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
    1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
    2. Including chat-specific fields would confuse users about their purpose in scoring
    3. This is a more focused interface that only exposes what's needed for scoring
41
42
    """  # noqa: E501

43
44
45
    content: Required[list[ScoreContentPartParam]]
    """The multimodal contents"""

46
47
48

def _cosine_similarity(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
49
50
51
    embed_1: list[PoolingRequestOutput],
    embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
52
    scorer = CosineSimilarity(0)
53
    scores: Union[list[PoolingRequestOutput]] = []
54
55
56
57
58

    for emb_1, emb_2 in zip(embed_1, embed_2):
        pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)

        padding = []
59
        if (pad_token_id := getattr(tokenizer, "pad_token_id", None)) is not None:
60
61
62
63
64
65
66
67
68
            padding = [pad_token_id]

        tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

        scores.append(
            PoolingRequestOutput(
                request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                outputs=pair_score,
                prompt_token_ids=tokens,
69
70
71
                finished=True,
            )
        )
72
73
74
75
76

    return scores


def _validate_score_input_lens(
77
78
    data_1: Union[list[str], list[ScoreContentPartParam]],
    data_2: Union[list[str], list[ScoreContentPartParam]],
79
):
80
81
82
83
    len_1 = len(data_1)
    len_2 = len(data_2)

    if len_1 > 1 and len_1 != len_2:
84
        raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
85
    if len_1 == 0:
86
        raise ValueError("At least one text element must be given")
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    if len_2 == 0:
        raise ValueError("At least one text_pair element must be given")


def parse_score_data(
    data_1: Union[str, ScoreContentPartParam],
    data_2: Union[str, ScoreContentPartParam],
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
) -> tuple[str, str, Optional[MultiModalDataDict]]:
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)

    content_1 = _parse_score_content(data_1, mm_tracker)

    content_2 = _parse_score_content(data_2, mm_tracker)

    def ensure_str(content: Optional[_ContentPart]) -> str:
        if content is not None and isinstance(content, str):
            return cast(str, content)
        else:
107
            raise ValueError(f"Only string content is supported, but got {content}.")
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    prompt_1 = ensure_str(content_1)
    prompt_2 = ensure_str(content_2)

    return prompt_1, prompt_2, mm_tracker.all_mm_data()


def _parse_score_content(
    data: Union[str, ScoreContentPartParam],
    mm_tracker: BaseMultiModalItemTracker,
) -> Optional[_ContentPart]:
    if isinstance(data, str):
        data = ChatCompletionContentPartTextParam(type="text", text=data)

    mm_parser = mm_tracker.create_parser()

    parse_res = _parse_chat_message_content_part(
        data,
        mm_parser,
        wrap_dicts=False,
        interleave_strings=False,
    )

    if parse_res:
        return parse_res

    mm_placeholder_storage = mm_parser.mm_placeholder_storage()

136
137
138
139
    if (
        len(mm_placeholder_storage) != 1
        or len(next(iter(mm_placeholder_storage.values()))) != 1
    ):
140
141
142
143
144
145
146
147
148
149
        raise ValueError("Only one multi-modal item is supported")

    return next(iter(mm_placeholder_storage.values()))[0]


def apply_score_template(
    model_config: ModelConfig,
    prompt_1: str,
    prompt_2: str,
) -> str:
Simon Mo's avatar
Simon Mo committed
150
151
    # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
    from vllm.model_executor.model_loader import get_model_cls
152
153
154
155
156
157
158
159

    model = get_model_cls(model_config)
    if supports_score_template(model):
        full_prompt = model.get_score_template(prompt_1, prompt_2)
        if full_prompt is None:
            raise ValueError("Get empty score template from model")
        return full_prompt

160
    raise ValueError(f"Unsupported model architecture: {model_config.architecture}")
161
162
163
164
165
166
167
168


def post_process_tokens(
    model_config: ModelConfig,
    prompt: TokensPrompt,
) -> None:
    """
    Perform architecture-specific manipulations on the input tokens.
169

170
171
172
    Note:
        This is an in-place operation.
    """
Simon Mo's avatar
Simon Mo committed
173
174
175
    # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
    from vllm.model_executor.model_loader import get_model_cls

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    model = get_model_cls(model_config)
    if supports_score_template(model):
        model.post_process_tokens(prompt)


def get_score_prompt(
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
    tokenization_kwargs: dict[str, Any],
    data_1: Union[str, ScoreContentPartParam],
    data_2: Union[str, ScoreContentPartParam],
) -> tuple[str, TokensPrompt]:
    prompt_1, prompt_2, mm_data = parse_score_data(
        data_1,
        data_2,
        model_config,
        tokenizer,
    )
194
    from vllm.model_executor.model_loader import get_model_cls
195

196
197
198
199
200
201
    model = get_model_cls(model_config)
    if supports_score_template(model):
        full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
        prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
    elif model_config.use_pad_token:
        # cross_encoder models defaults to using pad_token.
202
203
204
        prompt_inputs = tokenizer(
            text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
        )
205
206
207
208
209
        full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
    else:
        # `llm as reranker` models defaults to not using pad_token.
        full_prompt = prompt_1 + prompt_2
        prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
210
211
212

    engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])

213
214
215
    if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
        engine_prompt["token_type_ids"] = token_type_ids

216
217
218
219
220
    post_process_tokens(model_config, engine_prompt)

    if mm_data is not None:
        engine_prompt["multi_modal_data"] = mm_data
    return full_prompt, engine_prompt
221
222
223
224
225
226
227
228


def compress_token_type_ids(token_type_ids: list[int]) -> int:
    """
    Return position of the first 1 or the length of the list
    if not found.
    """
    first_one = len(token_type_ids)
229
230
231
232
    err_msg = (
        "Token type ids are expected to be a sequence"
        " of zeros followed by a sequence of ones"
    )
233
234
235
236
237
238
239
240
241
    for i, type_id in enumerate(token_type_ids):
        if type_id == 0 and first_one < i:
            raise ValueError(err_msg)
        elif type_id == 1 and first_one > i:
            first_one = i
        elif type_id > 1:
            raise ValueError(err_msg)

    return first_one