test_embedding_models.py 3.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import multiprocessing as mp
16
import random
17
18
19
import unittest

import torch
20
from transformers import AutoConfig, AutoTokenizer
21
22

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
23
24
25
26
27
28
from sglang.test.test_utils import (
    CustomTestCase,
    get_similarities,
    is_in_amd_ci,
    is_in_ci,
)
29

30
31
32
MODELS = [
    ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
    ("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
33
    ("marco/mcdse-2b-v1", 1, 1e-5),
uylnap's avatar
uylnap committed
34
    ("Qwen/Qwen3-Embedding-8B", 1, 1e-5),
35
36
    # Temporarily disable before this model is fixed
    # ("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
37
]
38
39
40
TORCH_DTYPES = [torch.float16]


41
class TestEmbeddingModels(CustomTestCase):
42

43
44
45
46
    @classmethod
    def setUpClass(cls):
        mp.set_start_method("spawn", force=True)

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    def _truncate_prompts(self, prompts, model_path):
        config = AutoConfig.from_pretrained(model_path)
        max_length = getattr(config, "max_position_embeddings", 2048)

        tokenizer = AutoTokenizer.from_pretrained(model_path)

        truncated_prompts = []
        for prompt in prompts:
            tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
            if len(tokens.input_ids[0]) > max_length:
                truncated_text = tokenizer.decode(
                    tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
                )
                truncated_prompts.append(truncated_text)
            else:
                truncated_prompts.append(prompt)
        return truncated_prompts

65
66
67
68
69
70
    def assert_close_prefill_logits(
        self,
        prompts,
        model_path,
        tp_size,
        torch_dtype,
71
        prefill_tolerance,
72
    ) -> None:
73
74
        truncated_prompts = self._truncate_prompts(prompts, model_path)

75
        with HFRunner(
76
77
78
            model_path,
            torch_dtype=torch_dtype,
            model_type="embedding",
79
        ) as hf_runner:
80
            hf_outputs = hf_runner.forward(truncated_prompts)
81

82
        attention_backend = "triton" if is_in_amd_ci() else None
83
84
85
86
        with SRTRunner(
            model_path,
            tp_size=tp_size,
            torch_dtype=torch_dtype,
87
            model_type="embedding",
88
            attention_backend=attention_backend,
89
        ) as srt_runner:
90
            srt_outputs = srt_runner.forward(truncated_prompts)
91
92

        for i in range(len(prompts)):
93
94
95
            hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
            srt_logits = torch.Tensor(srt_outputs.embed_logits[i])

96
97
            similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
            print("similarity diff", abs(similarity - 1))
98

99
            if len(prompts[i]) <= 1000:
100
101
102
                assert torch.all(
                    abs(similarity - 1) < prefill_tolerance
                ), "embeddings are not all close"
103
104

    def test_prefill_logits(self):
105
106
107
108
109
110
        models_to_test = MODELS

        if is_in_ci():
            models_to_test = [random.choice(MODELS)]

        for model, tp_size, prefill_tolerance in models_to_test:
111
112
            for torch_dtype in TORCH_DTYPES:
                self.assert_close_prefill_logits(
113
                    DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
114
115
116
117
                )


if __name__ == "__main__":
Mingyi's avatar
Mingyi committed
118
    unittest.main()