score_utils.py 1.68 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Union
4
5
6
7
8
9
10
11
12
13

from torch.nn import CosineSimilarity

from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
                                               PreTrainedTokenizerFast)


def _cosine_similarity(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
14
15
16
    embed_1: list[PoolingRequestOutput],
    embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
17
18

    scorer = CosineSimilarity(0)
19
    scores: Union[list[PoolingRequestOutput]] = []
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    for emb_1, emb_2 in zip(embed_1, embed_2):
        pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)

        padding = []
        if (pad_token_id := getattr(tokenizer, "pad_token_id",
                                    None)) is not None:
            padding = [pad_token_id]

        tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

        scores.append(
            PoolingRequestOutput(
                request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                outputs=pair_score,
                prompt_token_ids=tokens,
                finished=True))

    return scores


def _validate_score_input_lens(
42
43
    texts_1: Union[list[str], list[dict]],
    texts_2: Union[list[str], list[dict]],
44
45
46
47
48
49
):
    if len(texts_1) > 1 and len(texts_1) != len(texts_2):
        raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
    if len(texts_1) == 0:
        raise ValueError("At least one text element must be given")
    if len(texts_2) == 0:
50
        raise ValueError("At least one text_pair element must be given")