runners.py 15.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import multiprocessing as mp
16
import os
17
18
19
20
21
from dataclasses import dataclass
from typing import List, Union

import torch
import torch.nn.functional as F
22
from transformers import AutoModelForCausalLM
23

24
from sglang.srt.entrypoints.engine import Engine
25
from sglang.srt.hf_transformers_utils import get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
26
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
27
28

DEFAULT_PROMPTS = [
29
    "Apple is red. Banana is Yellow. " * 800 + "Apple is",
30
    "The capital of the United Kingdom is",
31
    "Today is a sunny day and I like",
32
    "AI is a field of computer science focused on",
33
34
    # the output of gemma-2-2b from SRT is unstable on the commented prompt
    # "The capital of France is",
35
36
]

37
dirpath = os.path.dirname(__file__)
38
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
39
40
41
    long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)

42
43
44
45
46
47
48
49
50
51
NUM_TOP_LOGPROBS = 5


def get_dtype_str(torch_dtype):
    if torch_dtype is torch.float16:
        return "float16"
    else:
        raise NotImplementedError()


52
53
def get_top_logprobs(logits, k):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
54
    del logits
55
56
57
58
    logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
    return logprobs


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
    from sentence_transformers import SentenceTransformer
    from sentence_transformers.util import is_sentence_transformer_model

    if is_sentence_transformer_model(model_path):
        model = SentenceTransformer(
            model_path,
            model_kwargs={"torch_dtype": torch_dtype},
        )
    else:  # if no pre-trained sentence-transformers model
        from sentence_transformers import models

        word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
        pooling_model = models.Pooling(
            word_embedding_model.get_word_embedding_dimension(),
            pooling_mode="lasttoken",
        )
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    return model.cuda()


81
82
@dataclass
class ModelOutput:
83
84
85
86
87
    output_strs: List[str] = None
    output_ids: List[int] = None
    top_input_logprobs: List[torch.Tensor] = None
    top_output_logprobs: List[torch.Tensor] = None
    embed_logits: List[torch.Tensor] = None
88
    scores: List[float] = None
89
90
91
92
93


class HFRunner:
    def __init__(
        self,
94
95
96
97
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str = "generation",
        output_str_only: bool = False,
98
    ):
99
        self.model_type = model_type
100
        self.output_str_only = output_str_only
101

102
103
104
105
        self.in_queue = mp.Queue()
        self.out_queue = mp.Queue()

        self.model_proc = mp.Process(
106
107
108
109
110
111
112
113
114
115
            target=self.start_model_process,
            args=(
                self.in_queue,
                self.out_queue,
                model_path,
                torch_dtype,
            ),
        )
        self.model_proc.start()

116
117
118
119
120
121
122
123
    def needs_trust_remote_code(self, model_path):
        models_needs_trust_remote = [
            "LxzGordon/URM-LLaMa-3.1-8B",
        ]
        if model_path in models_needs_trust_remote:
            return True
        return False

124
    def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
125
126
        # Apply model-specific patches
        monkey_patch_gemma2_sdpa()
127

128
        # Load the model and tokenizer
129
        if self.model_type == "generation":
130
            self.base_model = AutoModelForCausalLM.from_pretrained(
131
132
                model_path,
                torch_dtype=torch_dtype,
133
                trust_remote_code=False,
134
135
                low_cpu_mem_usage=True,
            ).cuda()
136
        elif self.model_type == "embedding":
137
138
139
            self.model = _get_sentence_transformer_embedding_model(
                model_path, torch_dtype
            )
140
141
142
143
144
145
146
147
148
149
        elif self.model_type == "reward":
            from transformers import AutoModelForSequenceClassification

            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
                trust_remote_code=self.needs_trust_remote_code(model_path),
            ).cuda()
        else:
            raise Exception(f"Unrecognized model type {self.model_type}")
150
        self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
151

152
        # Run forward
153
        while True:
154
155
156
157
            prompts, max_new_tokens, lora_paths = in_queue.get()
            if lora_paths is not None:
                assert len(prompts) == len(lora_paths)

158
            if prompts is not None:
159
                if self.model_type == "generation":
160
                    output_strs = []
161
162
                    top_input_logprobs = []
                    top_output_logprobs = []
163
                    for i, p in enumerate(prompts):
164
165
166
167
168
169
170
                        if isinstance(p, str):
                            input_ids = self.tokenizer.encode(
                                p, return_tensors="pt"
                            ).cuda()
                        else:
                            input_ids = torch.tensor([p], device="cuda")

171
                        if lora_paths is not None and lora_paths[i] is not None:
172
173
                            from peft import PeftModel

174
175
176
177
178
179
180
181
182
                            self.model = PeftModel.from_pretrained(
                                self.base_model,
                                lora_paths[i],
                                torch_dtype=torch_dtype,
                                is_trainable=False,
                            )
                        else:
                            self.model = self.base_model

183
184
185
186
187
188
189
                        outputs = self.model.generate(
                            input_ids,
                            do_sample=False,
                            temperature=None,
                            top_p=None,
                            max_new_tokens=max_new_tokens,
                            return_dict_in_generate=True,
190
                            output_scores=(not self.output_str_only),
191
                        )
192
193
194

                        text = self.tokenizer.decode(
                            outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
195
                        )
196
197
198
199
200
201
202
                        # Check if the text is empty or only whitespace.
                        if not text.strip():
                            raise ValueError(
                                "Received an empty text response. Please verify your input or model configuration."
                            )
                        output_strs.append(text)

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                        if not self.output_str_only:
                            # outputs.scores: (num_token, 1, vocab_size)
                            top_output_logprobs.append(
                                [
                                    get_top_logprobs(
                                        logits[0], NUM_TOP_LOGPROBS
                                    ).tolist()
                                    for logits in outputs.scores
                                ]
                            )
                            del outputs

                            input_logits = self.model.forward(input_ids).logits[0]
                            top_input_logprobs.append(
                                get_top_logprobs(
                                    input_logits, NUM_TOP_LOGPROBS
                                ).tolist()
                            )
                            del input_logits
222
223
224

                    out_queue.put(
                        ModelOutput(
225
226
227
                            output_strs=output_strs,
                            top_input_logprobs=top_input_logprobs,
                            top_output_logprobs=top_output_logprobs,
228
229
230
                        )
                    )

231
                elif self.model_type == "embedding":
232
                    assert not self.output_str_only
233
234
235
                    logits = self.model.encode(prompts).tolist()
                    out_queue.put(ModelOutput(embed_logits=logits))

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                elif self.model_type == "reward":
                    scores = []
                    for conv in prompts:
                        conv_formatted = self.tokenizer.apply_chat_template(
                            conv, tokenize=False
                        )
                        conv_tokenized = self.tokenizer(
                            conv_formatted, return_tensors="pt"
                        ).to("cuda")
                        scores.append(
                            float(self.model(**conv_tokenized).logits[0][0].item())
                        )
                    out_queue.put(ModelOutput(scores=scores))
                else:
                    raise Exception(f"Unrecognized model type {self.model_type}")

252
253
254
    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
255
        max_new_tokens=8,
256
        lora_paths=None,
257
    ):
258
        self.in_queue.put((prompts, max_new_tokens, lora_paths))
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        return self.out_queue.get()

    def terminate(self):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None


class SRTRunner:
    def __init__(
        self,
276
277
278
279
280
281
282
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str,
        tp_size: int = 1,
        port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
        lora_paths: List[str] = None,
        max_loras_per_batch: int = 4,
283
        lora_backend: str = "triton",
284
285
        disable_cuda_graph: bool = False,
        disable_radix_cache: bool = False,
286
        mem_fraction_static: float = 0.65,
287
    ):
288
289
        self.model_type = model_type
        self.is_generation = model_type == "generation"
290
        self.engine = Engine(
291
292
293
            model_path=model_path,
            tp_size=tp_size,
            dtype=get_dtype_str(torch_dtype),
294
            port=port,
295
            mem_fraction_static=mem_fraction_static,
296
297
            trust_remote_code=False,
            is_embedding=not self.is_generation,
298
299
            lora_paths=lora_paths,
            max_loras_per_batch=max_loras_per_batch,
300
            lora_backend=lora_backend,
301
302
            disable_cuda_graph=disable_cuda_graph,
            disable_radix_cache=disable_radix_cache,
303
        )
304
        self.tokenizer = get_tokenizer(model_path)
305
306
307
308

    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
309
        max_new_tokens=8,
310
        lora_paths=None,
311
    ):
312
        if self.is_generation:
313
314
315
            # the return value contains logprobs from prefill
            output_strs = []
            top_input_logprobs = []
316
            top_output_logprobs = []
317
            sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
318
            for i, prompt in enumerate(prompts):
319
                response = self.engine.generate(
320
                    prompt,
321
                    lora_path=lora_paths[i] if lora_paths else None,
322
323
                    sampling_params=sampling_params,
                    return_logprob=True,
324
                    logprob_start_len=0,
325
326
                    top_logprobs_num=NUM_TOP_LOGPROBS,
                )
327
328
329
330
331
332
333
334
335
                text = response["text"]

                # Check if the text is empty or only whitespace.
                if not text.strip():
                    raise ValueError(
                        "Received an empty text response. Please verify your input or model configuration."
                    )
                output_strs.append(text)

336
                top_input_logprobs.append(
337
                    [
338
339
340
341
342
343
344
345
346
                        [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                        for x in response["meta_info"]["input_top_logprobs"][1:]
                    ]
                    + [
                        [
                            tup[0]
                            for tup in response["meta_info"]["output_top_logprobs"][0][
                                :NUM_TOP_LOGPROBS
                            ]
347
348
                        ]
                    ]
349
                )
350
351
352
353
354
355
                top_output_logprobs.append(
                    [
                        [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                        for x in response["meta_info"]["output_top_logprobs"]
                    ]
                )
356

357
            return ModelOutput(
358
359
360
                output_strs=output_strs,
                top_input_logprobs=top_input_logprobs,
                top_output_logprobs=top_output_logprobs,
361
362
            )
        else:
363
            response = self.engine.encode(prompts)
364
365
366
367
368
369
            if self.model_type == "embedding":
                logits = [x["embedding"] for x in response]
                return ModelOutput(embed_logits=logits)
            else:
                scores = [x["embedding"][0] for x in response]
                return ModelOutput(scores=scores)
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
    def batch_forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
        max_new_tokens=8,
        lora_paths=None,
    ):
        """
        testing serving by sending all prompts once
        only return output strings and no logprobs
        """
        if self.is_generation:
            # the return value contains logprobs from prefill
            output_strs = []
            sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
385
            response = self.engine.generate(
386
387
388
389
390
391
392
393
394
395
                prompts,
                lora_path=lora_paths if lora_paths else None,
                sampling_params=sampling_params,
            )
            output_strs = [r["text"] for r in response]

            return ModelOutput(
                output_strs=output_strs,
            )
        else:
396
            response = self.engine.encode(prompts)
397
398
399
400
401
            if self.model_type == "embedding":
                logits = [x["embedding"] for x in response]
                return ModelOutput(embed_logits=logits)
            else:
                scores = [x["embedding"][0] for x in response]
402
                return ModelOutput(scores=scores)
403

404
405
406
407
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
408
409
        self.engine.shutdown()
        del self.engine
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424


def monkey_patch_gemma2_sdpa():
    """
    Use sdpa by default to fix the OOM issue.
    Revert this commit:
    https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660
    """
    from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel

    def _check_and_enable_sdpa(config, hard_check_only: bool = False):
        config._attn_implementation = "sdpa"
        return config

    setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)