model_runner.py 23.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
21
22
import importlib.resources
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
import torch.nn as nn
28
29
30
31
32
33
from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
35
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
36
37
38
39
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
40
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
41
)
42
from vllm.distributed.parallel_state import in_the_same_node_as
43
from vllm.model_executor.model_loader import get_model
44
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
from sglang.global_config import global_config
47
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
48
49
50
51
52
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
53
from sglang.srt.model_config import AttentionArch, ModelConfig
54
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
Lianmin Zheng's avatar
Lianmin Zheng committed
55
from sglang.srt.server_args import ServerArgs
56
57
from sglang.srt.utils import (
    get_available_gpu_memory,
58
    is_generation_model,
Lianmin Zheng's avatar
Lianmin Zheng committed
59
    is_llama3_405b_fp8_head_16,
60
    is_multimodal_model,
61
    monkey_patch_vllm_dummy_weight_loader,
62
    monkey_patch_vllm_p2p_access_check,
63
    monkey_patch_vllm_qvk_linear_loader,
64
)
65

Ying Sheng's avatar
Ying Sheng committed
66
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
67

Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
70
71

class ModelRunner:
    def __init__(
        self,
72
        model_config: ModelConfig,
73
74
75
76
77
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    ):
80
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
83
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
87
        self.server_args = server_args
88
89
90
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
91
92
93
94
        global_server_args_dict.update(
            {
                "disable_flashinfer": server_args.disable_flashinfer,
                "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
95
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
96
                "enable_mla": server_args.enable_mla,
97
98
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
99

100
101
102
103
104
105
106
        if self.is_multimodal_model:
            logger.info(
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95

107
108
109
110
111
112
113
114
115
116
117
118
        min_per_gpu_memory = self.init_torch_distributed()
        self.load_model()
        self.init_memory_pool(
            min_per_gpu_memory,
            server_args.max_num_reqs,
            server_args.max_total_tokens,
        )
        self.init_cublas()
        self.init_flashinfer()
        self.init_cuda_graphs()

    def init_torch_distributed(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
119
        # Init torch distributed
120
        torch.cuda.set_device(self.gpu_id)
121
        logger.info("Init nccl begin.")
122

123
        if not self.server_args.enable_p2p_check:
124
125
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

126
127
        if self.server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
128
129
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
130
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
131
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
135
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
136
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
137
138
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
139
        min_per_gpu_memory = get_available_gpu_memory(
140
141
            self.gpu_id, distributed=self.tp_size > 1
        )
142
        self.tp_group = get_tp_group()
143

144
145
146
147
148
149
150
151
152
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
        if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
153
        if self.tp_size > 1:
154
155
            local_gpu_memory = get_available_gpu_memory(self.gpu_id)
            if min_per_gpu_memory < local_gpu_memory * 0.9:
156
157
158
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
159

160
        return min_per_gpu_memory
161

Lianmin Zheng's avatar
Lianmin Zheng committed
162
    def load_model(self):
163
        logger.info(
164
            f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
165
        )
166
167
        if torch.cuda.get_device_capability()[0] < 8:
            logger.info(
168
                "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
169
170
            )
            self.server_args.dtype = "float16"
171
172
            if torch.cuda.get_device_capability()[1] < 5:
                raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
173

174
        monkey_patch_vllm_dummy_weight_loader()
175
176
177
        self.device_config = DeviceConfig()
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
180
181
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
182
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
183
            dtype=self.server_args.dtype,
184
185
186
            seed=42,
            skip_tokenizer_init=True,
        )
187

188
189
        # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
        # Drop this after Sept, 2024.
Lianmin Zheng's avatar
Lianmin Zheng committed
190
        if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
191
            self.model_config.hf_config.num_key_value_heads = 8
192
            self.vllm_model_config.hf_config.num_key_value_heads = 8
193
194
            monkey_patch_vllm_qvk_linear_loader()

195
        self.dtype = self.vllm_model_config.dtype
196
        if self.model_config.model_overide_args is not None:
197
198
199
            self.vllm_model_config.hf_config.update(
                self.model_config.model_overide_args
            )
200
201

        self.model = get_model(
202
203
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Yineng Zhang's avatar
Yineng Zhang committed
204
            device_config=self.device_config,
205
206
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
207
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
208
            cache_config=None,
209
        )
210
        self.sliding_window_size = (
211
212
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
213
214
            else None
        )
215
        self.is_generation = is_generation_model(
216
            self.model_config.hf_config.architectures, self.server_args.is_embedding
217
218
        )

219
        logger.info(
220
            f"Load weight end. "
221
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
222
            f"dtype={self.dtype}, "
223
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
224
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
225

226
227
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
228
229
230
231
232
233
234
235
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
236
            f"Update weights begin. "
237
238
239
240
241
242
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
        )

        target_device = torch.device(self.device_config.device)

        try:
243
            # TODO: Use a better method to check this
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
                seed=42,
                skip_tokenizer_init=True,
            )
        except Exception as e:
            logger.error(f"Failed to load model config: {e}")
            return False, "Failed to update model weights"

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
            logger.error("Failed to get weights iterator: Unsupported loader")
            return False, "Failed to update model weights"

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
                config.model,
                config.revision,
                fall_back_to_pt=getattr(
                    self.model, "fall_back_to_pt_during_load", True
                ),
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
                message = f"Failed to get weights iterator: {e}"
                logger.error(message)
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
                message = f"Failed to update weights: {e}. \n Rolling back to original weights"
                logger.error(message)
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

310
        logger.info("Update weights end.")
311
312
        return True, "Succeeded to update model weights"

313
    def profile_max_num_token(self, total_gpu_memory: int):
314
315
316
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
317
318
319
320
321
322
323
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
324
                * torch._utils._element_size(self.kv_cache_dtype)
325
326
327
328
329
330
331
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
332
                * torch._utils._element_size(self.kv_cache_dtype)
333
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
334
335
336
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
337
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
338
339
        return max_num_token

340
    def init_memory_pool(
341
342
343
344
        self,
        total_gpu_memory: int,
        max_num_reqs: int = None,
        max_total_tokens: int = None,
345
    ):
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
            if self.server_args.disable_flashinfer or self.server_args.enable_mla:
                logger.warning(
                    "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
                )
                self.kv_cache_dtype = self.dtype
            else:
                self.kv_cache_dtype = torch.float8_e5m2
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

361
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
362
363
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
364
                logging.warning(
365
366
367
368
369
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
370

371
        if self.max_total_num_tokens <= 0:
372
            raise RuntimeError(
373
                "Not enough memory. Please try to increase --mem-fraction-static."
374
            )
375

Liangsheng Yin's avatar
Liangsheng Yin committed
376
        if max_num_reqs is None:
377
378
379
380
381
382
383
384
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
                5120,
Liangsheng Yin's avatar
Liangsheng Yin committed
385
386
387
388
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
            self.model_config.context_len + 8,
        )
391
392
393
394
395
396
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
397
                dtype=self.kv_cache_dtype,
398
399
400
401
402
403
404
405
406
407
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
            logger.info("using MLA Triton implementaion, flashinfer is disabled")
            # FIXME: temporarily only Triton MLA is supported
            self.server_args.disable_flashinfer = True
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
408
                dtype=self.kv_cache_dtype,
409
410
411
412
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
413
        logger.info(
414
            f"Memory pool end. "
415
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
416
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
417

Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
420
421
422
423
424
425
426
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

427
    def init_flashinfer(self):
428
        """Init flashinfer attention kernel wrappers."""
429
        if self.server_args.disable_flashinfer:
430
431
432
            assert (
                self.sliding_window_size is None
            ), "turn on flashinfer to support window attention"
433
434
            self.flashinfer_prefill_wrapper_ragged = None
            self.flashinfer_prefill_wrapper_paged = None
435
            self.flashinfer_decode_wrapper = None
436
437
438
439
440
441
442
443
444
445
            return

        if not _grouped_size_compiled_for_decode_kernels(
            self.model_config.num_attention_heads // self.tp_size,
            self.model_config.get_num_kv_heads(self.tp_size),
        ):
            use_tensor_cores = True
        else:
            use_tensor_cores = False

446
        if self.sliding_window_size is None:
447
            self.flashinfer_workspace_buffer = torch.empty(
448
449
450
451
452
453
                global_config.flashinfer_workspace_size,
                dtype=torch.uint8,
                device="cuda",
            )
            self.flashinfer_prefill_wrapper_ragged = (
                BatchPrefillWithRaggedKVCacheWrapper(
454
                    self.flashinfer_workspace_buffer, "NHD"
455
456
457
                )
            )
            self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
458
                self.flashinfer_workspace_buffer, "NHD"
459
460
            )
            self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
461
                self.flashinfer_workspace_buffer,
462
463
464
465
                "NHD",
                use_tensor_cores=use_tensor_cores,
            )
        else:
466
            self.flashinfer_workspace_buffer = torch.empty(
467
468
469
470
                global_config.flashinfer_workspace_size,
                dtype=torch.uint8,
                device="cuda",
            )
471
            self.flashinfer_prefill_wrapper_ragged = None
472
473
474
475
476
            self.flashinfer_prefill_wrapper_paged = []
            self.flashinfer_decode_wrapper = []
            for i in range(2):
                self.flashinfer_prefill_wrapper_paged.append(
                    BatchPrefillWithPagedKVCacheWrapper(
477
                        self.flashinfer_workspace_buffer, "NHD"
478
479
480
481
                    )
                )
                self.flashinfer_decode_wrapper.append(
                    BatchDecodeWithPagedKVCacheWrapper(
482
                        self.flashinfer_workspace_buffer,
483
484
485
486
                        "NHD",
                        use_tensor_cores=use_tensor_cores,
                    )
                )
487

488
    def init_cuda_graphs(self):
489
490
491
492
493
        """Capture cuda graphs."""
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

494
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
495
496
497
498
499

        if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
            self.cuda_graph_runner = None
            return

500
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
501
502
503
504
505
506

        if self.server_args.disable_cuda_graph_padding:
            batch_size_list = list(range(1, 32)) + [64, 128]
        else:
            batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]

507
        self.cuda_graph_runner = CudaGraphRunner(
508
509
510
            self,
            max_batch_size_to_capture=max(batch_size_list),
            use_torch_compile=self.server_args.enable_torch_compile,
511
            disable_padding=self.server_args.disable_cuda_graph_padding,
512
        )
513
514
        try:
            self.cuda_graph_runner.capture(batch_size_list)
515
        except RuntimeError as e:
516
            raise Exception(
517
518
                f"Capture cuda graph failed: {e}\n"
                "Possible solutions:\n"
519
520
521
                "1. disable cuda graph by --disable-cuda-graph\n"
                "2. set --mem-fraction-static to a smaller value\n"
                "3. disable torch compile by not using --enable-torch-compile\n"
522
                "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
523
            )
524

Lianmin Zheng's avatar
Lianmin Zheng committed
525
    @torch.inference_mode()
526
    def forward_decode(self, batch: ScheduleBatch):
527
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
528
529
            return self.cuda_graph_runner.replay(batch)

530
        input_metadata = InputMetadata.from_schedule_batch(
531
532
533
            self,
            batch,
            ForwardMode.DECODE,
534
        )
535

536
537
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
538
539
540
        )

    @torch.inference_mode()
541
    def forward_extend(self, batch: ScheduleBatch):
542
        input_metadata = InputMetadata.from_schedule_batch(
543
544
545
            self,
            batch,
            forward_mode=ForwardMode.EXTEND,
546
        )
547
548
549
550
551
552
553
554
555
556
557
558
        if self.is_generation:
            return self.model.forward(
                batch.input_ids, input_metadata.positions, input_metadata
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
                batch.input_ids,
                input_metadata.positions,
                input_metadata,
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
559
560

    @torch.inference_mode()
561
    def forward_extend_multi_modal(self, batch: ScheduleBatch):
562
        input_metadata = InputMetadata.from_schedule_batch(
563
564
565
            self,
            batch,
            forward_mode=ForwardMode.EXTEND,
Lianmin Zheng's avatar
Lianmin Zheng committed
566
567
        )
        return self.model.forward(
568
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
569
570
            input_metadata.positions,
            input_metadata,
571
572
573
            input_metadata.pixel_values,
            input_metadata.image_sizes,
            input_metadata.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
574
575
        )

576
    def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
Lianmin Zheng's avatar
Lianmin Zheng committed
577
        if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
578
579
580
            return self.forward_extend_multi_modal(batch)
        elif forward_mode == ForwardMode.DECODE:
            return self.forward_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
581
        elif forward_mode == ForwardMode.EXTEND:
582
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
583
584
        else:
            raise ValueError(f"Invaid forward mode: {forward_mode}")
585
586
587
588
589
590
591
592
593
594
595


@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
596
                entry = module.EntryClass
597
598
599
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
600
                    for tmp in entry:
601
                        assert tmp.__name__ not in model_arch_name_to_cls
602
                        model_arch_name_to_cls[tmp.__name__] = tmp
603
                else:
604
                    assert entry.__name__ not in model_arch_name_to_cls
605
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
606
607
608

            # compat: some models such as chatglm has incorrect class set in config.json
            # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
609
610
611
            if hasattr(module, "EntryClassRemapping") and isinstance(
                module.EntryClassRemapping, list
            ):
Qubitium's avatar
Qubitium committed
612
613
                for remap in module.EntryClassRemapping:
                    if isinstance(remap, tuple) and len(remap) == 2:
614
                        assert remap[0] not in model_arch_name_to_cls
Qubitium's avatar
Qubitium committed
615
616
                        model_arch_name_to_cls[remap[0]] = remap[1]

617
618
619
620
621
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
622

623
624
625
626
627
628
629
630
631
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
632
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)