runners.py 31.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import multiprocessing as mp
16
import os
17
from dataclasses import dataclass
18
from typing import List, Optional, Tuple, Union
19
20
21

import torch
import torch.nn.functional as F
Kiv Chen's avatar
Kiv Chen committed
22
import transformers
uylnap's avatar
uylnap committed
23
from transformers import (
Kiv Chen's avatar
Kiv Chen committed
24
    AutoConfig,
uylnap's avatar
uylnap committed
25
26
27
28
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForVision2Seq,
    AutoProcessor,
29
    GenerationConfig,
uylnap's avatar
uylnap committed
30
)
31

Lianmin Zheng's avatar
Lianmin Zheng committed
32
from sglang.srt.entrypoints.engine import Engine
33
from sglang.srt.hf_transformers_utils import get_tokenizer
uylnap's avatar
uylnap committed
34
from sglang.srt.utils import load_image
35
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
36
37

DEFAULT_PROMPTS = [
38
    "Apple is red. Banana is Yellow. " * 800 + "Apple is",
39
    "The capital of the United Kingdom is",
40
    "Today is a sunny day and I like",
41
    "AI is a field of computer science focused on",
42
43
    # the output of gemma-2-2b from SRT is unstable on the commented prompt
    # "The capital of France is",
44
]
woodx's avatar
woodx committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
TEST_RERANK_QUERY_DOCS = [
    {
        "query": "How many people live in Berlin?",
        "documents": [
            "Berlin is well known for its museums.",
        ],
    },
    {
        "query": "How many people live in Berlin?",
        "documents": [
            "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
            "Berlin is well known for its museums.",
        ],
    },
]
60

61
dirpath = os.path.dirname(__file__)
62
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
63
64
65
    long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)

66
67
68
69
70
71
NUM_TOP_LOGPROBS = 5


def get_dtype_str(torch_dtype):
    if torch_dtype is torch.float16:
        return "float16"
72
73
    if torch_dtype is torch.float32:
        return "float32"
74
75
76
77
    else:
        raise NotImplementedError()


78
79
def get_top_logprobs(logits, k):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
80
    del logits
81
82
83
84
    logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
    return logprobs


85
86
87
88
89
90
91
def get_token_ids_logprobs(logits, token_ids):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    del logits
    logprobs = logprobs[..., token_ids]
    return logprobs


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
    from sentence_transformers import SentenceTransformer
    from sentence_transformers.util import is_sentence_transformer_model

    if is_sentence_transformer_model(model_path):
        model = SentenceTransformer(
            model_path,
            model_kwargs={"torch_dtype": torch_dtype},
        )
    else:  # if no pre-trained sentence-transformers model
        from sentence_transformers import models

        word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
        pooling_model = models.Pooling(
            word_embedding_model.get_word_embedding_dimension(),
            pooling_mode="lasttoken",
        )
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    return model.cuda()


114
115
@dataclass
class ModelOutput:
116
117
118
119
    output_strs: List[str] = None
    output_ids: List[int] = None
    top_input_logprobs: List[torch.Tensor] = None
    top_output_logprobs: List[torch.Tensor] = None
120
    top_output_logprob_idx: List[List[int]] = None
121
    embed_logits: List[torch.Tensor] = None
122
    scores: List[float] = None
123
124
125
126
    input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
    output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
    token_ids_input_logprobs: List[torch.Tensor] = None
    token_ids_output_logprobs: List[torch.Tensor] = None
127
128
129
130
131


class HFRunner:
    def __init__(
        self,
132
133
134
135
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str = "generation",
        output_str_only: bool = False,
136
        trust_remote_code: bool = False,
137
    ):
138
        self.model_type = model_type
139
        self.output_str_only = output_str_only
140
        self.trust_remote_code = trust_remote_code
141

142
143
144
145
        self.in_queue = mp.Queue()
        self.out_queue = mp.Queue()

        self.model_proc = mp.Process(
146
147
148
149
150
151
152
153
154
155
            target=self.start_model_process,
            args=(
                self.in_queue,
                self.out_queue,
                model_path,
                torch_dtype,
            ),
        )
        self.model_proc.start()

156
157
158
159
160
161
162
163
    def needs_trust_remote_code(self, model_path):
        models_needs_trust_remote = [
            "LxzGordon/URM-LLaMa-3.1-8B",
        ]
        if model_path in models_needs_trust_remote:
            return True
        return False

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    # copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py

    def _get_gme_qwen2_vl_embeddings(
        self, prompts, image_data: Optional[List[str]] = None
    ):

        images = None
        if image_data is not None:
            images = [load_image(image)[0] for image in image_data]

        inputs = self.processor(
            text=prompts,
            images=images,
            padding=True,
            truncation=True,
            max_length=1800,
            return_tensors="pt",
        )
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        with torch.no_grad():
            embeddings = self._forward_gme_qwen2_vl(**inputs)
        return embeddings.tolist()

    def _forward_gme_qwen2_vl(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        pooling_mask: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            inputs_embeds = self.model.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.model.visual.get_dtype())
                image_embeds = self.model.visual(
                    pixel_values, grid_thw=image_grid_thw
                ).to(inputs_embeds.device)
                image_mask = input_ids == self.model.config.image_token_id
                inputs_embeds[image_mask] = image_embeds
            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

211
212
        outputs = self.model(
            input_ids=input_ids,
213
214
215
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
216
217
            output_hidden_states=True,
            return_dict=True,
218
            inputs_embeds=inputs_embeds,
219
            image_grid_thw=image_grid_thw,
220
221
        )

222
        embeddings = outputs.hidden_states[-1][:, -1]
223
224
225
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings.contiguous()

226
    def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
227
228
        # Apply model-specific patches
        monkey_patch_gemma2_sdpa()
229

230
        # Load the model and tokenizer
231
        if self.model_type == "generation":
Kiv Chen's avatar
Kiv Chen committed
232
233
234
235
236
237
            config = AutoConfig.from_pretrained(model_path)
            if model_archs := getattr(config, "architectures"):
                model_cls = getattr(transformers, model_archs[0])
            else:
                model_cls = AutoModelForCausalLM
            self.base_model = model_cls.from_pretrained(
238
239
                model_path,
                torch_dtype=torch_dtype,
240
                trust_remote_code=self.trust_remote_code,
241
242
                low_cpu_mem_usage=True,
            ).cuda()
243
        elif self.model_type == "embedding":
244
245
246
247
248
249
250
251
            if "gme-qwen2-vl" in model_path.lower():
                self.model = AutoModelForVision2Seq.from_pretrained(
                    model_path,
                    torch_dtype=torch_dtype,
                    trust_remote_code=False,
                    low_cpu_mem_usage=True,
                ).cuda()
                self.processor = AutoProcessor.from_pretrained(model_path)
uylnap's avatar
uylnap committed
252
253
254
            elif "clip" in model_path.lower():
                self.model = AutoModel.from_pretrained(model_path).cuda()
                self.processor = AutoProcessor.from_pretrained(model_path)
255
256
257
258
            else:
                self.model = _get_sentence_transformer_embedding_model(
                    model_path, torch_dtype
                )
woodx's avatar
woodx committed
259
        elif self.model_type == "reward" or self.model_type == "cross_encoder":
260
261
262
263
264
265
266
267
268
            from transformers import AutoModelForSequenceClassification

            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
                trust_remote_code=self.needs_trust_remote_code(model_path),
            ).cuda()
        else:
            raise Exception(f"Unrecognized model type {self.model_type}")
269
270
271
272
273
        self.tokenizer = get_tokenizer(
            model_path,
            torch_dtype=torch.dtype,
            trust_remote_code=self.trust_remote_code,
        )
274

275
        # Run forward
276
        while True:
277
278
279
            prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
                in_queue.get()
            )
280
281
282
            if lora_paths is not None:
                assert len(prompts) == len(lora_paths)

283
            if prompts is not None:
284
                if self.model_type == "generation":
285
                    out_queue.put(
286
                        self.forward_generation_raw(
287
                            base_model=self.base_model,
288
289
290
291
292
293
                            prompts=prompts,
                            max_new_tokens=max_new_tokens,
                            tokenizer=self.tokenizer,
                            lora_paths=lora_paths,
                            torch_dtype=torch_dtype,
                            output_str_only=self.output_str_only,
294
                            token_ids_logprob=token_ids_logprob,
295
296
                        )
                    )
297
                elif self.model_type == "embedding":
298
                    assert not self.output_str_only
299
300
                    if "gme-qwen2-vl" in model_path.lower():
                        logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
uylnap's avatar
uylnap committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
                    elif "clip" in model_path.lower():
                        if image_data is not None:
                            image = load_image(image_data)
                            inputs = self.processor(
                                images=image[0], return_tensors="pt"
                            )
                            logits = self.model.get_image_features(
                                pixel_values=inputs.data["pixel_values"].cuda(),
                            ).tolist()
                        else:
                            inputs = self.tokenizer(
                                prompts, padding=True, return_tensors="pt"
                            )
                            logits = self.model.get_text_features(
                                input_ids=inputs.data["input_ids"].cuda(),
                                attention_mask=inputs.data["attention_mask"].cuda(),
                            ).tolist()
318
319
                    else:
                        logits = self.model.encode(prompts).tolist()
320
                    out_queue.put(ModelOutput(embed_logits=logits))
woodx's avatar
woodx committed
321
322
323
324
325
326
327
328
329
                elif self.model_type == "cross_encoder":
                    inputs = self.tokenizer(
                        prompts, padding=True, return_tensors="pt"
                    ).to("cuda")
                    scores = self.model(**inputs).logits
                    scores = scores.squeeze().tolist()
                    if not isinstance(scores, list):
                        scores = [scores]
                    out_queue.put(ModelOutput(scores=scores))
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
                elif self.model_type == "reward":
                    scores = []
                    for conv in prompts:
                        conv_formatted = self.tokenizer.apply_chat_template(
                            conv, tokenize=False
                        )
                        conv_tokenized = self.tokenizer(
                            conv_formatted, return_tensors="pt"
                        ).to("cuda")
                        scores.append(
                            float(self.model(**conv_tokenized).logits[0][0].item())
                        )
                    out_queue.put(ModelOutput(scores=scores))
                else:
                    raise Exception(f"Unrecognized model type {self.model_type}")

347
348
    def forward(
        self,
woodx's avatar
woodx committed
349
350
351
        prompts: Union[
            List[List[str]], List[str], List[torch.Tensor]
        ] = DEFAULT_PROMPTS,
352
        image_data: Optional[List[str]] = None,
353
354
355
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        token_ids_logprob: Optional[int] = None,
356
    ):
357
358
359
        self.in_queue.put(
            (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
        )
360
361
362
363
364
365
366
367
368
369
370
371
372
        return self.out_queue.get()

    def terminate(self):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

373
374
375
    @staticmethod
    def forward_generation_raw(
        base_model,
376
377
        prompts: Union[List[str], List[torch.Tensor]],
        max_new_tokens: int,
378
379
        tokenizer,
        torch_dtype: torch.dtype,
380
381
382
        lora_paths: Optional[List[str]] = None,
        output_str_only: bool = False,
        token_ids_logprob: Optional[int] = None,
383
384
385
386
    ) -> ModelOutput:
        output_strs = []
        top_input_logprobs = []
        top_output_logprobs = []
387
388
389
390
391
392
        if token_ids_logprob is not None:
            token_ids_input_logprobs = []
            token_ids_output_logprobs = []
        else:
            token_ids_input_logprobs = token_ids_output_logprobs = None

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        for i, p in enumerate(prompts):
            if isinstance(p, str):
                input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
            else:
                input_ids = torch.tensor([p], device="cuda")

            if lora_paths is not None and lora_paths[i] is not None:
                from peft import PeftModel

                model = PeftModel.from_pretrained(
                    base_model,
                    lora_paths[i],
                    torch_dtype=torch_dtype,
                    is_trainable=False,
                )
            else:
                model = base_model

            outputs = model.generate(
412
413
414
415
416
417
418
419
420
421
422
                input_ids=input_ids,
                generation_config=GenerationConfig(
                    do_sample=False,
                    temperature=None,
                    top_p=None,
                    max_new_tokens=max_new_tokens,
                    return_dict_in_generate=True,
                    output_scores=(not output_str_only),
                    # make sure to disable compile
                    disable_compile=True,
                ),
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            )

            text = tokenizer.decode(
                outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
            )
            # Check if the text is empty or only whitespace.
            if not text.strip():
                raise ValueError(
                    "Received an empty text response. Please verify your input or model configuration."
                )
            output_strs.append(text)

            if not output_str_only:
                # outputs.scores: (num_token, 1, vocab_size)
                top_output_logprobs.append(
                    [
                        get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
                        for logits in outputs.scores
                    ]
                )
443
444
445
446
447
448
449
450
451
                if token_ids_logprob is not None:
                    token_ids_output_logprobs.append(
                        [
                            get_token_ids_logprobs(
                                logits[0], token_ids_logprob
                            ).tolist()
                            for logits in outputs.scores
                        ]
                    )
452
453
454
455
456
457
                del outputs

                input_logits = model.forward(input_ids).logits[0]
                top_input_logprobs.append(
                    get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
                )
458
459
460
461
                if token_ids_logprob is not None:
                    token_ids_input_logprobs.append(
                        get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
                    )
462
463
                del input_logits

464
465
466
            if lora_paths is not None and lora_paths[i] is not None:
                # Unload the LoRA adapter if it is used
                model.unload()
467

468
469
470
471
        return ModelOutput(
            output_strs=output_strs,
            top_input_logprobs=top_input_logprobs,
            top_output_logprobs=top_output_logprobs,
472
473
            token_ids_input_logprobs=token_ids_input_logprobs,
            token_ids_output_logprobs=token_ids_output_logprobs,
474
475
        )

476
477
478
479

class SRTRunner:
    def __init__(
        self,
480
481
482
483
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str,
        tp_size: int = 1,
484
        impl: str = "auto",
485
486
487
        port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
        lora_paths: List[str] = None,
        max_loras_per_batch: int = 4,
488
        attention_backend: Optional[str] = None,
489
        lora_backend: str = "triton",
490
491
        disable_cuda_graph: bool = False,
        disable_radix_cache: bool = False,
492
493
494
495
        chunked_prefill_size: Optional[int] = None,
        dp_size: int = 1,
        tokenizer_path: Optional[str] = None,
        enable_ep_moe: bool = False,
496
        mem_fraction_static: float = 0.65,
497
        trust_remote_code: bool = False,
498
499
500
501
502
503
        speculative_draft_model_path: Optional[str] = None,
        speculative_algorithm: Optional[str] = None,
        speculative_num_steps: Optional[int] = None,
        speculative_eagle_topk: Optional[int] = None,
        speculative_num_draft_tokens: Optional[int] = None,
        disable_overlap_schedule: bool = False,
504
        disable_custom_all_reduce: bool = False,
505
        torchao_config: Optional[str] = None,
506
        sleep_on_idle=False,
507
    ):
508
509
        self.model_type = model_type
        self.is_generation = model_type == "generation"
510
511
512
513
514
515
516
517
518
519
        enable_dp_attention = dp_size > 1

        spec_kwargs = {}
        if speculative_draft_model_path:
            spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
            spec_kwargs["speculative_algorithm"] = speculative_algorithm
            spec_kwargs["speculative_num_steps"] = speculative_num_steps
            spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
            spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens

520
        self.engine = Engine(
521
522
523
            model_path=model_path,
            tp_size=tp_size,
            dtype=get_dtype_str(torch_dtype),
524
            port=port,
525
526
            impl=impl,
            torchao_config=torchao_config,
527
            mem_fraction_static=mem_fraction_static,
528
            trust_remote_code=trust_remote_code,
529
            is_embedding=not self.is_generation,
530
531
            lora_paths=lora_paths,
            max_loras_per_batch=max_loras_per_batch,
532
            lora_backend=lora_backend,
533
            attention_backend=attention_backend,
534
535
            disable_cuda_graph=disable_cuda_graph,
            disable_radix_cache=disable_radix_cache,
536
537
538
539
540
541
542
            chunked_prefill_size=chunked_prefill_size,
            enable_dp_attention=enable_dp_attention,
            dp_size=dp_size,
            tokenizer_path=tokenizer_path,
            enable_ep_moe=enable_ep_moe,
            disable_overlap_schedule=disable_overlap_schedule,
            cuda_graph_max_bs=4,
543
            disable_custom_all_reduce=disable_custom_all_reduce,
544
            sleep_on_idle=sleep_on_idle,
545
            **spec_kwargs,
546
        )
547
548
549
550
551
552
553

        if tokenizer_path is None:
            self.tokenizer = get_tokenizer(
                model_path, trust_remote_code=trust_remote_code
            )
        else:
            self.tokenizer = None
554
555
556

    def forward(
        self,
woodx's avatar
woodx committed
557
558
559
        prompts: Union[
            List[List[str]], List[str], List[torch.Tensor]
        ] = DEFAULT_PROMPTS,
560
        image_data: Optional[List[str]] = None,
561
562
563
564
565
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        logprob_start_len: int = 0,
        top_k: Optional[int] = None,
        token_ids_logprob: Optional[List[int]] = None,
566
    ):
567
        if self.is_generation:
568
            return self.forward_generation_raw(
569
                engine=self.engine,
570
571
572
                prompts=prompts,
                max_new_tokens=max_new_tokens,
                lora_paths=lora_paths,
573
574
575
                logprob_start_len=logprob_start_len,
                top_k=top_k,
                token_ids_logprob=token_ids_logprob,
576
577
            )
        else:
578
            if self.model_type == "embedding":
579
580
581
582
583
                response = self.engine.encode(prompt=prompts, image_data=image_data)
                if isinstance(response, list):
                    logits = [x["embedding"] for x in response]
                else:
                    logits = [response["embedding"]]
584
                return ModelOutput(embed_logits=logits)
woodx's avatar
woodx committed
585
586
587
588
589
590
591
            # cross encoder model
            elif self.model_type == "cross_encoder":
                response = self.engine.rerank(prompts)
                if not isinstance(response, list):
                    response = [response]
                scores = [x["embedding"] for x in response]
                return ModelOutput(scores=scores)
592
            # reward model
593
            else:
594
                response = self.engine.encode(prompts)
595
596
                scores = [x["embedding"][0] for x in response]
                return ModelOutput(scores=scores)
597

598
599
600
    def batch_forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
601
        image_data: Optional[List[str]] = None,
602
603
604
605
606
607
608
609
        max_new_tokens=8,
        lora_paths=None,
    ):
        """
        testing serving by sending all prompts once
        only return output strings and no logprobs
        """
        if self.is_generation:
610
            return self.batch_forward_generation_raw(
611
                engine=self.engine,
612
613
614
                prompts=prompts,
                max_new_tokens=max_new_tokens,
                lora_paths=lora_paths,
615
616
            )
        else:
617
            response = self.engine.encode(prompts, image_data)
618
619
620
621
622
            if self.model_type == "embedding":
                logits = [x["embedding"] for x in response]
                return ModelOutput(embed_logits=logits)
            else:
                scores = [x["embedding"][0] for x in response]
623
                return ModelOutput(scores=scores)
624

625
626
627
628
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
629
630
        self.engine.shutdown()
        del self.engine
631

632
633
    @staticmethod
    def forward_generation_raw(
634
        engine: Engine,
635
        prompts: Union[List[str], List[torch.Tensor]],
636
637
638
639
640
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        logprob_start_len: int = 0,
        top_k: Optional[int] = None,
        token_ids_logprob: Optional[List[int]] = None,
641
642
643
    ):
        # the return value contains logprobs from prefill
        output_strs = []
644
645
646
        output_ids = []
        # Input logprobs. Note that the last item in input logprob is equivalent to
        # the first item in the output logprob.
647
        top_input_logprobs = []
648
        input_token_logprobs_lst = []
649
        top_output_logprobs = []
650
651
652
653
654
655
656
657
        output_token_logprobs_lst = []
        top_output_logprob_idx = []
        if token_ids_logprob is not None:
            token_ids_input_logprobs = []
            token_ids_output_logprobs = []
        else:
            token_ids_input_logprobs = token_ids_output_logprobs = None

658
        sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
659
660
661
        if top_k:
            sampling_params["top_k"] = top_k

662
663
664
665
666
667
        for i, prompt in enumerate(prompts):
            response = engine.generate(
                prompt,
                lora_path=lora_paths[i] if lora_paths else None,
                sampling_params=sampling_params,
                return_logprob=True,
668
                logprob_start_len=logprob_start_len,
669
                top_logprobs_num=NUM_TOP_LOGPROBS,
670
                token_ids_logprob=token_ids_logprob,
671
672
673
674
675
676
677
678
679
            )
            text = response["text"]

            # Check if the text is empty or only whitespace.
            if not text.strip():
                raise ValueError(
                    "Received an empty text response. Please verify your input or model configuration."
                )
            output_strs.append(text)
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
            # output_ids.append(response["output_ids"])

            input_token_logprobs = response["meta_info"]["input_token_logprobs"]
            output_token_logprobs = response["meta_info"]["output_token_logprobs"]
            # print(i, input_token_logprobs)
            # print(i, output_token_logprobs)
            logprobs = response["meta_info"]["input_top_logprobs"]
            if token_ids_logprob is not None:
                input_token_ids_logprobs = response["meta_info"][
                    "input_token_ids_logprobs"
                ][1:]
            else:
                input_token_ids_logprobs = None

            num_prompt_tokens = response["meta_info"]["prompt_tokens"]
            assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
            assert len(logprobs) == num_prompt_tokens - logprob_start_len

            # The first token logprob has no meaning in sglang.
            input_token_logprobs = input_token_logprobs[1:]
            logprobs = logprobs[1:]
            assert len(input_token_logprobs) == len(logprobs)

            input_token_logprobs_lst.append(
                input_token_logprobs + [output_token_logprobs[0]]
            )
            output_token_logprobs_lst.append(output_token_logprobs)
707
708

            top_input_logprobs.append(
709
                [[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                + [
                    [
                        tup[0]
                        for tup in response["meta_info"]["output_top_logprobs"][0][
                            :NUM_TOP_LOGPROBS
                        ]
                    ]
                ]
            )
            top_output_logprobs.append(
                [
                    [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                    for x in response["meta_info"]["output_top_logprobs"]
                ]
            )
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
            top_output_logprob_idx.append(
                [
                    [tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
                    for x in response["meta_info"]["output_top_logprobs"]
                ]
            )
            if token_ids_logprob is not None:
                token_ids_input_logprobs.append(
                    [[tup[0] for tup in x] for x in input_token_ids_logprobs]
                    + [
                        [
                            tup[0]
                            for tup in response["meta_info"][
                                "output_token_ids_logprobs"
                            ][0]
                        ]
                    ]
                )
                token_ids_output_logprobs.append(
                    [
                        [tup[0] for tup in x]
                        for x in response["meta_info"]["output_token_ids_logprobs"]
                    ]
                )
749
750
751

        return ModelOutput(
            output_strs=output_strs,
752
            output_ids=output_ids,
753
754
            top_input_logprobs=top_input_logprobs,
            top_output_logprobs=top_output_logprobs,
755
756
757
758
759
            input_token_logprobs_lst=input_token_logprobs_lst,
            output_token_logprobs_lst=output_token_logprobs_lst,
            top_output_logprob_idx=top_output_logprob_idx,
            token_ids_input_logprobs=token_ids_input_logprobs,
            token_ids_output_logprobs=token_ids_output_logprobs,
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        )

    @staticmethod
    def batch_forward_generation_raw(
        prompts: Union[List[str], List[torch.Tensor]],
        max_new_tokens,
        lora_paths,
        engine,
    ):
        # the return value contains logprobs from prefill
        output_strs = []
        sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
        response = engine.generate(
            prompts,
            lora_path=lora_paths if lora_paths else None,
            sampling_params=sampling_params,
        )
        output_strs = [r["text"] for r in response]

        return ModelOutput(
            output_strs=output_strs,
        )

783
784
785
786
787
788
789
790
791
792
793
794
795
796

def monkey_patch_gemma2_sdpa():
    """
    Use sdpa by default to fix the OOM issue.
    Revert this commit:
    https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660
    """
    from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel

    def _check_and_enable_sdpa(config, hard_check_only: bool = False):
        config._attn_implementation = "sdpa"
        return config

    setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845


def check_close_model_outputs(
    hf_outputs: ModelOutput,
    srt_outputs: ModelOutput,
    prefill_tolerance: float,
    decode_tolerance: float,
    rouge_l_tolerance: float,
    debug_text: str = "",
    check_logprobs: bool = True,
):
    # Compare output strings
    print(f"{hf_outputs.output_strs=}")
    print(f"{srt_outputs.output_strs=}")
    rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
    print(f"{rouge_l_scores=}")
    assert all(
        score >= rouge_l_tolerance for score in rouge_l_scores
    ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"

    if check_logprobs:
        for i in range(len(hf_outputs.output_strs)):
            # Compare input logprobs
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
                    f"prefill logprobs are not all close with {debug_text} "
                    f"prefill_tolerance={prefill_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )

            # Compare output logprobs
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])

            print(
                "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
                    f"decode logprobs are not all close with {debug_text} "
                    f"decode_tolerance={decode_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )