test_protocol.py 4.77 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for Cohere embed protocol: build_typed_embeddings and its
underlying packing helpers, plus Cohere-specific serving helpers."""

import struct

import numpy as np
9
import pybase64 as base64
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pytest

from vllm.entrypoints.pooling.embed.protocol import (
    build_typed_embeddings,
)


@pytest.fixture
def sample_embeddings() -> list[list[float]]:
    return [
        [0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8],
        [-0.05, 0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75],
    ]


class TestBuildTypedEmbeddingsFloat:
    def test_float_passthrough(self, sample_embeddings: list[list[float]]):
        result = build_typed_embeddings(sample_embeddings, ["float"])
        assert result.float == sample_embeddings
        assert result.binary is None

    def test_empty_input(self):
        result = build_typed_embeddings([], ["float"])
        assert result.float == []


class TestBuildTypedEmbeddingsBinary:
    def test_binary_packing(self):
        # 8 values: positive->1, negative->0 => bits: 10101010 = 0xAA = 170
        # signed: 170 - 128 = 42
        embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
        result = build_typed_embeddings(embs, ["binary"])
        assert result.binary is not None
        assert result.binary[0] == [42]

    def test_ubinary_packing(self):
        embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
        result = build_typed_embeddings(embs, ["ubinary"])
        assert result.ubinary is not None
        assert result.ubinary[0] == [170]  # 0b10101010

    def test_binary_all_positive(self):
        embs = [[0.1] * 8]
        result = build_typed_embeddings(embs, ["binary"])
        assert result.binary is not None
        # all bits = 1 => 0xFF = 255, signed: 255 - 128 = 127
        assert result.binary[0] == [127]

    def test_binary_all_negative(self):
        embs = [[-0.1] * 8]
        result = build_typed_embeddings(embs, ["binary"])
        assert result.binary is not None
        # all bits = 0, signed: 0 - 128 = -128
        assert result.binary[0] == [-128]

    def test_binary_dimension_is_eighth(self, sample_embeddings: list[list[float]]):
        result = build_typed_embeddings(sample_embeddings, ["binary"])
        assert result.binary is not None
        for orig, packed in zip(sample_embeddings, result.binary):
            assert len(packed) == len(orig) // 8

    def test_zero_treated_as_positive(self):
        embs = [[0.0] * 8]
        result = build_typed_embeddings(embs, ["binary"])
        assert result.binary is not None
        # 0.0 >= 0 is True, so bit=1 for all => 127 (signed)
        assert result.binary[0] == [127]

    def test_non_multiple_of_8_raises(self):
        embs = [[0.1] * 7]
        with pytest.raises(ValueError, match="multiple of 8"):
            build_typed_embeddings(embs, ["binary"])

    def test_ubinary_non_multiple_of_8_raises(self):
        embs = [[0.1] * 10]
        with pytest.raises(ValueError, match="multiple of 8"):
            build_typed_embeddings(embs, ["ubinary"])


class TestBuildTypedEmbeddingsBase64:
    def test_base64_roundtrip(self, sample_embeddings: list[list[float]]):
        result = build_typed_embeddings(sample_embeddings, ["base64"])
        assert result.base64 is not None
        assert len(result.base64) == 2

        for orig, b64_str in zip(sample_embeddings, result.base64):
            decoded = base64.b64decode(b64_str)
            n = len(orig)
            values = struct.unpack(f"<{n}f", decoded)
            np.testing.assert_allclose(orig, values, rtol=1e-5)

    def test_base64_byte_length(self):
        embs = [[0.1, 0.2, 0.3]]
        result = build_typed_embeddings(embs, ["base64"])
        assert result.base64 is not None
        raw = base64.b64decode(result.base64[0])
        assert len(raw) == 3 * 4  # 3 floats * 4 bytes each


class TestBuildTypedEmbeddingsMultiple:
    def test_all_types_at_once(self, sample_embeddings: list[list[float]]):
        result = build_typed_embeddings(
            sample_embeddings,
            ["float", "binary", "ubinary", "base64"],
        )
        assert result.float is not None
        assert result.binary is not None
        assert result.ubinary is not None
        assert result.base64 is not None

    def test_subset_types(self, sample_embeddings: list[list[float]]):
        result = build_typed_embeddings(sample_embeddings, ["float", "binary"])
        assert result.float is not None
        assert result.binary is not None
        assert result.ubinary is None
        assert result.base64 is None

    def test_unknown_type_ignored(self, sample_embeddings: list[list[float]]):
        result = build_typed_embeddings(sample_embeddings, ["float", "unknown_type"])
        assert result.float is not None