model_runner.py 22.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
21
22
import importlib.resources
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
28
29
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
30
31
32
33
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
34
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
35
)
36
from vllm.distributed.parallel_state import in_the_same_node_as
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
from sglang.srt.constrained import disable_cache
42
43
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
45
from sglang.srt.layers.sampler import Sampler
46
from sglang.srt.lora.lora_manager import LoRAManager
47
from sglang.srt.managers.schedule_batch import global_server_args_dict
48
49
50
51
52
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
53
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
54
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
55
from sglang.srt.server_args import ServerArgs
56
from sglang.srt.utils import (
57
    enable_show_time_cost,
58
    get_available_gpu_memory,
59
    is_generation_model,
60
    is_multimodal_model,
61
    monkey_patch_vllm_dummy_weight_loader,
62
63
    monkey_patch_vllm_p2p_access_check,
)
64

Ying Sheng's avatar
Ying Sheng committed
65
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
66

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68

class ModelRunner:
69
70
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
    def __init__(
        self,
73
        model_config: ModelConfig,
74
75
76
77
78
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    ):
81
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
84
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
87
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
88
        self.server_args = server_args
89
90
91
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
Ke Bao's avatar
Ke Bao committed
92

93
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
94
95
96
97
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
98
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
99
100
            self.server_args.attention_backend = "triton"

101
102
103
104
105
106
107
        if self.is_multimodal_model:
            logger.info(
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95

108
109
110
111
112
113
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
            disable_cache()

114
115
        global_server_args_dict.update(
            {
116
117
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
118
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
119
                "disable_mla": server_args.disable_mla,
120
                "torchao_config": server_args.torchao_config,
121
122
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
123

124
        # Init componnets
125
        min_per_gpu_memory = self.init_torch_distributed()
126
        self.sampler = Sampler()
127
        self.load_model()
128
129
        if server_args.lora_paths is not None:
            self.init_lora_manager()
130
131
        self.init_memory_pool(
            min_per_gpu_memory,
132
            server_args.max_running_requests,
133
134
135
            server_args.max_total_tokens,
        )
        self.init_cublas()
136
        self.init_attention_backend()
137
138
139
        self.init_cuda_graphs()

    def init_torch_distributed(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
140
        # Init torch distributed
141
        torch.cuda.set_device(self.gpu_id)
142
        logger.info("Init nccl begin.")
143

144
        if not self.server_args.enable_p2p_check:
145
146
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

147
148
        if self.server_args.dist_init_addr:
            nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
149
150
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
151
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
152
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
155
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
156
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
157
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
158
159
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
160
        min_per_gpu_memory = get_available_gpu_memory(
161
162
            self.gpu_id, distributed=self.tp_size > 1
        )
163
        self.tp_group = get_tp_group()
164

165
166
167
168
169
170
171
172
173
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
        if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
174
        if self.tp_size > 1:
175
176
            local_gpu_memory = get_available_gpu_memory(self.gpu_id)
            if min_per_gpu_memory < local_gpu_memory * 0.9:
177
178
179
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
180

181
        return min_per_gpu_memory
182

Lianmin Zheng's avatar
Lianmin Zheng committed
183
    def load_model(self):
184
        logger.info(
185
            f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
186
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188
189
190

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)

191
192
        if torch.cuda.get_device_capability()[0] < 8:
            logger.info(
193
                "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
194
195
            )
            self.server_args.dtype = "float16"
196
197
            if torch.cuda.get_device_capability()[1] < 5:
                raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
198

Lianmin Zheng's avatar
Lianmin Zheng committed
199
        # Prepare the vllm model config
200
        monkey_patch_vllm_dummy_weight_loader()
201
202
203
        self.device_config = DeviceConfig()
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
204
205
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
206
207
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
208
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
209
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
210
            seed=self.server_args.random_seed,
211
212
            skip_tokenizer_init=True,
        )
213
        if self.model_config.model_override_args is not None:
214
            self.vllm_model_config.hf_config.update(
215
                self.model_config.model_override_args
216
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
217
        self.dtype = self.vllm_model_config.dtype
218

Lianmin Zheng's avatar
Lianmin Zheng committed
219
        # Load the model
220
        self.model = get_model(
221
222
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Yineng Zhang's avatar
Yineng Zhang committed
223
            device_config=self.device_config,
224
225
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
226
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
227
            cache_config=None,
228
        )
229
        self.sliding_window_size = (
230
231
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
232
233
            else None
        )
234
        self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
235
        self.is_generation = is_generation_model(
236
            self.model_config.hf_config.architectures, self.server_args.is_embedding
237
238
        )

239
        logger.info(
240
            f"Load weight end. "
241
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
242
            f"dtype={self.dtype}, "
243
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
244
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
245

246
247
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
248
249
250
251
252
253
254
255
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
256
            f"Update weights begin. "
257
258
259
260
261
262
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
        )

        target_device = torch.device(self.device_config.device)

        try:
263
            # TODO: Use a better method to check this
264
265
266
267
268
269
270
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
271
                seed=self.server_args.random_seed,
272
273
274
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
275
276
            message = f"Failed to load model config: {e}."
            return False, message
277
278
279
280
281
282

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
283
284
            message = f"Failed to get model loader: {loader}."
            return False, message
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
                config.model,
                config.revision,
                fall_back_to_pt=getattr(
                    self.model, "fall_back_to_pt_during_load", True
                ),
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
309
                message = f"Failed to get weights iterator: {e}."
310
311
312
313
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
316
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
317
318
319
320
321
322
323
324
325
326
327
328
329
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

330
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
331
        return True, "Succeeded to update model weights."
332

333
334
335
336
337
338
339
340
341
342
343
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

344
    def profile_max_num_token(self, total_gpu_memory: int):
345
346
347
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
348
349
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
350
            and not self.server_args.disable_mla
351
352
353
354
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
355
                * torch._utils._element_size(self.kv_cache_dtype)
356
357
358
359
360
361
362
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
363
                * torch._utils._element_size(self.kv_cache_dtype)
364
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
365
366
367
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
368
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
        return max_num_token

371
    def init_memory_pool(
372
373
        self,
        total_gpu_memory: int,
374
375
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
376
    ):
377
378
379
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
380
            self.kv_cache_dtype = torch.float8_e5m2
381
382
383
384
385
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

386
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
387
388
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
389
                logging.warning(
390
391
392
393
394
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
395

396
        if self.max_total_num_tokens <= 0:
397
            raise RuntimeError(
398
                "Not enough memory. Please try to increase --mem-fraction-static."
399
            )
400

Liangsheng Yin's avatar
Liangsheng Yin committed
401
        if max_num_reqs is None:
402
403
404
405
406
407
408
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
409
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
410
411
            )

412
        device = "cuda"
Liangsheng Yin's avatar
Liangsheng Yin committed
413
        self.req_to_token_pool = ReqToTokenPool(
414
415
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
416
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
417
        )
418
419
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
420
            and not self.server_args.disable_mla
421
422
423
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
424
                dtype=self.kv_cache_dtype,
425
426
427
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
428
                device=device,
429
430
431
432
            )
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
433
                dtype=self.kv_cache_dtype,
434
435
436
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
437
                device=device,
438
            )
439
        logger.info(
440
            f"Memory pool end. "
441
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
442
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
443

Lianmin Zheng's avatar
Lianmin Zheng committed
444
445
446
447
448
449
450
451
452
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

453
454
455
456
457
458
459
460
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
461
            )
462
463
464
465
            assert not self.has_cross_attention, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
466
            self.attn_backend = TritonAttnBackend(self)
467
        else:
468
469
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
470
            )
471

472
    def init_cuda_graphs(self):
473
        """Capture cuda graphs."""
474
475
476
477
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

478
479
480
481
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

482
483
        if self.server_args.disable_cuda_graph:
            return
484

485
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
486
        self.cuda_graph_runner = CudaGraphRunner(self)
487

488
    def forward_decode(self, forward_batch: ForwardBatch):
489
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
490
            forward_batch.batch_size
491
        ):
492
            return self.cuda_graph_runner.replay(forward_batch)
493

494
        return self.model.forward(
495
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
496
497
        )

498
    def forward_extend(self, forward_batch: ForwardBatch):
499
500
        if self.is_generation:
            return self.model.forward(
501
                forward_batch.input_ids, forward_batch.positions, forward_batch
502
503
504
505
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
506
507
508
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
509
510
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
511

512
513
514
515
516
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
517
        else:
518
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
519

520
521
522
523
524
525
526
527
528
529
530
531
532
533
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
        sampling_info = forward_batch.sampling_info
        sampling_info.update_regex_vocab_mask()
        sampling_info.update_penalties()
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
534
535
536
537
538
539
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
540
            logits.add_(sampling_info.linear_penalties)
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

556
557
558
559
560
561
562
563

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
564
565
566
567
568
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
569
            if hasattr(module, "EntryClass"):
570
                entry = module.EntryClass
571
572
573
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
574
                    for tmp in entry:
575
576
577
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
578
                        model_arch_name_to_cls[tmp.__name__] = tmp
579
                else:
580
581
582
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
583
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
584

585
586
587
588
589
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
590

591
592
593
594
595
596
597
598
599
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
600
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)