runners.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
17
import multiprocessing as mp
18
import os
19
20
21
22
23
24
25
26
from dataclasses import dataclass
from typing import List, Union

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

from sglang.srt.server import Runtime
Lianmin Zheng's avatar
Lianmin Zheng committed
27
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
28
29

DEFAULT_PROMPTS = [
30
    "Apple is red. Banana is Yellow. " * 800 + "Apple is",
31
    "The capital of the United Kingdom is",
32
    "Today is a sunny day and I like",
33
    "AI is a field of computer science focused on",
34
35
    # the output of gemma-2-2b from SRT is unstable on the commented prompt
    # "The capital of France is",
36
37
]

38
dirpath = os.path.dirname(__file__)
39
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
40
41
42
    long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)

43
44
45
46
47
48
49
50
51
52
NUM_TOP_LOGPROBS = 5


def get_dtype_str(torch_dtype):
    if torch_dtype is torch.float16:
        return "float16"
    else:
        raise NotImplementedError()


53
54
def get_top_logprobs(logits, k):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
55
    del logits
56
57
58
59
    logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
    return logprobs


60
61
@dataclass
class ModelOutput:
62
63
64
65
66
    output_strs: List[str] = None
    output_ids: List[int] = None
    top_input_logprobs: List[torch.Tensor] = None
    top_output_logprobs: List[torch.Tensor] = None
    embed_logits: List[torch.Tensor] = None
67
68
69
70
71
72


class HFRunner:
    def __init__(
        self,
        model_path,
73
        torch_dtype,
74
        is_generation,
75
        output_str_only=False,
76
    ):
77
        self.is_generation = is_generation
78
        self.output_str_only = output_str_only
79

80
81
82
83
        self.in_queue = mp.Queue()
        self.out_queue = mp.Queue()

        self.model_proc = mp.Process(
84
85
86
87
88
89
90
91
92
93
            target=self.start_model_process,
            args=(
                self.in_queue,
                self.out_queue,
                model_path,
                torch_dtype,
            ),
        )
        self.model_proc.start()

94
    def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
95
96
97
98
99
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            torch_dtype=torch_dtype,
        )

100
        if self.is_generation:
101
            self.base_model = AutoModelForCausalLM.from_pretrained(
102
103
                model_path,
                torch_dtype=torch_dtype,
104
                trust_remote_code=False,
105
106
107
108
109
110
111
                low_cpu_mem_usage=True,
            ).cuda()
        else:
            from sentence_transformers import SentenceTransformer

            self.model = SentenceTransformer(
                model_path,
112
113
                model_kwargs={"torch_dtype": torch_dtype},
            )
114
115

        while True:
116
117
118
119
            prompts, max_new_tokens, lora_paths = in_queue.get()
            if lora_paths is not None:
                assert len(prompts) == len(lora_paths)

120
            if prompts is not None:
121
                if self.is_generation:
122
                    output_strs = []
123
124
                    top_input_logprobs = []
                    top_output_logprobs = []
125
                    for i, p in enumerate(prompts):
126
127
128
129
130
131
132
                        if isinstance(p, str):
                            input_ids = self.tokenizer.encode(
                                p, return_tensors="pt"
                            ).cuda()
                        else:
                            input_ids = torch.tensor([p], device="cuda")

133
                        if lora_paths is not None and lora_paths[i] is not None:
134
135
                            from peft import PeftModel

136
137
138
139
140
141
142
143
144
                            self.model = PeftModel.from_pretrained(
                                self.base_model,
                                lora_paths[i],
                                torch_dtype=torch_dtype,
                                is_trainable=False,
                            )
                        else:
                            self.model = self.base_model

145
146
147
148
149
150
151
                        outputs = self.model.generate(
                            input_ids,
                            do_sample=False,
                            temperature=None,
                            top_p=None,
                            max_new_tokens=max_new_tokens,
                            return_dict_in_generate=True,
152
                            output_scores=(not self.output_str_only),
153
                        )
154
                        output_strs.append(
155
                            self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
156
                        )
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
                        if not self.output_str_only:
                            # outputs.scores: (num_token, 1, vocab_size)
                            top_output_logprobs.append(
                                [
                                    get_top_logprobs(
                                        logits[0], NUM_TOP_LOGPROBS
                                    ).tolist()
                                    for logits in outputs.scores
                                ]
                            )
                            del outputs

                            input_logits = self.model.forward(input_ids).logits[0]
                            top_input_logprobs.append(
                                get_top_logprobs(
                                    input_logits, NUM_TOP_LOGPROBS
                                ).tolist()
                            )
                            del input_logits
176
177
178

                    out_queue.put(
                        ModelOutput(
179
180
181
                            output_strs=output_strs,
                            top_input_logprobs=top_input_logprobs,
                            top_output_logprobs=top_output_logprobs,
182
183
184
185
                        )
                    )

                else:
186
                    assert not self.output_str_only
187
188
189
190
191
192
                    logits = self.model.encode(prompts).tolist()
                    out_queue.put(ModelOutput(embed_logits=logits))

    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
193
        max_new_tokens=8,
194
        lora_paths=None,
195
    ):
196
        self.in_queue.put((prompts, max_new_tokens, lora_paths))
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        return self.out_queue.get()

    def terminate(self):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None


class SRTRunner:
    def __init__(
        self,
        model_path,
215
        torch_dtype,
216
        is_generation,
217
        tp_size=1,
Lianmin Zheng's avatar
Lianmin Zheng committed
218
        port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
219
220
221
222
        lora_paths=None,
        max_loras_per_batch=4,
        disable_cuda_graph=False,
        disable_radix_cache=False,
223
    ):
224
        self.is_generation = is_generation
225
226
227
228
        self.runtime = Runtime(
            model_path=model_path,
            tp_size=tp_size,
            dtype=get_dtype_str(torch_dtype),
229
            port=port,
Liangsheng Yin's avatar
Liangsheng Yin committed
230
            mem_fraction_static=0.69,
231
232
            trust_remote_code=False,
            is_embedding=not self.is_generation,
233
234
235
236
            lora_paths=lora_paths,
            max_loras_per_batch=max_loras_per_batch,
            disable_cuda_graph=disable_cuda_graph,
            disable_radix_cache=disable_radix_cache,
237
238
239
240
241
        )

    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
242
        max_new_tokens=8,
243
        lora_paths=None,
244
    ):
245
        if self.is_generation:
246
247
248
            # the return value contains logprobs from prefill
            output_strs = []
            top_input_logprobs = []
249
            top_output_logprobs = []
250
            sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
251
            for i, prompt in enumerate(prompts):
252
253
                response = self.runtime.generate(
                    prompt,
254
                    lora_path=lora_paths[i] if lora_paths else None,
255
256
                    sampling_params=sampling_params,
                    return_logprob=True,
257
                    logprob_start_len=0,
258
259
260
261
262
                    top_logprobs_num=NUM_TOP_LOGPROBS,
                )
                response = json.loads(response)
                output_strs.append(response["text"])
                top_input_logprobs.append(
263
                    [
264
265
266
267
268
269
270
271
272
                        [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                        for x in response["meta_info"]["input_top_logprobs"][1:]
                    ]
                    + [
                        [
                            tup[0]
                            for tup in response["meta_info"]["output_top_logprobs"][0][
                                :NUM_TOP_LOGPROBS
                            ]
273
274
                        ]
                    ]
275
                )
276
277
278
279
280
281
                top_output_logprobs.append(
                    [
                        [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                        for x in response["meta_info"]["output_top_logprobs"]
                    ]
                )
282

283
            return ModelOutput(
284
285
286
                output_strs=output_strs,
                top_input_logprobs=top_input_logprobs,
                top_output_logprobs=top_output_logprobs,
287
288
            )
        else:
289
290
291
            response = self.runtime.encode(prompts)
            response = json.loads(response)
            logits = [x["embedding"] for x in response]
292
            return ModelOutput(embed_logits=logits)
293

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    def batch_forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
        max_new_tokens=8,
        lora_paths=None,
    ):
        """
        testing serving by sending all prompts once
        only return output strings and no logprobs
        """
        if self.is_generation:
            # the return value contains logprobs from prefill
            output_strs = []
            sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
            response = self.runtime.generate(
                prompts,
                lora_path=lora_paths if lora_paths else None,
                sampling_params=sampling_params,
            )
            response = json.loads(response)
            output_strs = [r["text"] for r in response]

            return ModelOutput(
                output_strs=output_strs,
            )
        else:
            response = self.runtime.encode(prompts)
            response = json.loads(response)
            logits = [x["embedding"] for x in response]
            return ModelOutput(embed_logits=logits)

325
326
327
328
329
330
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.runtime.shutdown()
        del self.runtime