test_encoder_embedding_models.py 5.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# python -m unittest test_encoder_embedding_models.TestEncoderEmbeddingModels.test_prefill_logits

import multiprocessing as mp
import random
import time
import unittest

import torch
from transformers import AutoConfig, AutoTokenizer

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci

28
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
29

30
ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"]
31
BATCH_SIZE = [1, 2]
32
TORCH_DTYPES = [torch.float32, torch.float16]
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
sgl_to_st_ratio = []


class TestEncoderEmbeddingModels(CustomTestCase):

    @classmethod
    def setUpClass(cls):
        mp.set_start_method("spawn", force=True)

    def _truncate_prompts(self, prompts, model_path):
        config = AutoConfig.from_pretrained(model_path)
        max_length = getattr(config, "max_position_embeddings", 512) - 20

        tokenizer = AutoTokenizer.from_pretrained(model_path)

        truncated_prompts = []
        for prompt in prompts:
            tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
            if len(tokens.input_ids[0]) > max_length:
                truncated_text = tokenizer.decode(
                    tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
                )
                truncated_prompts.append(truncated_text)
            else:
                truncated_prompts.append(prompt)

        return truncated_prompts

    def assert_close_prefill_logits(
        self,
        prompts,
        model_path,
        tp_size,
        torch_dtype,
        prefill_tolerance,
        attention_backend,
        batch_size,
    ) -> None:
        truncated_prompts = self._truncate_prompts(prompts, model_path)
        truncated_prompts = truncated_prompts * batch_size

        with HFRunner(
            model_path,
            torch_dtype=torch_dtype,
            model_type="embedding",
        ) as hf_runner:
            # warm up
            hf_outputs = hf_runner.forward(truncated_prompts)

82
            st_start_time = time.perf_counter()
83
            hf_outputs = hf_runner.forward(truncated_prompts)
84
            st_end_time = time.perf_counter()
85
86
87
88
89
90
91
92
93
94
95
96
97

        with SRTRunner(
            model_path,
            tp_size=tp_size,
            torch_dtype=torch_dtype,
            model_type="embedding",
            attention_backend=attention_backend,
            chunked_prefill_size=-1,
            disable_radix_cache=True,
        ) as srt_runner:
            # warm up
            srt_outputs = srt_runner.forward(truncated_prompts)

98
            sgl_start_time = time.perf_counter()
99
            srt_outputs = srt_runner.forward(truncated_prompts)
100
            sgl_end_time = time.perf_counter()
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

        transformer_time = st_end_time - st_start_time
        sgl_time = sgl_end_time - sgl_start_time
        sgl_to_st_ratio.append(sgl_time / transformer_time)

        for i in range(len(truncated_prompts)):
            hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
            srt_logits = torch.Tensor(srt_outputs.embed_logits[i])

            similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
            # If something is wrong, uncomment this to observe similarity.
            # print("similarity diff", abs(similarity - 1))

            if len(truncated_prompts[i]) <= 1000:
                assert torch.all(
                    abs(similarity - 1) < prefill_tolerance
                ), "embeddings are not all close"

    def test_prefill_logits(self):
        models_to_test = MODELS

        if is_in_ci():
            models_to_test = [random.choice(MODELS)]

        for model, tp_size, prefill_tolerance in models_to_test:
            for attention_backend in ATTENTION_BACKEND:
                for batch_size in BATCH_SIZE:
                    for torch_dtype in TORCH_DTYPES:
129
130
131
132
133
134
135
136
137
138
139
140
141
                        # NOTE: FlashInfer currently has limitations with head_dim = 32 or
                        # other dimensions.
                        # The FlashInfer head_dim limitation itself is tracked here:
                        # https://github.com/flashinfer-ai/flashinfer/issues/1048
                        #
                        # Flashinfer does not support torch.float32 for dtype_q, so skip it
                        if attention_backend == "flashinfer":
                            if (
                                model == "BAAI/bge-small-en"
                                or torch_dtype == torch.float32
                            ):
                                continue

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
                        self.assert_close_prefill_logits(
                            DEFAULT_PROMPTS,
                            model,
                            tp_size,
                            torch_dtype,
                            prefill_tolerance,
                            attention_backend,
                            batch_size,
                        )

        for i in range(len(BATCH_SIZE)):
            print(
                "bacth size: ",
                BATCH_SIZE[i] * 5,
                "sgl_time/st_time",
                round(sgl_to_st_ratio[i], 3),
            )


if __name__ == "__main__":
    unittest.main()