model_runner.py 22.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
21
22
import importlib.resources
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
24
from typing import Optional, Tuple, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
28
29
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
30
31
32
33
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
34
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
35
)
36
from vllm.distributed.parallel_state import in_the_same_node_as
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
43
from sglang.srt.layers.sampler import SampleOutput, Sampler
44
from sglang.srt.lora.lora_manager import LoRAManager
45
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
46
47
48
49
50
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
51
from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
53
from sglang.srt.server_args import ServerArgs
54
55
from sglang.srt.utils import (
    get_available_gpu_memory,
56
    is_generation_model,
Lianmin Zheng's avatar
Lianmin Zheng committed
57
    is_llama3_405b_fp8_head_16,
58
    is_multimodal_model,
59
    monkey_patch_vllm_dummy_weight_loader,
60
    monkey_patch_vllm_p2p_access_check,
61
    monkey_patch_vllm_qvk_linear_loader,
62
)
63

Ying Sheng's avatar
Ying Sheng committed
64
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
65

Lianmin Zheng's avatar
Lianmin Zheng committed
66
67

class ModelRunner:
68
69
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
    def __init__(
        self,
72
        model_config: ModelConfig,
73
74
75
76
77
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    ):
80
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
83
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
87
        self.server_args = server_args
88
89
90
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
91
92
        global_server_args_dict.update(
            {
93
94
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
95
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
96
                "enable_mla": server_args.enable_mla,
97
                "torchao_config": server_args.torchao_config,
98
99
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
100

101
        # Model-specific adjustment
102
103
104
105
106
107
108
        if self.is_multimodal_model:
            logger.info(
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95

109
        # Init componnets
110
        min_per_gpu_memory = self.init_torch_distributed()
111
        self.sampler = Sampler()
112
        self.load_model()
113
114
        if server_args.lora_paths is not None:
            self.init_lora_manager()
115
116
        self.init_memory_pool(
            min_per_gpu_memory,
117
            server_args.max_running_requests,
118
119
120
            server_args.max_total_tokens,
        )
        self.init_cublas()
121
        self.init_attention_backend()
122
123
124
        self.init_cuda_graphs()

    def init_torch_distributed(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
125
        # Init torch distributed
126
        torch.cuda.set_device(self.gpu_id)
127
        logger.info("Init nccl begin.")
128

129
        if not self.server_args.enable_p2p_check:
130
131
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

132
133
        if self.server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
134
135
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
136
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
137
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
140
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
141
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
142
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
143
144
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
145
        min_per_gpu_memory = get_available_gpu_memory(
146
147
            self.gpu_id, distributed=self.tp_size > 1
        )
148
        self.tp_group = get_tp_group()
149

150
151
152
153
154
155
156
157
158
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
        if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
159
        if self.tp_size > 1:
160
161
            local_gpu_memory = get_available_gpu_memory(self.gpu_id)
            if min_per_gpu_memory < local_gpu_memory * 0.9:
162
163
164
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
165

166
        return min_per_gpu_memory
167

Lianmin Zheng's avatar
Lianmin Zheng committed
168
    def load_model(self):
169
        torch.set_num_threads(1)
170
        logger.info(
171
            f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
172
        )
173
174
        if torch.cuda.get_device_capability()[0] < 8:
            logger.info(
175
                "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
176
177
            )
            self.server_args.dtype = "float16"
178
179
            if torch.cuda.get_device_capability()[1] < 5:
                raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
180

181
        monkey_patch_vllm_dummy_weight_loader()
182
183
184
        self.device_config = DeviceConfig()
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
185
186
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
187
188
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
189
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
190
            dtype=self.server_args.dtype,
191
192
193
            seed=42,
            skip_tokenizer_init=True,
        )
194

195
196
        # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
        # Drop this after Sept, 2024.
Lianmin Zheng's avatar
Lianmin Zheng committed
197
        if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
198
            self.model_config.hf_config.num_key_value_heads = 8
199
            self.vllm_model_config.hf_config.num_key_value_heads = 8
200
201
            monkey_patch_vllm_qvk_linear_loader()

202
        self.dtype = self.vllm_model_config.dtype
203
        if self.model_config.model_override_args is not None:
204
            self.vllm_model_config.hf_config.update(
205
                self.model_config.model_override_args
206
            )
207
208

        self.model = get_model(
209
210
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Yineng Zhang's avatar
Yineng Zhang committed
211
            device_config=self.device_config,
212
213
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
214
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
215
            cache_config=None,
216
        )
217
        self.sliding_window_size = (
218
219
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
220
221
            else None
        )
222
        self.is_generation = is_generation_model(
223
            self.model_config.hf_config.architectures, self.server_args.is_embedding
224
225
        )

226
        logger.info(
227
            f"Load weight end. "
228
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
229
            f"dtype={self.dtype}, "
230
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
231
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
232

233
234
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
235
236
237
238
239
240
241
242
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
243
            f"Update weights begin. "
244
245
246
247
248
249
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
        )

        target_device = torch.device(self.device_config.device)

        try:
250
            # TODO: Use a better method to check this
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
                seed=42,
                skip_tokenizer_init=True,
            )
        except Exception as e:
            logger.error(f"Failed to load model config: {e}")
            return False, "Failed to update model weights"

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
            logger.error("Failed to get weights iterator: Unsupported loader")
            return False, "Failed to update model weights"

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
                config.model,
                config.revision,
                fall_back_to_pt=getattr(
                    self.model, "fall_back_to_pt_during_load", True
                ),
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
                message = f"Failed to get weights iterator: {e}"
                logger.error(message)
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
                message = f"Failed to update weights: {e}. \n Rolling back to original weights"
                logger.error(message)
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

317
        logger.info("Update weights end.")
318
319
        return True, "Succeeded to update model weights"

320
321
322
323
324
325
326
327
328
329
330
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

331
    def profile_max_num_token(self, total_gpu_memory: int):
332
333
334
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
335
336
337
338
339
340
341
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
342
                * torch._utils._element_size(self.kv_cache_dtype)
343
344
345
346
347
348
349
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
350
                * torch._utils._element_size(self.kv_cache_dtype)
351
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
352
353
354
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
355
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
        return max_num_token

358
    def init_memory_pool(
359
360
        self,
        total_gpu_memory: int,
361
362
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
363
    ):
364
365
366
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
367
            self.kv_cache_dtype = torch.float8_e5m2
368
369
370
371
372
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

373
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
374
375
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
376
                logging.warning(
377
378
379
380
381
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
382

383
        if self.max_total_num_tokens <= 0:
384
            raise RuntimeError(
385
                "Not enough memory. Please try to increase --mem-fraction-static."
386
            )
387

Liangsheng Yin's avatar
Liangsheng Yin committed
388
        if max_num_reqs is None:
389
390
391
392
393
394
395
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
396
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
397
398
399
400
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
401
402
            self.model_config.context_len + 8,
        )
403
404
405
406
407
408
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
409
                dtype=self.kv_cache_dtype,
410
411
412
413
414
415
416
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
417
                dtype=self.kv_cache_dtype,
418
419
420
421
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
422
        logger.info(
423
            f"Memory pool end. "
424
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
425
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
426

Lianmin Zheng's avatar
Lianmin Zheng committed
427
428
429
430
431
432
433
434
435
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

436
437
438
439
440
441
442
443
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
444
            )
445
            self.attn_backend = TritonAttnBackend(self)
446
        else:
447
448
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
449
            )
450

451
    def init_cuda_graphs(self):
452
        """Capture cuda graphs."""
453
454
455
456
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

457
458
459
460
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

461
462
        if self.server_args.disable_cuda_graph:
            return
463

464
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
465
        self.cuda_graph_runner = CudaGraphRunner(self)
466

Lianmin Zheng's avatar
Lianmin Zheng committed
467
    @torch.inference_mode()
468
    def forward_decode(self, batch: ScheduleBatch):
469
470
        if self.server_args.lora_paths is not None:
            self.lora_manager.prepare_lora_batch(batch)
471
472

        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
473
474
            return self.cuda_graph_runner.replay(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
475
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
476

477
478
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
479
480
481
        )

    @torch.inference_mode()
482
    def forward_extend(self, batch: ScheduleBatch):
Liangsheng Yin's avatar
Liangsheng Yin committed
483
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
484
485
486
        if self.server_args.lora_paths is not None:
            self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)

487
488
489
490
491
492
493
494
495
496
497
498
        if self.is_generation:
            return self.model.forward(
                batch.input_ids, input_metadata.positions, input_metadata
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
                batch.input_ids,
                input_metadata.positions,
                input_metadata,
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
499
500

    @torch.inference_mode()
501
    def forward_extend_multi_modal(self, batch: ScheduleBatch):
Liangsheng Yin's avatar
Liangsheng Yin committed
502
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
503
        return self.model.forward(
504
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
505
506
            input_metadata.positions,
            input_metadata,
507
508
509
            input_metadata.pixel_values,
            input_metadata.image_sizes,
            input_metadata.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
        )

512
    def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
Liangsheng Yin's avatar
Liangsheng Yin committed
513
514
515
        assert batch.forward_mode is not None

        if self.is_multimodal_model and batch.forward_mode.is_extend():
516
            return self.forward_extend_multi_modal(batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
517
        elif batch.forward_mode.is_decode():
518
            return self.forward_decode(batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
519
        elif batch.forward_mode.is_extend():
520
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
521
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
522
            raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
523

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    def _check_sample_results(self, sample_output: SampleOutput):
        if not torch.all(sample_output.success):
            probs = sample_output.probs
            batch_next_token_ids = sample_output.batch_next_token_ids
            logging.warning("Sampling failed, fallback to top_k=1 strategy")
            probs = probs.masked_fill(torch.isnan(probs), 0.0)
            argmax_ids = torch.argmax(probs, dim=-1)
            batch_next_token_ids = torch.where(
                sample_output.success, batch_next_token_ids, argmax_ids
            )
            sample_output.probs = probs
            sample_output.batch_next_token_ids = batch_next_token_ids

        return sample_output.batch_next_token_ids

    def _apply_logits_bias(
        self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
    ):
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
            logits += sampling_info.linear_penalties

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

    def sample(
        self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
    ) -> torch.Tensor:
        batch.sampling_info.update_regex_vocab_mask(batch)
        batch.sampling_info.update_penalties()
        logits = self._apply_logits_bias(
            logits_output.next_token_logits, batch.sampling_info
        )
        sample_output = self.sampler(logits, batch.sampling_info)
        return self._check_sample_results(sample_output)

575
576
577
578
579
580
581
582
583
584

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
585
                entry = module.EntryClass
586
587
588
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
589
                    for tmp in entry:
590
                        assert tmp.__name__ not in model_arch_name_to_cls
591
                        model_arch_name_to_cls[tmp.__name__] = tmp
592
                else:
593
                    assert entry.__name__ not in model_arch_name_to_cls
594
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
595

596
597
598
599
600
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
601

602
603
604
605
606
607
608
609
610
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
611
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)