model_runner.py 22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
21
22
import importlib.resources
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
24
from typing import Optional, Tuple, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
28
29
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
30
31
32
33
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
34
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
35
)
36
from vllm.distributed.parallel_state import in_the_same_node_as
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
43
from sglang.srt.layers.sampler import Sampler
44
from sglang.srt.lora.lora_manager import LoRAManager
45
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
46
47
48
49
50
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
51
from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
53
from sglang.srt.server_args import ServerArgs
54
55
from sglang.srt.utils import (
    get_available_gpu_memory,
56
    is_generation_model,
57
    is_multimodal_model,
58
    monkey_patch_vllm_dummy_weight_loader,
59
60
    monkey_patch_vllm_p2p_access_check,
)
61

Ying Sheng's avatar
Ying Sheng committed
62
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
63

Lianmin Zheng's avatar
Lianmin Zheng committed
64
65

class ModelRunner:
66
67
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
    def __init__(
        self,
70
        model_config: ModelConfig,
71
72
73
74
75
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
76
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
    ):
78
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
81
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        self.server_args = server_args
86
87
88
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
Ke Bao's avatar
Ke Bao committed
89
90
91
92
93
94
95
96

        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
            logger.info("MLA optimization is tunred on. Use triton backend.")
            self.server_args.attention_backend = "triton"

97
98
        global_server_args_dict.update(
            {
99
100
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
101
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
102
                "disable_mla": server_args.disable_mla,
103
                "torchao_config": server_args.torchao_config,
104
105
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
106

107
        # Model-specific adjustment
108
109
110
111
112
113
114
        if self.is_multimodal_model:
            logger.info(
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95

115
        # Init componnets
116
        min_per_gpu_memory = self.init_torch_distributed()
117
        self.sampler = Sampler()
118
        self.load_model()
119
120
        if server_args.lora_paths is not None:
            self.init_lora_manager()
121
122
        self.init_memory_pool(
            min_per_gpu_memory,
123
            server_args.max_running_requests,
124
125
126
            server_args.max_total_tokens,
        )
        self.init_cublas()
127
        self.init_attention_backend()
128
129
130
        self.init_cuda_graphs()

    def init_torch_distributed(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
131
        # Init torch distributed
132
        torch.cuda.set_device(self.gpu_id)
133
        logger.info("Init nccl begin.")
134

135
        if not self.server_args.enable_p2p_check:
136
137
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

138
139
        if self.server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
140
141
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
142
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
143
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
146
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
147
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
148
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
149
150
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
151
        min_per_gpu_memory = get_available_gpu_memory(
152
153
            self.gpu_id, distributed=self.tp_size > 1
        )
154
        self.tp_group = get_tp_group()
155

156
157
158
159
160
161
162
163
164
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
        if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
165
        if self.tp_size > 1:
166
167
            local_gpu_memory = get_available_gpu_memory(self.gpu_id)
            if min_per_gpu_memory < local_gpu_memory * 0.9:
168
169
170
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
171

172
        return min_per_gpu_memory
173

Lianmin Zheng's avatar
Lianmin Zheng committed
174
    def load_model(self):
175
        logger.info(
176
            f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
177
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)

182
183
        if torch.cuda.get_device_capability()[0] < 8:
            logger.info(
184
                "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
185
186
            )
            self.server_args.dtype = "float16"
187
188
            if torch.cuda.get_device_capability()[1] < 5:
                raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
189

Lianmin Zheng's avatar
Lianmin Zheng committed
190
        # Prepare the vllm model config
191
        monkey_patch_vllm_dummy_weight_loader()
192
193
194
        self.device_config = DeviceConfig()
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
195
196
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
197
198
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
199
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
200
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
201
            seed=self.server_args.random_seed,
202
203
            skip_tokenizer_init=True,
        )
204
        if self.model_config.model_override_args is not None:
205
            self.vllm_model_config.hf_config.update(
206
                self.model_config.model_override_args
207
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
208
        self.dtype = self.vllm_model_config.dtype
209

Lianmin Zheng's avatar
Lianmin Zheng committed
210
        # Load the model
211
        self.model = get_model(
212
213
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Yineng Zhang's avatar
Yineng Zhang committed
214
            device_config=self.device_config,
215
216
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
217
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
218
            cache_config=None,
219
        )
220
        self.sliding_window_size = (
221
222
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
223
224
            else None
        )
225
        self.is_generation = is_generation_model(
226
            self.model_config.hf_config.architectures, self.server_args.is_embedding
227
228
        )

229
        logger.info(
230
            f"Load weight end. "
231
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
232
            f"dtype={self.dtype}, "
233
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
234
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
235

236
237
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
238
239
240
241
242
243
244
245
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
246
            f"Update weights begin. "
247
248
249
250
251
252
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
        )

        target_device = torch.device(self.device_config.device)

        try:
253
            # TODO: Use a better method to check this
254
255
256
257
258
259
260
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
261
                seed=self.server_args.random_seed,
262
263
264
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
265
266
            message = f"Failed to load model config: {e}."
            return False, message
267
268
269
270
271
272

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
            message = f"Failed to get model loader: {loader}."
            return False, message
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
                config.model,
                config.revision,
                fall_back_to_pt=getattr(
                    self.model, "fall_back_to_pt_during_load", True
                ),
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
299
                message = f"Failed to get weights iterator: {e}."
300
301
302
303
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
304
305
306
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
307
308
309
310
311
312
313
314
315
316
317
318
319
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

320
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
321
        return True, "Succeeded to update model weights."
322

323
324
325
326
327
328
329
330
331
332
333
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

334
    def profile_max_num_token(self, total_gpu_memory: int):
335
336
337
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
338
339
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
340
            and not self.server_args.disable_mla
341
342
343
344
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
345
                * torch._utils._element_size(self.kv_cache_dtype)
346
347
348
349
350
351
352
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
353
                * torch._utils._element_size(self.kv_cache_dtype)
354
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
357
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
358
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
        return max_num_token

361
    def init_memory_pool(
362
363
        self,
        total_gpu_memory: int,
364
365
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
366
    ):
367
368
369
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
370
            self.kv_cache_dtype = torch.float8_e5m2
371
372
373
374
375
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

376
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
377
378
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
379
                logging.warning(
380
381
382
383
384
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
385

386
        if self.max_total_num_tokens <= 0:
387
            raise RuntimeError(
388
                "Not enough memory. Please try to increase --mem-fraction-static."
389
            )
390

Liangsheng Yin's avatar
Liangsheng Yin committed
391
        if max_num_reqs is None:
392
393
394
395
396
397
398
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
399
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
400
401
402
            )

        self.req_to_token_pool = ReqToTokenPool(
403
404
            max_num_reqs + 1,
            self.model_config.context_len + 4,
Lianmin Zheng's avatar
Lianmin Zheng committed
405
        )
406
407
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
408
            and not self.server_args.disable_mla
409
410
411
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
412
                dtype=self.kv_cache_dtype,
413
414
415
416
417
418
419
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
420
                dtype=self.kv_cache_dtype,
421
422
423
424
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
425
        logger.info(
426
            f"Memory pool end. "
427
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
428
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
429

Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
432
433
434
435
436
437
438
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

439
440
441
442
443
444
445
446
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
447
            )
448
            self.attn_backend = TritonAttnBackend(self)
449
        else:
450
451
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
452
            )
453

454
    def init_cuda_graphs(self):
455
        """Capture cuda graphs."""
456
457
458
459
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

460
461
462
463
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

464
465
        if self.server_args.disable_cuda_graph:
            return
466

467
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
468
        self.cuda_graph_runner = CudaGraphRunner(self)
469

470
    def forward_decode(self, batch: ScheduleBatch):
471
472
        if self.server_args.lora_paths is not None:
            self.lora_manager.prepare_lora_batch(batch)
473
474

        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
475
476
            return self.cuda_graph_runner.replay(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
477
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
478

479
480
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
481
482
        )

483
    def forward_extend(self, batch: ScheduleBatch):
Liangsheng Yin's avatar
Liangsheng Yin committed
484
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
485
486
487
        if self.server_args.lora_paths is not None:
            self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)

488
489
490
491
492
493
494
495
496
497
498
499
        if self.is_generation:
            return self.model.forward(
                batch.input_ids, input_metadata.positions, input_metadata
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
                batch.input_ids,
                input_metadata.positions,
                input_metadata,
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
500

501
    def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
Liangsheng Yin's avatar
Liangsheng Yin committed
502
503
        assert batch.forward_mode is not None

Liangsheng Yin's avatar
Liangsheng Yin committed
504
        if batch.forward_mode.is_decode():
505
            return self.forward_decode(batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
506
        elif batch.forward_mode.is_extend():
507
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
508
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
509
            raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
510

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    def _apply_logits_bias(
        self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
    ):
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
            logits += sampling_info.linear_penalties

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

    def sample(
        self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
    ) -> torch.Tensor:
539
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
540
541
542
543
544
        batch.sampling_info.update_regex_vocab_mask(batch)
        batch.sampling_info.update_penalties()
        logits = self._apply_logits_bias(
            logits_output.next_token_logits, batch.sampling_info
        )
545
546
547
548

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, batch.sampling_info)
        return next_token_ids
549

550
551
552
553
554
555
556
557

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
558
559
560
561
562
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
563
            if hasattr(module, "EntryClass"):
564
                entry = module.EntryClass
565
566
567
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
568
                    for tmp in entry:
569
570
571
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
572
                        model_arch_name_to_cls[tmp.__name__] = tmp
573
                else:
574
575
576
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
577
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
578

579
580
581
582
583
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
584

585
586
587
588
589
590
591
592
593
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
594
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)