"docs/vscode:/vscode.git/clone" did not exist on "138d891d7f42004c417561050a6813792316b13b"
score_utils.py 7.82 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any, Optional, Union, cast
4
5

from torch.nn import CosineSimilarity
6
from typing_extensions import Required, TypeAlias, TypedDict
7

8
9
10
11
12
13
14
15
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
    BaseMultiModalItemTracker, ChatCompletionContentPartImageEmbedsParam,
    ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam,
    MultiModalItemTracker, _ContentPart, _parse_chat_message_content_part)
from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
16
from vllm.outputs import PoolingRequestOutput
17
18
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               PreTrainedTokenizer,
19
20
                                               PreTrainedTokenizerFast)

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
ScoreContentPartParam: TypeAlias = Union[
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartImageEmbedsParam]


class ScoreMultiModalParam(TypedDict, total=False):
    """
    A specialized parameter type for scoring multimodal content
    
    The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
    1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
    2. Including chat-specific fields would confuse users about their purpose in scoring
    3. This is a more focused interface that only exposes what's needed for scoring
    """ # noqa: E501
    content: Required[list[ScoreContentPartParam]]
    """The multimodal contents"""

38
39
40

def _cosine_similarity(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
41
42
43
    embed_1: list[PoolingRequestOutput],
    embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
44
45

    scorer = CosineSimilarity(0)
46
    scores: Union[list[PoolingRequestOutput]] = []
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    for emb_1, emb_2 in zip(embed_1, embed_2):
        pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)

        padding = []
        if (pad_token_id := getattr(tokenizer, "pad_token_id",
                                    None)) is not None:
            padding = [pad_token_id]

        tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

        scores.append(
            PoolingRequestOutput(
                request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                outputs=pair_score,
                prompt_token_ids=tokens,
                finished=True))

    return scores


def _validate_score_input_lens(
69
70
    data_1: Union[list[str], list[ScoreContentPartParam]],
    data_2: Union[list[str], list[ScoreContentPartParam]],
71
):
72
73
74
75
    len_1 = len(data_1)
    len_2 = len(data_2)

    if len_1 > 1 and len_1 != len_2:
76
        raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
77
    if len_1 == 0:
78
        raise ValueError("At least one text element must be given")
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    if len_2 == 0:
        raise ValueError("At least one text_pair element must be given")


def parse_score_data(
    data_1: Union[str, ScoreContentPartParam],
    data_2: Union[str, ScoreContentPartParam],
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
) -> tuple[str, str, Optional[MultiModalDataDict]]:
    mm_tracker = MultiModalItemTracker(model_config, tokenizer)

    content_1 = _parse_score_content(data_1, mm_tracker)

    content_2 = _parse_score_content(data_2, mm_tracker)

    def ensure_str(content: Optional[_ContentPart]) -> str:
        if content is not None and isinstance(content, str):
            return cast(str, content)
        else:
            raise ValueError(
                f"Only string content is supported, but got {content}.")

    prompt_1 = ensure_str(content_1)
    prompt_2 = ensure_str(content_2)

    return prompt_1, prompt_2, mm_tracker.all_mm_data()


def _parse_score_content(
    data: Union[str, ScoreContentPartParam],
    mm_tracker: BaseMultiModalItemTracker,
) -> Optional[_ContentPart]:

    if isinstance(data, str):
        data = ChatCompletionContentPartTextParam(type="text", text=data)

    mm_parser = mm_tracker.create_parser()

    parse_res = _parse_chat_message_content_part(
        data,
        mm_parser,
        wrap_dicts=False,
        interleave_strings=False,
    )

    if parse_res:
        return parse_res

    mm_placeholder_storage = mm_parser.mm_placeholder_storage()

    if len(mm_placeholder_storage) != 1 or len(
            next(iter(mm_placeholder_storage.values()))) != 1:
        raise ValueError("Only one multi-modal item is supported")

    return next(iter(mm_placeholder_storage.values()))[0]


def apply_score_template(
    model_config: ModelConfig,
    prompt_1: str,
    prompt_2: str,
) -> str:
Simon Mo's avatar
Simon Mo committed
142
143
    # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
    from vllm.model_executor.model_loader import get_model_cls
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

    model = get_model_cls(model_config)
    if supports_score_template(model):
        full_prompt = model.get_score_template(prompt_1, prompt_2)
        if full_prompt is None:
            raise ValueError("Get empty score template from model")
        return full_prompt

    raise ValueError(
        f"Unsupported model architecture: {model_config.architecture}")


def post_process_tokens(
    model_config: ModelConfig,
    prompt: TokensPrompt,
) -> None:
    """
    Perform architecture-specific manipulations on the input tokens.
    
    Note:
        This is an in-place operation.
    """
Simon Mo's avatar
Simon Mo committed
166
167
168
    # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
    from vllm.model_executor.model_loader import get_model_cls

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    model = get_model_cls(model_config)
    if supports_score_template(model):
        model.post_process_tokens(prompt)


def get_score_prompt(
    model_config: ModelConfig,
    tokenizer: AnyTokenizer,
    tokenization_kwargs: dict[str, Any],
    data_1: Union[str, ScoreContentPartParam],
    data_2: Union[str, ScoreContentPartParam],
) -> tuple[str, TokensPrompt]:
    prompt_1, prompt_2, mm_data = parse_score_data(
        data_1,
        data_2,
        model_config,
        tokenizer,
    )
187
    from vllm.model_executor.model_loader import get_model_cls
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
    model = get_model_cls(model_config)
    if supports_score_template(model):
        full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
        prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
    elif model_config.use_pad_token:
        # cross_encoder models defaults to using pad_token.
        prompt_inputs = tokenizer(text=prompt_1,
                                  text_pair=prompt_2,
                                  **tokenization_kwargs)
        full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
    else:
        # `llm as reranker` models defaults to not using pad_token.
        full_prompt = prompt_1 + prompt_2
        prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
203
204
205

    engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])

206
207
208
    if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
        engine_prompt["token_type_ids"] = token_type_ids

209
210
211
212
213
    post_process_tokens(model_config, engine_prompt)

    if mm_data is not None:
        engine_prompt["multi_modal_data"] = mm_data
    return full_prompt, engine_prompt
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232


def compress_token_type_ids(token_type_ids: list[int]) -> int:
    """
    Return position of the first 1 or the length of the list
    if not found.
    """
    first_one = len(token_type_ids)
    err_msg = "Token type ids are expected to be a sequence"\
              " of zeros followed by a sequence of ones"
    for i, type_id in enumerate(token_type_ids):
        if type_id == 0 and first_one < i:
            raise ValueError(err_msg)
        elif type_id == 1 and first_one > i:
            first_one = i
        elif type_id > 1:
            raise ValueError(err_msg)

    return first_one