runners.py 34.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import multiprocessing as mp
17
import os
18
from dataclasses import dataclass
19
from typing import Any, List, Optional, Tuple, Union
20
21
22

import torch
import torch.nn.functional as F
Kiv Chen's avatar
Kiv Chen committed
23
import transformers
uylnap's avatar
uylnap committed
24
from transformers import (
Kiv Chen's avatar
Kiv Chen committed
25
    AutoConfig,
uylnap's avatar
uylnap committed
26
27
28
29
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForVision2Seq,
    AutoProcessor,
30
    GenerationConfig,
uylnap's avatar
uylnap committed
31
)
32

Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.entrypoints.engine import Engine
uylnap's avatar
uylnap committed
34
from sglang.srt.utils import load_image
35
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
36
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
37
38

DEFAULT_PROMPTS = [
39
    "Apple is red. Banana is Yellow. " * 800 + "Apple is",
40
    "The capital of the United Kingdom is",
41
    "Today is a sunny day and I like",
42
    "AI is a field of computer science focused on",
43
44
    # the output of gemma-2-2b from SRT is unstable on the commented prompt
    # "The capital of France is",
45
]
woodx's avatar
woodx committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
TEST_RERANK_QUERY_DOCS = [
    {
        "query": "How many people live in Berlin?",
        "documents": [
            "Berlin is well known for its museums.",
        ],
    },
    {
        "query": "How many people live in Berlin?",
        "documents": [
            "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
            "Berlin is well known for its museums.",
        ],
    },
]
61

62
dirpath = os.path.dirname(__file__)
63
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
64
65
66
    long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)

67
68
69
70
71
72
NUM_TOP_LOGPROBS = 5


def get_dtype_str(torch_dtype):
    if torch_dtype is torch.float16:
        return "float16"
73
74
    if torch_dtype is torch.float32:
        return "float32"
75
76
77
78
    else:
        raise NotImplementedError()


79
80
def get_top_logprobs(logits, k):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
81
    del logits
82
83
84
85
    logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
    return logprobs


86
87
88
89
90
91
92
def get_token_ids_logprobs(logits, token_ids):
    logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    del logits
    logprobs = logprobs[..., token_ids]
    return logprobs


93
94
95
def _get_sentence_transformer_embedding_model(
    model_path, torch_dtype, matryoshka_dim: Optional[int] = None
):
96
97
98
99
100
101
102
    from sentence_transformers import SentenceTransformer
    from sentence_transformers.util import is_sentence_transformer_model

    if is_sentence_transformer_model(model_path):
        model = SentenceTransformer(
            model_path,
            model_kwargs={"torch_dtype": torch_dtype},
103
            truncate_dim=matryoshka_dim,
104
105
106
107
108
109
110
111
112
        )
    else:  # if no pre-trained sentence-transformers model
        from sentence_transformers import models

        word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
        pooling_model = models.Pooling(
            word_embedding_model.get_word_embedding_dimension(),
            pooling_mode="lasttoken",
        )
113
114
115
        model = SentenceTransformer(
            modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim
        )
116
117
118
119

    return model.cuda()


120
121
@dataclass
class ModelOutput:
122
123
124
125
    output_strs: List[str] = None
    output_ids: List[int] = None
    top_input_logprobs: List[torch.Tensor] = None
    top_output_logprobs: List[torch.Tensor] = None
126
    top_output_logprob_idx: List[List[int]] = None
127
    embed_logits: List[torch.Tensor] = None
128
    scores: List[float] = None
129
130
131
132
    input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
    output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
    token_ids_input_logprobs: List[torch.Tensor] = None
    token_ids_output_logprobs: List[torch.Tensor] = None
133
134
135
136
137


class HFRunner:
    def __init__(
        self,
138
139
140
141
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str = "generation",
        output_str_only: bool = False,
142
        trust_remote_code: bool = False,
143
        patch_model_do_sample_false: bool = False,
144
        matryoshka_dim: Optional[int] = None,
145
    ):
146
        self.model_type = model_type
147
        self.output_str_only = output_str_only
148
        self.trust_remote_code = trust_remote_code
149
        self.patch_model_do_sample_false = patch_model_do_sample_false
150

151
152
153
154
        self.in_queue = mp.Queue()
        self.out_queue = mp.Queue()

        self.model_proc = mp.Process(
155
156
157
158
159
160
            target=self.start_model_process,
            args=(
                self.in_queue,
                self.out_queue,
                model_path,
                torch_dtype,
161
                matryoshka_dim,
162
163
164
165
            ),
        )
        self.model_proc.start()

166
167
168
169
170
171
172
173
    def needs_trust_remote_code(self, model_path):
        models_needs_trust_remote = [
            "LxzGordon/URM-LLaMa-3.1-8B",
        ]
        if model_path in models_needs_trust_remote:
            return True
        return False

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    # copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py

    def _get_gme_qwen2_vl_embeddings(
        self, prompts, image_data: Optional[List[str]] = None
    ):

        images = None
        if image_data is not None:
            images = [load_image(image)[0] for image in image_data]

        inputs = self.processor(
            text=prompts,
            images=images,
            padding=True,
            truncation=True,
            max_length=1800,
            return_tensors="pt",
        )
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        with torch.no_grad():
            embeddings = self._forward_gme_qwen2_vl(**inputs)
        return embeddings.tolist()

    def _forward_gme_qwen2_vl(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        pooling_mask: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            inputs_embeds = self.model.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.model.visual.get_dtype())
                image_embeds = self.model.visual(
                    pixel_values, grid_thw=image_grid_thw
                ).to(inputs_embeds.device)
                image_mask = input_ids == self.model.config.image_token_id
                inputs_embeds[image_mask] = image_embeds
            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

221
222
        outputs = self.model(
            input_ids=input_ids,
223
224
225
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
226
227
            output_hidden_states=True,
            return_dict=True,
228
            inputs_embeds=inputs_embeds,
229
            image_grid_thw=image_grid_thw,
230
231
        )

232
        embeddings = outputs.hidden_states[-1][:, -1]
233
234
235
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings.contiguous()

236
237
238
239
240
241
242
243
    def start_model_process(
        self,
        in_queue,
        out_queue,
        model_path,
        torch_dtype,
        matryoshka_dim: Optional[int] = None,
    ):
244
245
        # Apply model-specific patches
        monkey_patch_gemma2_sdpa()
246

247
        # Load the model and tokenizer
248
        if self.model_type == "generation":
249
250
251
252
            config = AutoConfig.from_pretrained(
                model_path, trust_remote_code=self.trust_remote_code
            )
            if self.trust_remote_code:
Kiv Chen's avatar
Kiv Chen committed
253
                model_cls = AutoModelForCausalLM
254
255
256
            else:
                model_arch = getattr(config, "architectures")[0]
                model_cls = getattr(transformers, model_arch)
Kiv Chen's avatar
Kiv Chen committed
257
            self.base_model = model_cls.from_pretrained(
258
259
                model_path,
                torch_dtype=torch_dtype,
260
                trust_remote_code=self.trust_remote_code,
261
262
                low_cpu_mem_usage=True,
            ).cuda()
263
        elif self.model_type == "embedding":
264
265
266
267
268
269
270
271
            if "gme-qwen2-vl" in model_path.lower():
                self.model = AutoModelForVision2Seq.from_pretrained(
                    model_path,
                    torch_dtype=torch_dtype,
                    trust_remote_code=False,
                    low_cpu_mem_usage=True,
                ).cuda()
                self.processor = AutoProcessor.from_pretrained(model_path)
uylnap's avatar
uylnap committed
272
273
274
            elif "clip" in model_path.lower():
                self.model = AutoModel.from_pretrained(model_path).cuda()
                self.processor = AutoProcessor.from_pretrained(model_path)
275
276
            else:
                self.model = _get_sentence_transformer_embedding_model(
277
                    model_path, torch_dtype, matryoshka_dim=matryoshka_dim
278
                )
woodx's avatar
woodx committed
279
        elif self.model_type == "reward" or self.model_type == "cross_encoder":
280
281
282
283
284
285
286
287
288
            from transformers import AutoModelForSequenceClassification

            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=torch_dtype,
                trust_remote_code=self.needs_trust_remote_code(model_path),
            ).cuda()
        else:
            raise Exception(f"Unrecognized model type {self.model_type}")
289
290
291
292
293
        self.tokenizer = get_tokenizer(
            model_path,
            torch_dtype=torch.dtype,
            trust_remote_code=self.trust_remote_code,
        )
294

295
        # Run forward
296
        while True:
297
298
299
            prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
                in_queue.get()
            )
300
301
302
            if lora_paths is not None:
                assert len(prompts) == len(lora_paths)

303
            if prompts is not None:
304
                if self.model_type == "generation":
305
                    out_queue.put(
306
                        self.forward_generation_raw(
307
                            base_model=self.base_model,
308
309
310
311
312
313
                            prompts=prompts,
                            max_new_tokens=max_new_tokens,
                            tokenizer=self.tokenizer,
                            lora_paths=lora_paths,
                            torch_dtype=torch_dtype,
                            output_str_only=self.output_str_only,
314
                            token_ids_logprob=token_ids_logprob,
315
                            patch_model_do_sample_false=self.patch_model_do_sample_false,
316
317
                        )
                    )
318
                elif self.model_type == "embedding":
319
                    assert not self.output_str_only
320
321
                    if "gme-qwen2-vl" in model_path.lower():
                        logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
uylnap's avatar
uylnap committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
                    elif "clip" in model_path.lower():
                        if image_data is not None:
                            image = load_image(image_data)
                            inputs = self.processor(
                                images=image[0], return_tensors="pt"
                            )
                            logits = self.model.get_image_features(
                                pixel_values=inputs.data["pixel_values"].cuda(),
                            ).tolist()
                        else:
                            inputs = self.tokenizer(
                                prompts, padding=True, return_tensors="pt"
                            )
                            logits = self.model.get_text_features(
                                input_ids=inputs.data["input_ids"].cuda(),
                                attention_mask=inputs.data["attention_mask"].cuda(),
                            ).tolist()
339
340
                    else:
                        logits = self.model.encode(prompts).tolist()
341
                    out_queue.put(ModelOutput(embed_logits=logits))
woodx's avatar
woodx committed
342
343
344
345
346
347
348
349
350
                elif self.model_type == "cross_encoder":
                    inputs = self.tokenizer(
                        prompts, padding=True, return_tensors="pt"
                    ).to("cuda")
                    scores = self.model(**inputs).logits
                    scores = scores.squeeze().tolist()
                    if not isinstance(scores, list):
                        scores = [scores]
                    out_queue.put(ModelOutput(scores=scores))
351

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
                elif self.model_type == "reward":
                    scores = []
                    for conv in prompts:
                        conv_formatted = self.tokenizer.apply_chat_template(
                            conv, tokenize=False
                        )
                        conv_tokenized = self.tokenizer(
                            conv_formatted, return_tensors="pt"
                        ).to("cuda")
                        scores.append(
                            float(self.model(**conv_tokenized).logits[0][0].item())
                        )
                    out_queue.put(ModelOutput(scores=scores))
                else:
                    raise Exception(f"Unrecognized model type {self.model_type}")

368
369
    def forward(
        self,
woodx's avatar
woodx committed
370
371
372
        prompts: Union[
            List[List[str]], List[str], List[torch.Tensor]
        ] = DEFAULT_PROMPTS,
373
        image_data: Optional[List[str]] = None,
374
375
376
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        token_ids_logprob: Optional[int] = None,
377
    ):
378
379
380
        self.in_queue.put(
            (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
        )
381
382
383
384
385
386
387
388
389
390
391
392
393
        return self.out_queue.get()

    def terminate(self):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.model_proc.terminate()
        self.in_queue = self.out_queue = None

394
395
396
    @staticmethod
    def forward_generation_raw(
        base_model,
397
398
        prompts: Union[List[str], List[torch.Tensor]],
        max_new_tokens: int,
399
400
        tokenizer,
        torch_dtype: torch.dtype,
401
402
403
        lora_paths: Optional[List[str]] = None,
        output_str_only: bool = False,
        token_ids_logprob: Optional[int] = None,
404
        patch_model_do_sample_false: Optional[bool] = False,
405
406
407
408
    ) -> ModelOutput:
        output_strs = []
        top_input_logprobs = []
        top_output_logprobs = []
409
410
411
412
413
414
        if token_ids_logprob is not None:
            token_ids_input_logprobs = []
            token_ids_output_logprobs = []
        else:
            token_ids_input_logprobs = token_ids_output_logprobs = None

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        for i, p in enumerate(prompts):
            if isinstance(p, str):
                input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
            else:
                input_ids = torch.tensor([p], device="cuda")

            if lora_paths is not None and lora_paths[i] is not None:
                from peft import PeftModel

                model = PeftModel.from_pretrained(
                    base_model,
                    lora_paths[i],
                    torch_dtype=torch_dtype,
                    is_trainable=False,
                )
            else:
                model = base_model
432
433
            if patch_model_do_sample_false:
                model.generation_config.do_sample = False
434
            outputs = model.generate(
435
436
437
438
439
440
441
442
443
444
445
                input_ids=input_ids,
                generation_config=GenerationConfig(
                    do_sample=False,
                    temperature=None,
                    top_p=None,
                    max_new_tokens=max_new_tokens,
                    return_dict_in_generate=True,
                    output_scores=(not output_str_only),
                    # make sure to disable compile
                    disable_compile=True,
                ),
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            )

            text = tokenizer.decode(
                outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
            )
            # Check if the text is empty or only whitespace.
            if not text.strip():
                raise ValueError(
                    "Received an empty text response. Please verify your input or model configuration."
                )
            output_strs.append(text)

            if not output_str_only:
                # outputs.scores: (num_token, 1, vocab_size)
                top_output_logprobs.append(
                    [
                        get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
                        for logits in outputs.scores
                    ]
                )
466
467
468
469
470
471
472
473
474
                if token_ids_logprob is not None:
                    token_ids_output_logprobs.append(
                        [
                            get_token_ids_logprobs(
                                logits[0], token_ids_logprob
                            ).tolist()
                            for logits in outputs.scores
                        ]
                    )
475
476
477
478
479
480
                del outputs

                input_logits = model.forward(input_ids).logits[0]
                top_input_logprobs.append(
                    get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
                )
481
482
483
484
                if token_ids_logprob is not None:
                    token_ids_input_logprobs.append(
                        get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
                    )
485
486
                del input_logits

487
488
489
            if lora_paths is not None and lora_paths[i] is not None:
                # Unload the LoRA adapter if it is used
                model.unload()
490

491
492
493
494
        return ModelOutput(
            output_strs=output_strs,
            top_input_logprobs=top_input_logprobs,
            top_output_logprobs=top_output_logprobs,
495
496
            token_ids_input_logprobs=token_ids_input_logprobs,
            token_ids_output_logprobs=token_ids_output_logprobs,
497
498
        )

499
500
501
502

class SRTRunner:
    def __init__(
        self,
503
504
505
506
        model_path: str,
        torch_dtype: torch.dtype,
        model_type: str,
        tp_size: int = 1,
Lianmin Zheng's avatar
Lianmin Zheng committed
507
        model_impl: str = "auto",
508
        port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
509
        lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None,
510
        max_loras_per_batch: int = 4,
511
        attention_backend: Optional[str] = None,
512
513
        prefill_attention_backend: Optional[str] = None,
        decode_attention_backend: Optional[str] = None,
514
        lora_backend: str = "csgmv",
515
516
        disable_cuda_graph: bool = False,
        disable_radix_cache: bool = False,
517
        chunked_prefill_size: Optional[int] = None,
518
519
520
        context_length: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
        page_size: Optional[int] = None,
521
522
        dp_size: int = 1,
        tokenizer_path: Optional[str] = None,
523
        mem_fraction_static: float = 0.65,
524
        trust_remote_code: bool = False,
525
        speculative_draft_model_path: Optional[str] = None,
526
        speculative_draft_model_revision: Optional[str] = None,
527
528
529
530
531
        speculative_algorithm: Optional[str] = None,
        speculative_num_steps: Optional[int] = None,
        speculative_eagle_topk: Optional[int] = None,
        speculative_num_draft_tokens: Optional[int] = None,
        disable_overlap_schedule: bool = False,
532
        disable_custom_all_reduce: bool = False,
533
        torchao_config: Optional[str] = None,
534
        cuda_graph_max_bs: int = 4,
535
        sleep_on_idle=False,
536
537
        max_lora_rank: Optional[int] = None,
        lora_target_modules: Optional[List[str]] = None,
538
        enable_lora: Optional[bool] = None,
539
        max_loaded_loras: Optional[int] = None,
540
        json_model_override_args: Optional[dict[str, Any]] = None,
541
        lora_eviction_policy: str = "lru",
542
    ):
543
544
        self.model_type = model_type
        self.is_generation = model_type == "generation"
545
546
547
548
549
        enable_dp_attention = dp_size > 1

        spec_kwargs = {}
        if speculative_draft_model_path:
            spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
550
551
552
            spec_kwargs["speculative_draft_model_revision"] = (
                speculative_draft_model_revision
            )
553
554
555
556
557
            spec_kwargs["speculative_algorithm"] = speculative_algorithm
            spec_kwargs["speculative_num_steps"] = speculative_num_steps
            spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
            spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens

558
        self.engine = Engine(
559
560
561
            model_path=model_path,
            tp_size=tp_size,
            dtype=get_dtype_str(torch_dtype),
562
            port=port,
Lianmin Zheng's avatar
Lianmin Zheng committed
563
            model_impl=model_impl,
564
            torchao_config=torchao_config,
565
            mem_fraction_static=mem_fraction_static,
566
            trust_remote_code=trust_remote_code,
567
            is_embedding=not self.is_generation,
568
569
            lora_paths=lora_paths,
            max_loras_per_batch=max_loras_per_batch,
570
            lora_backend=lora_backend,
571
            attention_backend=attention_backend,
572
573
            prefill_attention_backend=prefill_attention_backend,
            decode_attention_backend=decode_attention_backend,
574
575
            disable_cuda_graph=disable_cuda_graph,
            disable_radix_cache=disable_radix_cache,
576
            chunked_prefill_size=chunked_prefill_size,
577
578
579
            context_length=context_length,
            max_total_tokens=max_total_tokens,
            page_size=page_size,
580
581
582
583
            enable_dp_attention=enable_dp_attention,
            dp_size=dp_size,
            tokenizer_path=tokenizer_path,
            disable_overlap_schedule=disable_overlap_schedule,
584
            cuda_graph_max_bs=cuda_graph_max_bs,
585
            disable_custom_all_reduce=disable_custom_all_reduce,
586
            sleep_on_idle=sleep_on_idle,
587
588
            max_lora_rank=max_lora_rank,
            lora_target_modules=lora_target_modules,
589
            enable_lora=enable_lora,
590
            max_loaded_loras=max_loaded_loras,
591
592
593
594
595
            json_model_override_args=(
                json.dumps(json_model_override_args)
                if json_model_override_args
                else "{}"
            ),
596
            lora_eviction_policy=lora_eviction_policy,
597
            **spec_kwargs,
598
        )
599
600
601
602
603
604
605

        if tokenizer_path is None:
            self.tokenizer = get_tokenizer(
                model_path, trust_remote_code=trust_remote_code
            )
        else:
            self.tokenizer = None
606

607
608
    def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
        return self.engine.load_lora_adapter(lora_name, lora_path, pinned)
609
610
611
612

    def unload_lora_adapter(self, lora_name: str):
        return self.engine.unload_lora_adapter(lora_name)

613
614
    def forward(
        self,
woodx's avatar
woodx committed
615
616
617
        prompts: Union[
            List[List[str]], List[str], List[torch.Tensor]
        ] = DEFAULT_PROMPTS,
618
        image_data: Optional[List[str]] = None,
619
620
621
622
623
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        logprob_start_len: int = 0,
        top_k: Optional[int] = None,
        token_ids_logprob: Optional[List[int]] = None,
624
        dimensions: Optional[int] = None,
625
    ):
626
        if self.is_generation:
627
            return self.forward_generation_raw(
628
                engine=self.engine,
629
630
631
                prompts=prompts,
                max_new_tokens=max_new_tokens,
                lora_paths=lora_paths,
632
633
634
                logprob_start_len=logprob_start_len,
                top_k=top_k,
                token_ids_logprob=token_ids_logprob,
635
636
            )
        else:
637
            if self.model_type == "embedding":
638
639
640
                response = self.engine.encode(
                    prompt=prompts, image_data=image_data, dimensions=dimensions
                )
641
642
643
644
                if isinstance(response, list):
                    logits = [x["embedding"] for x in response]
                else:
                    logits = [response["embedding"]]
645
                return ModelOutput(embed_logits=logits)
woodx's avatar
woodx committed
646
647
648
649
650
651
652
            # cross encoder model
            elif self.model_type == "cross_encoder":
                response = self.engine.rerank(prompts)
                if not isinstance(response, list):
                    response = [response]
                scores = [x["embedding"] for x in response]
                return ModelOutput(scores=scores)
653
            # reward model
654
            else:
655
                response = self.engine.encode(prompts)
656
657
                scores = [x["embedding"][0] for x in response]
                return ModelOutput(scores=scores)
658

659
660
661
    def batch_forward(
        self,
        prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
662
        image_data: Optional[List[str]] = None,
663
664
665
666
667
668
669
670
        max_new_tokens=8,
        lora_paths=None,
    ):
        """
        testing serving by sending all prompts once
        only return output strings and no logprobs
        """
        if self.is_generation:
671
            return self.batch_forward_generation_raw(
672
                engine=self.engine,
673
674
675
                prompts=prompts,
                max_new_tokens=max_new_tokens,
                lora_paths=lora_paths,
676
677
            )
        else:
678
            response = self.engine.encode(prompts, image_data)
679
680
681
682
683
            if self.model_type == "embedding":
                logits = [x["embedding"] for x in response]
                return ModelOutput(embed_logits=logits)
            else:
                scores = [x["embedding"][0] for x in response]
684
                return ModelOutput(scores=scores)
685

686
687
688
689
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
690
691
        self.engine.shutdown()
        del self.engine
692

693
694
    @staticmethod
    def forward_generation_raw(
695
        engine: Engine,
696
        prompts: Union[List[str], List[torch.Tensor]],
697
698
699
700
701
        max_new_tokens: int = 8,
        lora_paths: Optional[List[str]] = None,
        logprob_start_len: int = 0,
        top_k: Optional[int] = None,
        token_ids_logprob: Optional[List[int]] = None,
702
703
704
    ):
        # the return value contains logprobs from prefill
        output_strs = []
705
706
707
        output_ids = []
        # Input logprobs. Note that the last item in input logprob is equivalent to
        # the first item in the output logprob.
708
        top_input_logprobs = []
709
        input_token_logprobs_lst = []
710
        top_output_logprobs = []
711
712
713
714
715
716
717
718
        output_token_logprobs_lst = []
        top_output_logprob_idx = []
        if token_ids_logprob is not None:
            token_ids_input_logprobs = []
            token_ids_output_logprobs = []
        else:
            token_ids_input_logprobs = token_ids_output_logprobs = None

719
        sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
720
721
722
        if top_k:
            sampling_params["top_k"] = top_k

723
724
725
726
727
728
        for i, prompt in enumerate(prompts):
            response = engine.generate(
                prompt,
                lora_path=lora_paths[i] if lora_paths else None,
                sampling_params=sampling_params,
                return_logprob=True,
729
                logprob_start_len=logprob_start_len,
730
                top_logprobs_num=NUM_TOP_LOGPROBS,
731
                token_ids_logprob=token_ids_logprob,
732
733
734
735
736
737
738
739
740
            )
            text = response["text"]

            # Check if the text is empty or only whitespace.
            if not text.strip():
                raise ValueError(
                    "Received an empty text response. Please verify your input or model configuration."
                )
            output_strs.append(text)
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
            # output_ids.append(response["output_ids"])

            input_token_logprobs = response["meta_info"]["input_token_logprobs"]
            output_token_logprobs = response["meta_info"]["output_token_logprobs"]
            # print(i, input_token_logprobs)
            # print(i, output_token_logprobs)
            logprobs = response["meta_info"]["input_top_logprobs"]
            if token_ids_logprob is not None:
                input_token_ids_logprobs = response["meta_info"][
                    "input_token_ids_logprobs"
                ][1:]
            else:
                input_token_ids_logprobs = None

            num_prompt_tokens = response["meta_info"]["prompt_tokens"]
            assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
            assert len(logprobs) == num_prompt_tokens - logprob_start_len

            # The first token logprob has no meaning in sglang.
            input_token_logprobs = input_token_logprobs[1:]
            logprobs = logprobs[1:]
            assert len(input_token_logprobs) == len(logprobs)

            input_token_logprobs_lst.append(
                input_token_logprobs + [output_token_logprobs[0]]
            )
            output_token_logprobs_lst.append(output_token_logprobs)
768
769

            top_input_logprobs.append(
770
                [[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
                + [
                    [
                        tup[0]
                        for tup in response["meta_info"]["output_top_logprobs"][0][
                            :NUM_TOP_LOGPROBS
                        ]
                    ]
                ]
            )
            top_output_logprobs.append(
                [
                    [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
                    for x in response["meta_info"]["output_top_logprobs"]
                ]
            )
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
            top_output_logprob_idx.append(
                [
                    [tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
                    for x in response["meta_info"]["output_top_logprobs"]
                ]
            )
            if token_ids_logprob is not None:
                token_ids_input_logprobs.append(
                    [[tup[0] for tup in x] for x in input_token_ids_logprobs]
                    + [
                        [
                            tup[0]
                            for tup in response["meta_info"][
                                "output_token_ids_logprobs"
                            ][0]
                        ]
                    ]
                )
                token_ids_output_logprobs.append(
                    [
                        [tup[0] for tup in x]
                        for x in response["meta_info"]["output_token_ids_logprobs"]
                    ]
                )
810
811
812

        return ModelOutput(
            output_strs=output_strs,
813
            output_ids=output_ids,
814
815
            top_input_logprobs=top_input_logprobs,
            top_output_logprobs=top_output_logprobs,
816
817
818
819
820
            input_token_logprobs_lst=input_token_logprobs_lst,
            output_token_logprobs_lst=output_token_logprobs_lst,
            top_output_logprob_idx=top_output_logprob_idx,
            token_ids_input_logprobs=token_ids_input_logprobs,
            token_ids_output_logprobs=token_ids_output_logprobs,
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
        )

    @staticmethod
    def batch_forward_generation_raw(
        prompts: Union[List[str], List[torch.Tensor]],
        max_new_tokens,
        lora_paths,
        engine,
    ):
        # the return value contains logprobs from prefill
        output_strs = []
        sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
        response = engine.generate(
            prompts,
            lora_path=lora_paths if lora_paths else None,
            sampling_params=sampling_params,
        )
        output_strs = [r["text"] for r in response]

        return ModelOutput(
            output_strs=output_strs,
        )

844
845
846
847
848
849
850
851
852
853
854
855
856
857

def monkey_patch_gemma2_sdpa():
    """
    Use sdpa by default to fix the OOM issue.
    Revert this commit:
    https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660
    """
    from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel

    def _check_and_enable_sdpa(config, hard_check_only: bool = False):
        config._attn_implementation = "sdpa"
        return config

    setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906


def check_close_model_outputs(
    hf_outputs: ModelOutput,
    srt_outputs: ModelOutput,
    prefill_tolerance: float,
    decode_tolerance: float,
    rouge_l_tolerance: float,
    debug_text: str = "",
    check_logprobs: bool = True,
):
    # Compare output strings
    print(f"{hf_outputs.output_strs=}")
    print(f"{srt_outputs.output_strs=}")
    rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
    print(f"{rouge_l_scores=}")
    assert all(
        score >= rouge_l_tolerance for score in rouge_l_scores
    ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"

    if check_logprobs:
        for i in range(len(hf_outputs.output_strs)):
            # Compare input logprobs
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
                    f"prefill logprobs are not all close with {debug_text} "
                    f"prefill_tolerance={prefill_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )

            # Compare output logprobs
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])

            print(
                "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
                    f"decode logprobs are not all close with {debug_text} "
                    f"decode_tolerance={decode_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )