runners.py 29.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import multiprocessing as mp
16
import os
17
from dataclasses import dataclass
18
from typing import List, Optional, Tuple, Union
19
20
21

import torch
import torch.nn.functional as F
uylnap's avatar
uylnap committed
22
23
24
25
26
27
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForVision2Seq,
    AutoProcessor,
)
28

Lianmin Zheng's avatar
Lianmin Zheng committed
29
from sglang.srt.entrypoints.engine import Engine
30
from sglang.srt.hf_transformers_utils import get_tokenizer
uylnap's avatar
uylnap committed
31
from sglang.srt.utils import load_image
32
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
33
34

DEFAULT_PROMPTS = [
35
    "Apple is red. Banana is Yellow. " * 800 + "Apple is",
36
    "The capital of the United Kingdom is",
37
    "Today is a sunny day and I like",
38
    "AI is a field of computer science focused on",
39
40
    # the output of gemma-2-2b from SRT is unstable on the commented prompt
    # "The capital of France is",
41
42
]

43
dirpath = os.path.dirname(__file__)
44
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
45
46
47
    long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)

48
49
50
51
52
53
NUM_TOP_LOGPROBS = 5


def get_dtype_str(torch_dtype):
    if torch_dtype is torch.float16:
        return "float16"
54
55
    if torch_dtype is torch.float32:
        return "float32"
56
57
58
59
    else:
        raise NotImplementedError()


60
61
def get_top_logprobs(logits, k):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
62
    del logits
63
64
65
66
    logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
    return logprobs


67
68
69
70
71
72
73
def get_token_ids_logprobs(logits, token_ids):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    del logits
    logprobs = logprobs[..., token_ids]
    return logprobs


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
    from sentence_transformers import SentenceTransformer
    from sentence_transformers.util import is_sentence_transformer_model

    if is_sentence_transformer_model(model_path):
        model = SentenceTransformer(
            model_path,
            model_kwargs={"torch_dtype": torch_dtype},
        )
    else:  # if no pre-trained sentence-transformers model
        from sentence_transformers import models

        word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
        pooling_model = models.Pooling(
            word_embedding_model.get_word_embedding_dimension(),
            pooling_mode="lasttoken",
        )
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    return model.cuda()


96
97
@dataclass
class ModelOutput:
98
99
100
101
    output_strs: List[str] = None
    output_ids: List[int] = None
    top_input_logprobs: List[torch.Tensor] = None
    top_output_logprobs: List[torch.Tensor] = None
102
    top_output_logprob_idx: List[List[int]] = None
103
    embed_logits: List[torch.Tensor] = None
104
    scores: List[float] = None
105
106
107
108
    input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
    output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
    token_ids_input_logprobs: List[torch.Tensor] = None
    token_ids_output_logprobs: List[torch.Tensor] = None
109
110
111
112
113


class HFRunner:
    def __init__(
        self,
114
115
116
117
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str = "generation",
        output_str_only: bool = False,
118
        trust_remote_code: bool = False,
119
    ):
120
        self.model_type = model_type
121
        self.output_str_only = output_str_only
122
        self.trust_remote_code = trust_remote_code
123

124
125
126
127
        self.in_queue = mp.Queue()
        self.out_queue = mp.Queue()

        self.model_proc = mp.Process(
128
129
130
131
132
133
134
135
136
137
            target=self.start_model_process,
            args=(
                self.in_queue,
                self.out_queue,
                model_path,
                torch_dtype,
            ),
        )
        self.model_proc.start()

138
139
140
141
142
143
144
145
    def needs_trust_remote_code(self, model_path):
        models_needs_trust_remote = [
            "LxzGordon/URM-LLaMa-3.1-8B",
        ]
        if model_path in models_needs_trust_remote:
            return True
        return False

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    # copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py

    def _get_gme_qwen2_vl_embeddings(
        self, prompts, image_data: Optional[List[str]] = None
    ):

        images = None
        if image_data is not None:
            images = [load_image(image)[0] for image in image_data]

        inputs = self.processor(
            text=prompts,
            images=images,
            padding=True,
            truncation=True,
            max_length=1800,
            return_tensors="pt",
        )
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        with torch.no_grad():
            embeddings = self._forward_gme_qwen2_vl(**inputs)
        return embeddings.tolist()

    def _forward_gme_qwen2_vl(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        pooling_mask: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            inputs_embeds = self.model.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.model.visual.get_dtype())
                image_embeds = self.model.visual(
                    pixel_values, grid_thw=image_grid_thw
                ).to(inputs_embeds.device)
                image_mask = input_ids == self.model.config.image_token_id
                inputs_embeds[image_mask] = image_embeds
            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        outputs = self.model.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
        )

        pooling_mask = attention_mask if pooling_mask is None else pooling_mask
        left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0]  # TODO
        if left_padding:
            embeddings = outputs.last_hidden_state[:, -1]
        else:
            sequence_lengths = pooling_mask.sum(dim=1) - 1
            batch_size = outputs.last_hidden_state.shape[0]
            embeddings = outputs.last_hidden_state[
                torch.arange(batch_size, device=outputs.last_hidden_state.device),
                sequence_lengths,
            ]
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings.contiguous()

215
    def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
216
217
        # Apply model-specific patches
        monkey_patch_gemma2_sdpa()
218

219
        # Load the model and tokenizer
220
        if self.model_type == "generation":
221
            self.base_model = AutoModelForCausalLM.from_pretrained(
222
223
                model_path,
                torch_dtype=torch_dtype,
224
                trust_remote_code=self.trust_remote_code,
225
226
                low_cpu_mem_usage=True,
            ).cuda()
227
        elif self.model_type == "embedding":
228
229
230
231
232
233
234
235
            if "gme-qwen2-vl" in model_path.lower():
                self.model = AutoModelForVision2Seq.from_pretrained(
                    model_path,
                    torch_dtype=torch_dtype,
                    trust_remote_code=False,
                    low_cpu_mem_usage=True,
                ).cuda()
                self.processor = AutoProcessor.from_pretrained(model_path)
uylnap's avatar
uylnap committed
236
237
238
            elif "clip" in model_path.lower():
                self.model = AutoModel.from_pretrained(model_path).cuda()
                self.processor = AutoProcessor.from_pretrained(model_path)
239
240
241
242
            else:
                self.model = _get_sentence_transformer_embedding_model(
                    model_path, torch_dtype
                )
243
244
245
246
247
248
249
250
251
252
        elif self.model_type == "reward":
            from transformers import AutoModelForSequenceClassification

            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
                trust_remote_code=self.needs_trust_remote_code(model_path),
            ).cuda()
        else:
            raise Exception(f"Unrecognized model type {self.model_type}")
253
254
255
256
257
        self.tokenizer = get_tokenizer(
            model_path,
            torch_dtype=torch.dtype,
            trust_remote_code=self.trust_remote_code,
        )
258

259
        # Run forward
260
        while True:
261
262
263
            prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
                in_queue.get()
            )
264
265
266
            if lora_paths is not None:
                assert len(prompts) == len(lora_paths)

267
            if prompts is not None:
268
                if self.model_type == "generation":
269
                    out_queue.put(
270
                        self.forward_generation_raw(
271
                            base_model=self.base_model,
272
273
274
275
276
277
                            prompts=prompts,
                            max_new_tokens=max_new_tokens,
                            tokenizer=self.tokenizer,
                            lora_paths=lora_paths,
                            torch_dtype=torch_dtype,
                            output_str_only=self.output_str_only,
278
                            token_ids_logprob=token_ids_logprob,
279
280
                        )
                    )
281
                elif self.model_type == "embedding":
282
                    assert not self.output_str_only
283
284
                    if "gme-qwen2-vl" in model_path.lower():
                        logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
uylnap's avatar
uylnap committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
                    elif "clip" in model_path.lower():
                        if image_data is not None:
                            image = load_image(image_data)
                            inputs = self.processor(
                                images=image[0], return_tensors="pt"
                            )
                            logits = self.model.get_image_features(
                                pixel_values=inputs.data["pixel_values"].cuda(),
                            ).tolist()
                        else:
                            inputs = self.tokenizer(
                                prompts, padding=True, return_tensors="pt"
                            )
                            logits = self.model.get_text_features(
                                input_ids=inputs.data["input_ids"].cuda(),
                                attention_mask=inputs.data["attention_mask"].cuda(),
                            ).tolist()
302
303
                    else:
                        logits = self.model.encode(prompts).tolist()
304
305
                    out_queue.put(ModelOutput(embed_logits=logits))

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                elif self.model_type == "reward":
                    scores = []
                    for conv in prompts:
                        conv_formatted = self.tokenizer.apply_chat_template(
                            conv, tokenize=False
                        )
                        conv_tokenized = self.tokenizer(
                            conv_formatted, return_tensors="pt"
                        ).to("cuda")
                        scores.append(
                            float(self.model(**conv_tokenized).logits[0][0].item())
                        )
                    out_queue.put(ModelOutput(scores=scores))
                else:
                    raise Exception(f"Unrecognized model type {self.model_type}")

322
323
324
    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
325
        image_data: Optional[List[str]] = None,
326
327
328
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        token_ids_logprob: Optional[int] = None,
329
    ):
330
331
332
        self.in_queue.put(
            (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
        )
333
334
335
336
337
338
339
340
341
342
343
344
345
        return self.out_queue.get()

    def terminate(self):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

346
347
348
    @staticmethod
    def forward_generation_raw(
        base_model,
349
350
        prompts: Union[List[str], List[torch.Tensor]],
        max_new_tokens: int,
351
352
        tokenizer,
        torch_dtype: torch.dtype,
353
354
355
        lora_paths: Optional[List[str]] = None,
        output_str_only: bool = False,
        token_ids_logprob: Optional[int] = None,
356
357
358
359
    ) -> ModelOutput:
        output_strs = []
        top_input_logprobs = []
        top_output_logprobs = []
360
361
362
363
364
365
        if token_ids_logprob is not None:
            token_ids_input_logprobs = []
            token_ids_output_logprobs = []
        else:
            token_ids_input_logprobs = token_ids_output_logprobs = None

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        for i, p in enumerate(prompts):
            if isinstance(p, str):
                input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
            else:
                input_ids = torch.tensor([p], device="cuda")

            if lora_paths is not None and lora_paths[i] is not None:
                from peft import PeftModel

                model = PeftModel.from_pretrained(
                    base_model,
                    lora_paths[i],
                    torch_dtype=torch_dtype,
                    is_trainable=False,
                )
            else:
                model = base_model

            outputs = model.generate(
                input_ids,
                do_sample=False,
                temperature=None,
                top_p=None,
                max_new_tokens=max_new_tokens,
                return_dict_in_generate=True,
                output_scores=(not output_str_only),
            )

            text = tokenizer.decode(
                outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
            )
            # Check if the text is empty or only whitespace.
            if not text.strip():
                raise ValueError(
                    "Received an empty text response. Please verify your input or model configuration."
                )
            output_strs.append(text)

            if not output_str_only:
                # outputs.scores: (num_token, 1, vocab_size)
                top_output_logprobs.append(
                    [
                        get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
                        for logits in outputs.scores
                    ]
                )
412
413
414
415
416
417
418
419
420
                if token_ids_logprob is not None:
                    token_ids_output_logprobs.append(
                        [
                            get_token_ids_logprobs(
                                logits[0], token_ids_logprob
                            ).tolist()
                            for logits in outputs.scores
                        ]
                    )
421
422
423
424
425
426
                del outputs

                input_logits = model.forward(input_ids).logits[0]
                top_input_logprobs.append(
                    get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
                )
427
428
429
430
                if token_ids_logprob is not None:
                    token_ids_input_logprobs.append(
                        get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
                    )
431
432
433
434
435
436
                del input_logits

        return ModelOutput(
            output_strs=output_strs,
            top_input_logprobs=top_input_logprobs,
            top_output_logprobs=top_output_logprobs,
437
438
            token_ids_input_logprobs=token_ids_input_logprobs,
            token_ids_output_logprobs=token_ids_output_logprobs,
439
440
        )

441
442
443
444

class SRTRunner:
    def __init__(
        self,
445
446
447
448
449
450
451
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str,
        tp_size: int = 1,
        port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
        lora_paths: List[str] = None,
        max_loras_per_batch: int = 4,
452
        attention_backend: Optional[str] = None,
453
        lora_backend: str = "triton",
454
455
        disable_cuda_graph: bool = False,
        disable_radix_cache: bool = False,
456
457
458
459
        chunked_prefill_size: Optional[int] = None,
        dp_size: int = 1,
        tokenizer_path: Optional[str] = None,
        enable_ep_moe: bool = False,
460
        mem_fraction_static: float = 0.65,
461
        trust_remote_code: bool = False,
462
463
464
465
466
467
        speculative_draft_model_path: Optional[str] = None,
        speculative_algorithm: Optional[str] = None,
        speculative_num_steps: Optional[int] = None,
        speculative_eagle_topk: Optional[int] = None,
        speculative_num_draft_tokens: Optional[int] = None,
        disable_overlap_schedule: bool = False,
468
        disable_custom_all_reduce: bool = False,
469
    ):
470
471
        self.model_type = model_type
        self.is_generation = model_type == "generation"
472
473
474
475
476
477
478
479
480
481
        enable_dp_attention = dp_size > 1

        spec_kwargs = {}
        if speculative_draft_model_path:
            spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
            spec_kwargs["speculative_algorithm"] = speculative_algorithm
            spec_kwargs["speculative_num_steps"] = speculative_num_steps
            spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
            spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens

482
        self.engine = Engine(
483
484
485
            model_path=model_path,
            tp_size=tp_size,
            dtype=get_dtype_str(torch_dtype),
486
            port=port,
487
            mem_fraction_static=mem_fraction_static,
488
            trust_remote_code=trust_remote_code,
489
            is_embedding=not self.is_generation,
490
491
            lora_paths=lora_paths,
            max_loras_per_batch=max_loras_per_batch,
492
            lora_backend=lora_backend,
493
            attention_backend=attention_backend,
494
495
            disable_cuda_graph=disable_cuda_graph,
            disable_radix_cache=disable_radix_cache,
496
497
498
499
500
501
502
            chunked_prefill_size=chunked_prefill_size,
            enable_dp_attention=enable_dp_attention,
            dp_size=dp_size,
            tokenizer_path=tokenizer_path,
            enable_ep_moe=enable_ep_moe,
            disable_overlap_schedule=disable_overlap_schedule,
            cuda_graph_max_bs=4,
503
            disable_custom_all_reduce=disable_custom_all_reduce,
504
            **spec_kwargs,
505
        )
506
507
508
509
510
511
512

        if tokenizer_path is None:
            self.tokenizer = get_tokenizer(
                model_path, trust_remote_code=trust_remote_code
            )
        else:
            self.tokenizer = None
513
514
515
516

    def forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
517
        image_data: Optional[List[str]] = None,
518
519
520
521
522
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        logprob_start_len: int = 0,
        top_k: Optional[int] = None,
        token_ids_logprob: Optional[List[int]] = None,
523
    ):
524
        if self.is_generation:
525
            return self.forward_generation_raw(
526
                engine=self.engine,
527
528
529
                prompts=prompts,
                max_new_tokens=max_new_tokens,
                lora_paths=lora_paths,
530
531
532
                logprob_start_len=logprob_start_len,
                top_k=top_k,
                token_ids_logprob=token_ids_logprob,
533
534
            )
        else:
535
            if self.model_type == "embedding":
536
537
538
539
540
                response = self.engine.encode(prompt=prompts, image_data=image_data)
                if isinstance(response, list):
                    logits = [x["embedding"] for x in response]
                else:
                    logits = [response["embedding"]]
541
                return ModelOutput(embed_logits=logits)
542
            # reward model
543
            else:
544
                response = self.engine.encode(prompts)
545
546
                scores = [x["embedding"][0] for x in response]
                return ModelOutput(scores=scores)
547

548
549
550
    def batch_forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
551
        image_data: Optional[List[str]] = None,
552
553
554
555
556
557
558
559
        max_new_tokens=8,
        lora_paths=None,
    ):
        """
        testing serving by sending all prompts once
        only return output strings and no logprobs
        """
        if self.is_generation:
560
            return self.batch_forward_generation_raw(
561
                engine=self.engine,
562
563
564
                prompts=prompts,
                max_new_tokens=max_new_tokens,
                lora_paths=lora_paths,
565
566
            )
        else:
567
            response = self.engine.encode(prompts, image_data)
568
569
570
571
572
            if self.model_type == "embedding":
                logits = [x["embedding"] for x in response]
                return ModelOutput(embed_logits=logits)
            else:
                scores = [x["embedding"][0] for x in response]
573
                return ModelOutput(scores=scores)
574

575
576
577
578
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
579
580
        self.engine.shutdown()
        del self.engine
581

582
583
    @staticmethod
    def forward_generation_raw(
584
        engine: Engine,
585
        prompts: Union[List[str], List[torch.Tensor]],
586
587
588
589
590
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        logprob_start_len: int = 0,
        top_k: Optional[int] = None,
        token_ids_logprob: Optional[List[int]] = None,
591
592
593
    ):
        # the return value contains logprobs from prefill
        output_strs = []
594
595
596
        output_ids = []
        # Input logprobs. Note that the last item in input logprob is equivalent to
        # the first item in the output logprob.
597
        top_input_logprobs = []
598
        input_token_logprobs_lst = []
599
        top_output_logprobs = []
600
601
602
603
604
605
606
607
        output_token_logprobs_lst = []
        top_output_logprob_idx = []
        if token_ids_logprob is not None:
            token_ids_input_logprobs = []
            token_ids_output_logprobs = []
        else:
            token_ids_input_logprobs = token_ids_output_logprobs = None

608
        sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
609
610
611
        if top_k:
            sampling_params["top_k"] = top_k

612
613
614
615
616
617
        for i, prompt in enumerate(prompts):
            response = engine.generate(
                prompt,
                lora_path=lora_paths[i] if lora_paths else None,
                sampling_params=sampling_params,
                return_logprob=True,
618
                logprob_start_len=logprob_start_len,
619
                top_logprobs_num=NUM_TOP_LOGPROBS,
620
                token_ids_logprob=token_ids_logprob,
621
622
623
624
625
626
627
628
629
            )
            text = response["text"]

            # Check if the text is empty or only whitespace.
            if not text.strip():
                raise ValueError(
                    "Received an empty text response. Please verify your input or model configuration."
                )
            output_strs.append(text)
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
            # output_ids.append(response["output_ids"])

            input_token_logprobs = response["meta_info"]["input_token_logprobs"]
            output_token_logprobs = response["meta_info"]["output_token_logprobs"]
            # print(i, input_token_logprobs)
            # print(i, output_token_logprobs)
            logprobs = response["meta_info"]["input_top_logprobs"]
            if token_ids_logprob is not None:
                input_token_ids_logprobs = response["meta_info"][
                    "input_token_ids_logprobs"
                ][1:]
            else:
                input_token_ids_logprobs = None

            num_prompt_tokens = response["meta_info"]["prompt_tokens"]
            assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
            assert len(logprobs) == num_prompt_tokens - logprob_start_len

            # The first token logprob has no meaning in sglang.
            input_token_logprobs = input_token_logprobs[1:]
            logprobs = logprobs[1:]
            assert len(input_token_logprobs) == len(logprobs)

            input_token_logprobs_lst.append(
                input_token_logprobs + [output_token_logprobs[0]]
            )
            output_token_logprobs_lst.append(output_token_logprobs)
657
658

            top_input_logprobs.append(
659
                [[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
                + [
                    [
                        tup[0]
                        for tup in response["meta_info"]["output_top_logprobs"][0][
                            :NUM_TOP_LOGPROBS
                        ]
                    ]
                ]
            )
            top_output_logprobs.append(
                [
                    [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                    for x in response["meta_info"]["output_top_logprobs"]
                ]
            )
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
            top_output_logprob_idx.append(
                [
                    [tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
                    for x in response["meta_info"]["output_top_logprobs"]
                ]
            )
            if token_ids_logprob is not None:
                token_ids_input_logprobs.append(
                    [[tup[0] for tup in x] for x in input_token_ids_logprobs]
                    + [
                        [
                            tup[0]
                            for tup in response["meta_info"][
                                "output_token_ids_logprobs"
                            ][0]
                        ]
                    ]
                )
                token_ids_output_logprobs.append(
                    [
                        [tup[0] for tup in x]
                        for x in response["meta_info"]["output_token_ids_logprobs"]
                    ]
                )
699
700
701

        return ModelOutput(
            output_strs=output_strs,
702
            output_ids=output_ids,
703
704
            top_input_logprobs=top_input_logprobs,
            top_output_logprobs=top_output_logprobs,
705
706
707
708
709
            input_token_logprobs_lst=input_token_logprobs_lst,
            output_token_logprobs_lst=output_token_logprobs_lst,
            top_output_logprob_idx=top_output_logprob_idx,
            token_ids_input_logprobs=token_ids_input_logprobs,
            token_ids_output_logprobs=token_ids_output_logprobs,
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        )

    @staticmethod
    def batch_forward_generation_raw(
        prompts: Union[List[str], List[torch.Tensor]],
        max_new_tokens,
        lora_paths,
        engine,
    ):
        # the return value contains logprobs from prefill
        output_strs = []
        sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
        response = engine.generate(
            prompts,
            lora_path=lora_paths if lora_paths else None,
            sampling_params=sampling_params,
        )
        output_strs = [r["text"] for r in response]

        return ModelOutput(
            output_strs=output_strs,
        )

733
734
735
736
737
738
739
740
741
742
743
744
745
746

def monkey_patch_gemma2_sdpa():
    """
    Use sdpa by default to fix the OOM issue.
    Revert this commit:
    https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660
    """
    from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel

    def _check_and_enable_sdpa(config, hard_check_only: bool = False):
        config._attn_implementation = "sdpa"
        return config

    setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795


def check_close_model_outputs(
    hf_outputs: ModelOutput,
    srt_outputs: ModelOutput,
    prefill_tolerance: float,
    decode_tolerance: float,
    rouge_l_tolerance: float,
    debug_text: str = "",
    check_logprobs: bool = True,
):
    # Compare output strings
    print(f"{hf_outputs.output_strs=}")
    print(f"{srt_outputs.output_strs=}")
    rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
    print(f"{rouge_l_scores=}")
    assert all(
        score >= rouge_l_tolerance for score in rouge_l_scores
    ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"

    if check_logprobs:
        for i in range(len(hf_outputs.output_strs)):
            # Compare input logprobs
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
                    f"prefill logprobs are not all close with {debug_text} "
                    f"prefill_tolerance={prefill_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )

            # Compare output logprobs
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])

            print(
                "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
                    f"decode logprobs are not all close with {debug_text} "
                    f"decode_tolerance={decode_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )