runners.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
17
import multiprocessing as mp
18
import os
19
20
21
22
23
from dataclasses import dataclass
from typing import List, Union

import torch
import torch.nn.functional as F
24
from transformers import AutoModelForCausalLM
25

26
from sglang.srt.hf_transformers_utils import get_tokenizer
27
from sglang.srt.server import Runtime
Lianmin Zheng's avatar
Lianmin Zheng committed
28
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
29
30

DEFAULT_PROMPTS = [
31
    "Apple is red. Banana is Yellow. " * 800 + "Apple is",
32
    "The capital of the United Kingdom is",
33
    "Today is a sunny day and I like",
34
    "AI is a field of computer science focused on",
35
36
    # the output of gemma-2-2b from SRT is unstable on the commented prompt
    # "The capital of France is",
37
38
]

39
dirpath = os.path.dirname(__file__)
40
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
41
42
43
    long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)

44
45
46
47
48
49
50
51
52
53
NUM_TOP_LOGPROBS = 5


def get_dtype_str(torch_dtype):
    if torch_dtype is torch.float16:
        return "float16"
    else:
        raise NotImplementedError()


54
55
def get_top_logprobs(logits, k):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
56
    del logits
57
58
59
60
    logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
    return logprobs


61
62
@dataclass
class ModelOutput:
63
64
65
66
67
    output_strs: List[str] = None
    output_ids: List[int] = None
    top_input_logprobs: List[torch.Tensor] = None
    top_output_logprobs: List[torch.Tensor] = None
    embed_logits: List[torch.Tensor] = None
68
    scores: List[float] = None
69
70
71
72
73
74


class HFRunner:
    def __init__(
        self,
        model_path,
75
        torch_dtype,
76
        model_type="generation",
77
        output_str_only=False,
78
    ):
79
        self.model_type = model_type
80
        self.output_str_only = output_str_only
81

82
83
84
85
        self.in_queue = mp.Queue()
        self.out_queue = mp.Queue()

        self.model_proc = mp.Process(
86
87
88
89
90
91
92
93
94
95
            target=self.start_model_process,
            args=(
                self.in_queue,
                self.out_queue,
                model_path,
                torch_dtype,
            ),
        )
        self.model_proc.start()

96
97
98
99
100
101
102
103
    def needs_trust_remote_code(self, model_path):
        models_needs_trust_remote = [
            "LxzGordon/URM-LLaMa-3.1-8B",
        ]
        if model_path in models_needs_trust_remote:
            return True
        return False

104
    def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
105
106
107
        self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)

        if self.model_type == "generation":
108
            self.base_model = AutoModelForCausalLM.from_pretrained(
109
110
                model_path,
                torch_dtype=torch_dtype,
111
                trust_remote_code=False,
112
113
                low_cpu_mem_usage=True,
            ).cuda()
114
        elif self.model_type == "embedding":
115
116
117
118
            from sentence_transformers import SentenceTransformer

            self.model = SentenceTransformer(
                model_path,
119
                model_kwargs={"torch_dtype": torch_dtype},
120
121
122
123
124
125
126
127
128
129
130
            ).cuda()
        elif self.model_type == "reward":
            from transformers import AutoModelForSequenceClassification

            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
                trust_remote_code=self.needs_trust_remote_code(model_path),
            ).cuda()
        else:
            raise Exception(f"Unrecognized model type {self.model_type}")
131
132

        while True:
133
134
135
136
            prompts, max_new_tokens, lora_paths = in_queue.get()
            if lora_paths is not None:
                assert len(prompts) == len(lora_paths)

137
            if prompts is not None:
138
                if self.model_type == "generation":
139
                    output_strs = []
140
141
                    top_input_logprobs = []
                    top_output_logprobs = []
142
                    for i, p in enumerate(prompts):
143
144
145
146
147
148
149
                        if isinstance(p, str):
                            input_ids = self.tokenizer.encode(
                                p, return_tensors="pt"
                            ).cuda()
                        else:
                            input_ids = torch.tensor([p], device="cuda")

150
                        if lora_paths is not None and lora_paths[i] is not None:
151
152
                            from peft import PeftModel

153
154
155
156
157
158
159
160
161
                            self.model = PeftModel.from_pretrained(
                                self.base_model,
                                lora_paths[i],
                                torch_dtype=torch_dtype,
                                is_trainable=False,
                            )
                        else:
                            self.model = self.base_model

162
163
164
165
166
167
168
                        outputs = self.model.generate(
                            input_ids,
                            do_sample=False,
                            temperature=None,
                            top_p=None,
                            max_new_tokens=max_new_tokens,
                            return_dict_in_generate=True,
169
                            output_scores=(not self.output_str_only),
170
                        )
171
                        output_strs.append(
172
                            self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
173
                        )
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                        if not self.output_str_only:
                            # outputs.scores: (num_token, 1, vocab_size)
                            top_output_logprobs.append(
                                [
                                    get_top_logprobs(
                                        logits[0], NUM_TOP_LOGPROBS
                                    ).tolist()
                                    for logits in outputs.scores
                                ]
                            )
                            del outputs

                            input_logits = self.model.forward(input_ids).logits[0]
                            top_input_logprobs.append(
                                get_top_logprobs(
                                    input_logits, NUM_TOP_LOGPROBS
                                ).tolist()
                            )
                            del input_logits
193
194
195

                    out_queue.put(
                        ModelOutput(
196
197
198
                            output_strs=output_strs,
                            top_input_logprobs=top_input_logprobs,
                            top_output_logprobs=top_output_logprobs,
199
200
201
                        )
                    )

202
                elif self.model_type == "embedding":
203
                    assert not self.output_str_only
204
205
206
                    logits = self.model.encode(prompts).tolist()
                    out_queue.put(ModelOutput(embed_logits=logits))

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
                elif self.model_type == "reward":
                    scores = []
                    for conv in prompts:
                        conv_formatted = self.tokenizer.apply_chat_template(
                            conv, tokenize=False
                        )
                        conv_tokenized = self.tokenizer(
                            conv_formatted, return_tensors="pt"
                        ).to("cuda")
                        scores.append(
                            float(self.model(**conv_tokenized).logits[0][0].item())
                        )
                    out_queue.put(ModelOutput(scores=scores))
                else:
                    raise Exception(f"Unrecognized model type {self.model_type}")

223
224
225
    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
226
        max_new_tokens=8,
227
        lora_paths=None,
228
    ):
229
        self.in_queue.put((prompts, max_new_tokens, lora_paths))
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        return self.out_queue.get()

    def terminate(self):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None


class SRTRunner:
    def __init__(
        self,
        model_path,
248
        torch_dtype,
249
        model_type,
250
        tp_size=1,
Lianmin Zheng's avatar
Lianmin Zheng committed
251
        port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
252
253
254
255
        lora_paths=None,
        max_loras_per_batch=4,
        disable_cuda_graph=False,
        disable_radix_cache=False,
256
    ):
257
258
        self.model_type = model_type
        self.is_generation = model_type == "generation"
259
260
261
262
        self.runtime = Runtime(
            model_path=model_path,
            tp_size=tp_size,
            dtype=get_dtype_str(torch_dtype),
263
            port=port,
264
            mem_fraction_static=0.65,
265
266
            trust_remote_code=False,
            is_embedding=not self.is_generation,
267
268
269
270
            lora_paths=lora_paths,
            max_loras_per_batch=max_loras_per_batch,
            disable_cuda_graph=disable_cuda_graph,
            disable_radix_cache=disable_radix_cache,
271
272
273
274
275
        )

    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
276
        max_new_tokens=8,
277
        lora_paths=None,
278
    ):
279
        if self.is_generation:
280
281
282
            # the return value contains logprobs from prefill
            output_strs = []
            top_input_logprobs = []
283
            top_output_logprobs = []
284
            sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
285
            for i, prompt in enumerate(prompts):
286
287
                response = self.runtime.generate(
                    prompt,
288
                    lora_path=lora_paths[i] if lora_paths else None,
289
290
                    sampling_params=sampling_params,
                    return_logprob=True,
291
                    logprob_start_len=0,
292
293
294
295
296
                    top_logprobs_num=NUM_TOP_LOGPROBS,
                )
                response = json.loads(response)
                output_strs.append(response["text"])
                top_input_logprobs.append(
297
                    [
298
299
300
301
302
303
304
305
306
                        [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                        for x in response["meta_info"]["input_top_logprobs"][1:]
                    ]
                    + [
                        [
                            tup[0]
                            for tup in response["meta_info"]["output_top_logprobs"][0][
                                :NUM_TOP_LOGPROBS
                            ]
307
308
                        ]
                    ]
309
                )
310
311
312
313
314
315
                top_output_logprobs.append(
                    [
                        [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                        for x in response["meta_info"]["output_top_logprobs"]
                    ]
                )
316

317
            return ModelOutput(
318
319
320
                output_strs=output_strs,
                top_input_logprobs=top_input_logprobs,
                top_output_logprobs=top_output_logprobs,
321
322
            )
        else:
323
324
            response = self.runtime.encode(prompts)
            response = json.loads(response)
325
326
327
328
329
330
            if self.model_type == "embedding":
                logits = [x["embedding"] for x in response]
                return ModelOutput(embed_logits=logits)
            else:
                scores = [x["embedding"][0] for x in response]
                return ModelOutput(scores=scores)
331

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    def batch_forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
        max_new_tokens=8,
        lora_paths=None,
    ):
        """
        testing serving by sending all prompts once
        only return output strings and no logprobs
        """
        if self.is_generation:
            # the return value contains logprobs from prefill
            output_strs = []
            sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
            response = self.runtime.generate(
                prompts,
                lora_path=lora_paths if lora_paths else None,
                sampling_params=sampling_params,
            )
            response = json.loads(response)
            output_strs = [r["text"] for r in response]

            return ModelOutput(
                output_strs=output_strs,
            )
        else:
            response = self.runtime.encode(prompts)
            response = json.loads(response)
360
361
362
363
364
365
            if self.model_type == "embedding":
                logits = [x["embedding"] for x in response]
                return ModelOutput(embed_logits=logits)
            else:
                scores = [x["embedding"][0] for x in response]
                return ModelOutput(scores=logits)
366

367
368
369
370
371
372
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.runtime.shutdown()
        del self.runtime