"vllm/model_executor/models/terratorch.py" did not exist on "974dfd497149e871e59e35b677a85cca66ec3bae"
utils.py 1.82 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from collections.abc import Sequence
4
from typing import NamedTuple, Optional
5
6
7
8
9
10
11

import torch
import torch.nn.functional as F


def check_embeddings_close(
    *,
12
13
    embeddings_0_lst: Sequence[list[float]],
    embeddings_1_lst: Sequence[list[float]],
14
15
16
17
18
19
20
21
    name_0: str,
    name_1: str,
    tol: float = 1e-3,
) -> None:
    assert len(embeddings_0_lst) == len(embeddings_1_lst)

    for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
            zip(embeddings_0_lst, embeddings_1_lst)):
Cyrus Leung's avatar
Cyrus Leung committed
22
23
        assert len(embeddings_0) == len(embeddings_1), (
            f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
24
25
26
27
28
29

        sim = F.cosine_similarity(torch.tensor(embeddings_0),
                                  torch.tensor(embeddings_1),
                                  dim=0)

        fail_msg = (f"Test{prompt_idx}:"
30
31
                    f"\n{name_0}:\t{embeddings_0[:16]!r}"
                    f"\n{name_1}:\t{embeddings_1[:16]!r}")
32
33

        assert sim >= 1 - tol, fail_msg
34
35
36
37
38
39
40


def matryoshka_fy(tensor, dimensions):
    tensor = torch.tensor(tensor)
    tensor = tensor[..., :dimensions]
    tensor = F.normalize(tensor, p=2, dim=1)
    return tensor
41
42
43
44
45


class EmbedModelInfo(NamedTuple):
    name: str
    is_matryoshka: bool
46
    matryoshka_dimensions: Optional[list[int]] = None
47
48
    architecture: str = ""
    enable_test: bool = True
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66


def correctness_test(hf_model,
                     inputs,
                     vllm_outputs: Sequence[list[float]],
                     dimensions: Optional[int] = None):

    hf_outputs = hf_model.encode(inputs)
    if dimensions:
        hf_outputs = matryoshka_fy(hf_outputs, dimensions)

    check_embeddings_close(
        embeddings_0_lst=hf_outputs,
        embeddings_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
        tol=1e-2,
    )