model_runner.py 23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
21
22
import importlib.resources
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
24
from typing import Optional, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
28
29
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
30
31
32
33
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
34
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
35
)
36
from vllm.distributed.parallel_state import in_the_same_node_as
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
from sglang.srt.constrained import disable_cache
42
43
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
45
from sglang.srt.layers.sampler import Sampler
46
from sglang.srt.lora.lora_manager import LoRAManager
47
from sglang.srt.managers.schedule_batch import global_server_args_dict
48
49
50
51
52
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
53
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
54
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Lianmin Zheng's avatar
Lianmin Zheng committed
55
from sglang.srt.server_args import ServerArgs
56
from sglang.srt.utils import (
57
    enable_show_time_cost,
58
    get_available_gpu_memory,
59
    is_generation_model,
60
    is_multimodal_model,
61
    monkey_patch_vllm_dummy_weight_loader,
62
63
    monkey_patch_vllm_p2p_access_check,
)
64

Ying Sheng's avatar
Ying Sheng committed
65
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
66

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68

class ModelRunner:
69
70
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
    def __init__(
        self,
73
        model_config: ModelConfig,
74
75
76
77
78
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    ):
81
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
Zhang, Liangang's avatar
Zhang, Liangang committed
84
        self.device = server_args.device
85
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
        self.tp_rank = tp_rank
        self.tp_size = tp_size
Zhang, Liangang's avatar
Zhang, Liangang committed
88
        self.dist_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
89
        self.server_args = server_args
90
91
92
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
Ke Bao's avatar
Ke Bao committed
93

94
        # Model-specific adjustment
Ke Bao's avatar
Ke Bao committed
95
96
97
98
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and not self.server_args.disable_mla
        ):
Amos You's avatar
Amos You committed
99
            logger.info("MLA optimization is turned on. Use triton backend.")
Ke Bao's avatar
Ke Bao committed
100
101
            self.server_args.attention_backend = "triton"

102
103
104
105
106
107
108
        if self.is_multimodal_model:
            logger.info(
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95

109
110
111
112
113
114
        # Global vars
        if server_args.show_time_cost:
            enable_show_time_cost()
        if server_args.disable_disk_cache:
            disable_cache()

115
116
        global_server_args_dict.update(
            {
117
118
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
119
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
120
                "disable_mla": server_args.disable_mla,
121
                "torchao_config": server_args.torchao_config,
122
123
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
124

125
        # Init componnets
126
        min_per_gpu_memory = self.init_torch_distributed()
127
        self.sampler = Sampler()
128
        self.load_model()
129
130
        if server_args.lora_paths is not None:
            self.init_lora_manager()
131
132
        self.init_memory_pool(
            min_per_gpu_memory,
133
            server_args.max_running_requests,
134
135
            server_args.max_total_tokens,
        )
Zhang, Liangang's avatar
Zhang, Liangang committed
136
137
138
139
140
        if self.device == "cuda":
            self.init_cublas()
            self.init_attention_backend()
            self.init_cuda_graphs()
        else:
141
            self.cuda_graph_runner = None
Zhang, Liangang's avatar
Zhang, Liangang committed
142
            self.init_attention_backend()
143
144

    def init_torch_distributed(self):
145
        logger.info("Init torch distributed begin.")
Lianmin Zheng's avatar
Lianmin Zheng committed
146
        # Init torch distributed
Zhang, Liangang's avatar
Zhang, Liangang committed
147
148
149
        if self.device == "cuda":
            torch.cuda.set_device(self.gpu_id)
            backend = "nccl"
150
151
152
153
154
        # ToDO(liangan1):Just use gloo to bypass the initilization fail
        # Need to use xccl for xpu backend in the future
        elif self.device == "xpu":
            torch.xpu.set_device(self.gpu_id)
            backend = "gloo"
155

156
        if not self.server_args.enable_p2p_check:
157
            monkey_patch_vllm_p2p_access_check(self.gpu_id)
158
        if self.server_args.dist_init_addr:
Zhang, Liangang's avatar
Zhang, Liangang committed
159
            dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
160
        else:
Zhang, Liangang's avatar
Zhang, Liangang committed
161
            dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
162
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
163
        init_distributed_environment(
Zhang, Liangang's avatar
Zhang, Liangang committed
164
            backend=backend,
Lianmin Zheng's avatar
Lianmin Zheng committed
165
166
            world_size=self.tp_size,
            rank=self.tp_rank,
167
            local_rank=self.gpu_id,
Zhang, Liangang's avatar
Zhang, Liangang committed
168
            distributed_init_method=dist_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
171
        min_per_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
172
            self.device, self.gpu_id, distributed=self.tp_size > 1
173
        )
174
        self.tp_group = get_tp_group()
175

176
177
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
Zhang, Liangang's avatar
Zhang, Liangang committed
178
179
180
        if self.device == "cuda" and not all(
            in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
        ):
181
182
183
184
185
186
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
187
        if self.tp_size > 1:
Zhang, Liangang's avatar
Zhang, Liangang committed
188
            local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
189
            if min_per_gpu_memory < local_gpu_memory * 0.9:
190
191
192
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
193

194
        return min_per_gpu_memory
195

Lianmin Zheng's avatar
Lianmin Zheng committed
196
    def load_model(self):
197
        logger.info(
Zhang, Liangang's avatar
Zhang, Liangang committed
198
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
199
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
200
201
202

        # This can reduce thread conflicts and speed up weight loading.
        torch.set_num_threads(1)
Zhang, Liangang's avatar
Zhang, Liangang committed
203
204
205
206
207
208
209
210
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
211

Lianmin Zheng's avatar
Lianmin Zheng committed
212
        # Prepare the vllm model config
213
        monkey_patch_vllm_dummy_weight_loader()
214
215
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
216
217
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
218
219
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
220
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
221
            dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
222
            seed=self.server_args.random_seed,
223
224
            skip_tokenizer_init=True,
        )
225
        if self.model_config.model_override_args is not None:
226
            self.vllm_model_config.hf_config.update(
227
                self.model_config.model_override_args
228
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
229
        self.dtype = self.vllm_model_config.dtype
230

Lianmin Zheng's avatar
Lianmin Zheng committed
231
        # Load the model
232
        self.model = get_model(
233
234
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Zhang, Liangang's avatar
Zhang, Liangang committed
235
            device_config=DeviceConfig(self.device),
236
237
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
238
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
239
            cache_config=None,
240
        )
241
        self.sliding_window_size = (
242
243
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
244
245
            else None
        )
246
        self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
247
        self.is_generation = is_generation_model(
248
            self.model_config.hf_config.architectures, self.server_args.is_embedding
249
250
        )

251
        logger.info(
252
            f"Load weight end. "
253
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
254
            f"dtype={self.dtype}, "
Zhang, Liangang's avatar
Zhang, Liangang committed
255
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
256
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
257

258
259
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
260
261
262
263
264
265
266
267
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
268
            f"Update weights begin. "
Zhang, Liangang's avatar
Zhang, Liangang committed
269
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
270
271
        )

Zhang, Liangang's avatar
Zhang, Liangang committed
272
        target_device = torch.device(self.device)
273
274

        try:
275
            # TODO: Use a better method to check this
276
277
278
279
280
281
282
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
Lianmin Zheng's avatar
Lianmin Zheng committed
283
                seed=self.server_args.random_seed,
284
285
286
                skip_tokenizer_init=True,
            )
        except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
            message = f"Failed to load model config: {e}."
            return False, message
289
290
291
292
293
294

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
Lianmin Zheng's avatar
Lianmin Zheng committed
295
296
            message = f"Failed to get model loader: {loader}."
            return False, message
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
                config.model,
                config.revision,
                fall_back_to_pt=getattr(
                    self.model, "fall_back_to_pt_during_load", True
                ),
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
321
                message = f"Failed to get weights iterator: {e}."
322
323
324
325
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
326
327
328
                message = (
                    f"Failed to update weights: {e}.\nRolling back to original weights."
                )
329
330
331
332
333
334
335
336
337
338
339
340
341
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

342
        logger.info("Update weights end.")
Lianmin Zheng's avatar
Lianmin Zheng committed
343
        return True, "Succeeded to update model weights."
344

345
346
347
348
349
350
351
352
353
354
355
    def init_lora_manager(self):
        self.lora_manager = LoRAManager(
            base_model=self.model,
            lora_paths=self.server_args.lora_paths,
            base_hf_config=self.model_config.hf_config,
            max_loras_per_batch=self.server_args.max_loras_per_batch,
            load_config=self.load_config,
            dtype=self.dtype,
        )
        logger.info("LoRA manager ready.")

356
    def profile_max_num_token(self, total_gpu_memory: int):
357
        available_gpu_memory = get_available_gpu_memory(
Zhang, Liangang's avatar
Zhang, Liangang committed
358
            self.device, self.gpu_id, distributed=self.tp_size > 1
359
        )
360
361
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
362
            and not self.server_args.disable_mla
363
364
365
366
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
367
                * torch._utils._element_size(self.kv_cache_dtype)
368
369
370
371
372
373
374
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
375
                * torch._utils._element_size(self.kv_cache_dtype)
376
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
377
378
379
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
380
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
        return max_num_token

383
    def init_memory_pool(
384
385
        self,
        total_gpu_memory: int,
386
387
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
388
    ):
389
390
391
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
392
            self.kv_cache_dtype = torch.float8_e5m2
393
394
395
396
397
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

398
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
399
400
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
401
                logging.warning(
402
403
404
405
406
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
407

408
        if self.max_total_num_tokens <= 0:
409
            raise RuntimeError(
410
                "Not enough memory. Please try to increase --mem-fraction-static."
411
            )
412

Liangsheng Yin's avatar
Liangsheng Yin committed
413
        if max_num_reqs is None:
414
415
416
417
418
419
420
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
421
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
422
423
424
            )

        self.req_to_token_pool = ReqToTokenPool(
425
426
            size=max_num_reqs + 1,
            max_context_len=self.model_config.context_len + 4,
Zhang, Liangang's avatar
Zhang, Liangang committed
427
            device=self.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
428
        )
429
430
        if (
            self.model_config.attention_arch == AttentionArch.MLA
Ke Bao's avatar
Ke Bao committed
431
            and not self.server_args.disable_mla
432
433
434
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
435
                dtype=self.kv_cache_dtype,
436
437
438
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
439
                device=self.device,
440
441
442
443
            )
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
444
                dtype=self.kv_cache_dtype,
445
446
447
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
Zhang, Liangang's avatar
Zhang, Liangang committed
448
                device=self.device,
449
            )
450
        logger.info(
451
            f"Memory pool end. "
Zhang, Liangang's avatar
Zhang, Liangang committed
452
            f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
453
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
454

Lianmin Zheng's avatar
Lianmin Zheng committed
455
456
457
458
459
460
461
462
463
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

464
465
466
467
468
469
470
471
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
472
            )
473
474
475
476
            assert not self.has_cross_attention, (
                "Cross attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
            )
477
            self.attn_backend = TritonAttnBackend(self)
478
        else:
479
480
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
481
            )
482

483
    def init_cuda_graphs(self):
484
        """Capture cuda graphs."""
485
486
487
488
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

489
490
491
492
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

493
494
        if self.server_args.disable_cuda_graph:
            return
495

496
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
497
        self.cuda_graph_runner = CudaGraphRunner(self)
498

499
    def forward_decode(self, forward_batch: ForwardBatch):
500
        if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
501
            forward_batch.batch_size
502
        ):
503
            return self.cuda_graph_runner.replay(forward_batch)
504

505
        return self.model.forward(
506
            forward_batch.input_ids, forward_batch.positions, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
507
508
        )

509
    def forward_extend(self, forward_batch: ForwardBatch):
510
511
        if self.is_generation:
            return self.model.forward(
512
                forward_batch.input_ids, forward_batch.positions, forward_batch
513
514
515
516
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
517
518
519
                forward_batch.input_ids,
                forward_batch.positions,
                forward_batch,
520
521
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
522

523
524
525
526
527
    def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(forward_batch)
        elif forward_batch.forward_mode.is_extend():
            return self.forward_extend(forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
528
        else:
529
            raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
530

531
532
533
534
535
536
537
538
539
540
541
542
543
544
    def sample(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
        sampling_info = forward_batch.sampling_info
        sampling_info.update_regex_vocab_mask()
        sampling_info.update_penalties()
        logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)

        # Sample the next tokens.
        next_token_ids = self.sampler(logits, sampling_info)
        return next_token_ids

    def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
545
546
547
548
549
550
        # Apply logit_bias
        if sampling_info.logit_bias is not None:
            logits.add_(sampling_info.logit_bias)

        # min-token, presence, frequency
        if sampling_info.linear_penalties is not None:
551
            logits.add_(sampling_info.linear_penalties)
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566

        # repetition
        if sampling_info.scaling_penalties is not None:
            logits = torch.where(
                logits > 0,
                logits / sampling_info.scaling_penalties,
                logits * sampling_info.scaling_penalties,
            )

        # Apply regex vocab_mask
        if sampling_info.vocab_mask is not None:
            logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))

        return logits

567
568
569
570
571
572
573
574

@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
575
576
577
578
579
            try:
                module = importlib.import_module(name)
            except Exception as e:
                logger.warning(f"Ignore import error when loading {name}. " f"{e}")
                continue
580
            if hasattr(module, "EntryClass"):
581
                entry = module.EntryClass
582
583
584
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
585
                    for tmp in entry:
586
587
588
                        assert (
                            tmp.__name__ not in model_arch_name_to_cls
                        ), f"Duplicated model implementation for {tmp.__name__}"
589
                        model_arch_name_to_cls[tmp.__name__] = tmp
590
                else:
591
592
593
                    assert (
                        entry.__name__ not in model_arch_name_to_cls
                    ), f"Duplicated model implementation for {entry.__name__}"
594
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
595

596
597
598
599
600
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
601

602
603
604
605
606
607
608
609
610
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
611
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)