test_embedding_models.py 4.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import multiprocessing as mp
16
import random
17
import unittest
18
from typing import Optional
19
20

import torch
21
from transformers import AutoConfig, AutoTokenizer
22
23

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
24
25
26
27
28
29
from sglang.test.test_utils import (
    CustomTestCase,
    get_similarities,
    is_in_amd_ci,
    is_in_ci,
)
30

31
32
33
MODELS = [
    ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
    ("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
34
    ("marco/mcdse-2b-v1", 1, 1e-5),
uylnap's avatar
uylnap committed
35
    ("Qwen/Qwen3-Embedding-8B", 1, 1e-5),
36
37
    # Temporarily disable before this model is fixed
    # ("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
38
]
39
40
41
TORCH_DTYPES = [torch.float16]


42
class TestEmbeddingModels(CustomTestCase):
43

44
45
46
47
    @classmethod
    def setUpClass(cls):
        mp.set_start_method("spawn", force=True)

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def _truncate_prompts(self, prompts, model_path):
        config = AutoConfig.from_pretrained(model_path)
        max_length = getattr(config, "max_position_embeddings", 2048)

        tokenizer = AutoTokenizer.from_pretrained(model_path)

        truncated_prompts = []
        for prompt in prompts:
            tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
            if len(tokens.input_ids[0]) > max_length:
                truncated_text = tokenizer.decode(
                    tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
                )
                truncated_prompts.append(truncated_text)
            else:
                truncated_prompts.append(prompt)
        return truncated_prompts

66
67
68
69
70
71
    def assert_close_prefill_logits(
        self,
        prompts,
        model_path,
        tp_size,
        torch_dtype,
72
        prefill_tolerance,
73
        matryoshka_dim: Optional[int] = None,
74
    ) -> None:
75
76
        truncated_prompts = self._truncate_prompts(prompts, model_path)

77
        with HFRunner(
78
79
80
            model_path,
            torch_dtype=torch_dtype,
            model_type="embedding",
81
            matryoshka_dim=matryoshka_dim,
82
        ) as hf_runner:
83
            hf_outputs = hf_runner.forward(truncated_prompts)
84

85
        attention_backend = "triton" if is_in_amd_ci() else None
86
87
88
89
        with SRTRunner(
            model_path,
            tp_size=tp_size,
            torch_dtype=torch_dtype,
90
            model_type="embedding",
91
            attention_backend=attention_backend,
92
93
94
            json_model_override_args=(
                {"matryoshka_dimensions": [matryoshka_dim]} if matryoshka_dim else None
            ),
95
        ) as srt_runner:
96
97
98
            srt_outputs = srt_runner.forward(
                truncated_prompts, dimensions=matryoshka_dim
            )
99
100

        for i in range(len(prompts)):
101
102
103
            hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
            srt_logits = torch.Tensor(srt_outputs.embed_logits[i])

104
105
            similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
            print("similarity diff", abs(similarity - 1))
106

107
            if len(prompts[i]) <= 1000:
108
109
110
                assert torch.all(
                    abs(similarity - 1) < prefill_tolerance
                ), "embeddings are not all close"
111
112

    def test_prefill_logits(self):
113
114
115
116
117
118
        models_to_test = MODELS

        if is_in_ci():
            models_to_test = [random.choice(MODELS)]

        for model, tp_size, prefill_tolerance in models_to_test:
119
120
            for torch_dtype in TORCH_DTYPES:
                self.assert_close_prefill_logits(
121
                    DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
122
123
                )

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    def test_matryoshka_embedding(self):
        models_to_test = [
            model
            for model in MODELS
            if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == model[0]
        ]
        assert len(models_to_test) == 1

        for model, tp_size, prefill_tolerance in models_to_test:
            for torch_dtype in TORCH_DTYPES:
                self.assert_close_prefill_logits(
                    DEFAULT_PROMPTS,
                    model,
                    tp_size,
                    torch_dtype,
                    prefill_tolerance,
                    matryoshka_dim=128,
                )

143
144

if __name__ == "__main__":
Mingyi's avatar
Mingyi committed
145
    unittest.main()