model_runner.py 20.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""ModelRunner runs the forward passes of the models."""
17

18
import gc
Cody Yu's avatar
Cody Yu committed
19
import importlib
20
21
22
import importlib.resources
import logging
import pkgutil
Cody Yu's avatar
Cody Yu committed
23
from functools import lru_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
24
from typing import Optional, Tuple, Type
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26

import torch
27
28
29
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
zhyncs's avatar
zhyncs committed
30
31
32
33
from vllm.distributed import (
    get_tp_group,
    init_distributed_environment,
    initialize_model_parallel,
34
    set_custom_all_reduce,
zhyncs's avatar
zhyncs committed
35
)
36
from vllm.distributed.parallel_state import in_the_same_node_as
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.models import ModelRegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
Liangsheng Yin's avatar
Liangsheng Yin committed
42
43
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
44
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
45
46
47
48
49
from sglang.srt.mem_cache.memory_pool import (
    MHATokenToKVPool,
    MLATokenToKVPool,
    ReqToTokenPool,
)
50
from sglang.srt.model_executor.forward_batch_info import InputMetadata
Lianmin Zheng's avatar
Lianmin Zheng committed
51
from sglang.srt.server_args import ServerArgs
52
53
from sglang.srt.utils import (
    get_available_gpu_memory,
54
    is_generation_model,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
    is_llama3_405b_fp8_head_16,
56
    is_multimodal_model,
57
    monkey_patch_vllm_dummy_weight_loader,
58
    monkey_patch_vllm_p2p_access_check,
59
    monkey_patch_vllm_qvk_linear_loader,
60
)
61

Ying Sheng's avatar
Ying Sheng committed
62
logger = logging.getLogger(__name__)
Lianmin Zheng's avatar
Lianmin Zheng committed
63

Lianmin Zheng's avatar
Lianmin Zheng committed
64
65

class ModelRunner:
66
67
    """ModelRunner runs the forward passes of the models."""

Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
    def __init__(
        self,
70
        model_config: ModelConfig,
71
72
73
74
75
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        nccl_port: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
76
        server_args: ServerArgs,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
    ):
78
        # Parse args
Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
        self.model_config = model_config
        self.mem_fraction_static = mem_fraction_static
81
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.nccl_port = nccl_port
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        self.server_args = server_args
86
87
88
        self.is_multimodal_model = is_multimodal_model(
            self.model_config.hf_config.architectures
        )
89
90
        global_server_args_dict.update(
            {
91
92
                "attention_backend": server_args.attention_backend,
                "sampling_backend": server_args.sampling_backend,
93
                "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
94
                "enable_mla": server_args.enable_mla,
95
                "torchao_config": server_args.torchao_config,
96
97
            }
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
98

99
        # Model-specific adjustment
100
101
102
103
104
105
106
        if self.is_multimodal_model:
            logger.info(
                "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
            )
            server_args.chunked_prefill_size = None
            server_args.mem_fraction_static *= 0.95

107
        # Init componnets
108
109
110
111
        min_per_gpu_memory = self.init_torch_distributed()
        self.load_model()
        self.init_memory_pool(
            min_per_gpu_memory,
112
            server_args.max_running_requests,
113
114
115
            server_args.max_total_tokens,
        )
        self.init_cublas()
116
        self.init_attention_backend()
117
118
119
        self.init_cuda_graphs()

    def init_torch_distributed(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        # Init torch distributed
121
        torch.cuda.set_device(self.gpu_id)
122
        logger.info("Init nccl begin.")
123

124
        if not self.server_args.enable_p2p_check:
125
126
            monkey_patch_vllm_p2p_access_check(self.gpu_id)

127
128
        if self.server_args.nccl_init_addr:
            nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
129
130
        else:
            nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
131
        set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
Lianmin Zheng's avatar
Lianmin Zheng committed
132
        init_distributed_environment(
Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
135
            backend="nccl",
            world_size=self.tp_size,
            rank=self.tp_rank,
136
            local_rank=self.gpu_id,
Ying Sheng's avatar
Ying Sheng committed
137
            distributed_init_method=nccl_init_method,
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
        )
        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
140
        min_per_gpu_memory = get_available_gpu_memory(
141
142
            self.gpu_id, distributed=self.tp_size > 1
        )
143
        self.tp_group = get_tp_group()
144

145
146
147
148
149
150
151
152
153
        # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
        # so we disable padding in cuda graph.
        if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
            self.server_args.disable_cuda_graph_padding = True
            logger.info(
                "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
            )

        # Check memory for tensor parallelism
154
        if self.tp_size > 1:
155
156
            local_gpu_memory = get_available_gpu_memory(self.gpu_id)
            if min_per_gpu_memory < local_gpu_memory * 0.9:
157
158
159
                raise ValueError(
                    "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
160

161
        return min_per_gpu_memory
162

Lianmin Zheng's avatar
Lianmin Zheng committed
163
    def load_model(self):
164
        torch.set_num_threads(1)
165
        logger.info(
166
            f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
167
        )
168
169
        if torch.cuda.get_device_capability()[0] < 8:
            logger.info(
170
                "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
171
172
            )
            self.server_args.dtype = "float16"
173
174
            if torch.cuda.get_device_capability()[1] < 5:
                raise RuntimeError("SGLang only supports sm75 and above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
175

176
        monkey_patch_vllm_dummy_weight_loader()
177
178
179
        self.device_config = DeviceConfig()
        self.load_config = LoadConfig(load_format=self.server_args.load_format)
        self.vllm_model_config = VllmModelConfig(
Lianmin Zheng's avatar
Lianmin Zheng committed
180
181
            model=self.server_args.model_path,
            quantization=self.server_args.quantization,
182
183
            tokenizer=None,
            tokenizer_mode=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
184
            trust_remote_code=self.server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
185
            dtype=self.server_args.dtype,
186
187
188
            seed=42,
            skip_tokenizer_init=True,
        )
189

190
191
        # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
        # Drop this after Sept, 2024.
Lianmin Zheng's avatar
Lianmin Zheng committed
192
        if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
193
            self.model_config.hf_config.num_key_value_heads = 8
194
            self.vllm_model_config.hf_config.num_key_value_heads = 8
195
196
            monkey_patch_vllm_qvk_linear_loader()

197
        self.dtype = self.vllm_model_config.dtype
198
        if self.model_config.model_override_args is not None:
199
            self.vllm_model_config.hf_config.update(
200
                self.model_config.model_override_args
201
            )
202
203

        self.model = get_model(
204
205
            model_config=self.vllm_model_config,
            load_config=self.load_config,
Yineng Zhang's avatar
Yineng Zhang committed
206
            device_config=self.device_config,
207
208
            parallel_config=None,
            scheduler_config=None,
Yineng Zhang's avatar
Yineng Zhang committed
209
            lora_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
210
            cache_config=None,
211
        )
212
        self.sliding_window_size = (
213
214
            self.model.get_attention_sliding_window_size()
            if hasattr(self.model, "get_attention_sliding_window_size")
215
216
            else None
        )
217
        self.is_generation = is_generation_model(
218
            self.model_config.hf_config.architectures, self.server_args.is_embedding
219
220
        )

221
        logger.info(
222
            f"Load weight end. "
223
            f"type={type(self.model).__name__}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
224
            f"dtype={self.dtype}, "
225
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
226
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
227

228
229
    def update_weights(self, model_path: str, load_format: str):
        """Update weights in-place."""
230
231
232
233
234
235
236
237
        from vllm.model_executor.model_loader.loader import (
            DefaultModelLoader,
            device_loading_context,
            get_model_loader,
        )
        from vllm.model_executor.model_loader.utils import set_default_torch_dtype

        logger.info(
238
            f"Update weights begin. "
239
240
241
242
243
244
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
        )

        target_device = torch.device(self.device_config.device)

        try:
245
            # TODO: Use a better method to check this
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            vllm_model_config = VllmModelConfig(
                model=model_path,
                quantization=self.server_args.quantization,
                tokenizer=None,
                tokenizer_mode=None,
                trust_remote_code=self.server_args.trust_remote_code,
                dtype=self.server_args.dtype,
                seed=42,
                skip_tokenizer_init=True,
            )
        except Exception as e:
            logger.error(f"Failed to load model config: {e}")
            return False, "Failed to update model weights"

        load_config = LoadConfig(load_format=load_format)

        # Only support vllm DefaultModelLoader for now
        loader = get_model_loader(load_config)
        if not isinstance(loader, DefaultModelLoader):
            logger.error("Failed to get weights iterator: Unsupported loader")
            return False, "Failed to update model weights"

        def get_weight_iter(config):
            iter = loader._get_weights_iterator(
                config.model,
                config.revision,
                fall_back_to_pt=getattr(
                    self.model, "fall_back_to_pt_during_load", True
                ),
            )
            return iter

        def model_load_weights(model, iter):
            model.load_weights(iter)
            for _, module in self.model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
            return model

        with set_default_torch_dtype(vllm_model_config.dtype):
            try:
                iter = get_weight_iter(vllm_model_config)
            except Exception as e:
                message = f"Failed to get weights iterator: {e}"
                logger.error(message)
                return False, message
            try:
                model = model_load_weights(self.model, iter)
            except Exception as e:
                message = f"Failed to update weights: {e}. \n Rolling back to original weights"
                logger.error(message)
                del iter
                gc.collect()
                iter = get_weight_iter(self.vllm_model_config)
                self.model = model_load_weights(self.model, iter)
                return False, message

        self.model = model
        self.server_args.model_path = model_path
        self.server_args.load_format = load_format
        self.vllm_model_config = vllm_model_config
        self.load_config = load_config
        self.model_config.path = model_path

312
        logger.info("Update weights end.")
313
314
        return True, "Succeeded to update model weights"

315
    def profile_max_num_token(self, total_gpu_memory: int):
316
317
318
        available_gpu_memory = get_available_gpu_memory(
            self.gpu_id, distributed=self.tp_size > 1
        )
319
320
321
322
323
324
325
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            cell_size = (
                (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                * self.model_config.num_hidden_layers
326
                * torch._utils._element_size(self.kv_cache_dtype)
327
328
329
330
331
332
333
            )
        else:
            cell_size = (
                self.model_config.get_num_kv_heads(self.tp_size)
                * self.model_config.head_dim
                * self.model_config.num_hidden_layers
                * 2
334
                * torch._utils._element_size(self.kv_cache_dtype)
335
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
336
337
338
        rest_memory = available_gpu_memory - total_gpu_memory * (
            1 - self.mem_fraction_static
        )
339
        max_num_token = int(rest_memory * (1 << 30) // cell_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
340
341
        return max_num_token

342
    def init_memory_pool(
343
344
        self,
        total_gpu_memory: int,
345
346
        max_num_reqs: Optional[int] = None,
        max_total_tokens: Optional[int] = None,
347
    ):
348
349
350
        if self.server_args.kv_cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        elif self.server_args.kv_cache_dtype == "fp8_e5m2":
351
            self.kv_cache_dtype = torch.float8_e5m2
352
353
354
355
356
        else:
            raise ValueError(
                f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
            )

357
        self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
358
359
        if max_total_tokens is not None:
            if max_total_tokens > self.max_total_num_tokens:
360
                logging.warning(
361
362
363
364
365
                    f"max_total_tokens={max_total_tokens} is larger than the profiled value "
                    f"{self.max_total_num_tokens}. "
                    f"Use the profiled value instead."
                )
            self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
366

367
        if self.max_total_num_tokens <= 0:
368
            raise RuntimeError(
369
                "Not enough memory. Please try to increase --mem-fraction-static."
370
            )
371

Liangsheng Yin's avatar
Liangsheng Yin committed
372
        if max_num_reqs is None:
373
374
375
376
377
378
379
            max_num_reqs = min(
                max(
                    int(
                        self.max_total_num_tokens / self.model_config.context_len * 512
                    ),
                    2048,
                ),
380
                4096,
Liangsheng Yin's avatar
Liangsheng Yin committed
381
382
383
384
            )

        self.req_to_token_pool = ReqToTokenPool(
            max_num_reqs,
Lianmin Zheng's avatar
Lianmin Zheng committed
385
386
            self.model_config.context_len + 8,
        )
387
388
389
390
391
392
        if (
            self.model_config.attention_arch == AttentionArch.MLA
            and self.server_args.enable_mla
        ):
            self.token_to_kv_pool = MLATokenToKVPool(
                self.max_total_num_tokens,
393
                dtype=self.kv_cache_dtype,
394
395
396
397
398
399
400
                kv_lora_rank=self.model_config.kv_lora_rank,
                qk_rope_head_dim=self.model_config.qk_rope_head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
        else:
            self.token_to_kv_pool = MHATokenToKVPool(
                self.max_total_num_tokens,
401
                dtype=self.kv_cache_dtype,
402
403
404
405
                head_num=self.model_config.get_num_kv_heads(self.tp_size),
                head_dim=self.model_config.head_dim,
                layer_num=self.model_config.num_hidden_layers,
            )
406
        logger.info(
407
            f"Memory pool end. "
408
            f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
409
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
410

Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
413
414
415
416
417
418
419
    def init_cublas(self):
        """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
        dtype = torch.float16
        device = "cuda"
        a = torch.ones((16, 16), dtype=dtype, device=device)
        b = torch.ones((16, 16), dtype=dtype, device=device)
        c = a @ b
        return c

420
421
422
423
424
425
426
427
    def init_attention_backend(self):
        """Init attention kernel backend."""
        if self.server_args.attention_backend == "flashinfer":
            self.attn_backend = FlashInferAttnBackend(self)
        elif self.server_args.attention_backend == "triton":
            assert self.sliding_window_size is None, (
                "Window attention is not supported in the triton attention backend. "
                "Please use `--attention-backend flashinfer`."
428
            )
429
            self.attn_backend = TritonAttnBackend(self)
430
        else:
431
432
            raise ValueError(
                f"Invalid attention backend: {self.server_args.attention_backend}"
433
            )
434

435
    def init_cuda_graphs(self):
436
        """Capture cuda graphs."""
437
438
439
440
        from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

        self.cuda_graph_runner = None

441
442
443
444
        if not self.is_generation:
            # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
            return

445
446
        if self.server_args.disable_cuda_graph:
            return
447

448
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
449
        self.cuda_graph_runner = CudaGraphRunner(self)
450

Lianmin Zheng's avatar
Lianmin Zheng committed
451
    @torch.inference_mode()
452
    def forward_decode(self, batch: ScheduleBatch):
Liangsheng Yin's avatar
Liangsheng Yin committed
453
454
455
        if (
            self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(len(batch.reqs))
456
            and batch.sampling_info.can_run_in_cuda_graph()
Liangsheng Yin's avatar
Liangsheng Yin committed
457
        ):
458
459
            return self.cuda_graph_runner.replay(batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
460
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
461

462
463
        return self.model.forward(
            batch.input_ids, input_metadata.positions, input_metadata
Lianmin Zheng's avatar
Lianmin Zheng committed
464
465
466
        )

    @torch.inference_mode()
467
    def forward_extend(self, batch: ScheduleBatch):
Liangsheng Yin's avatar
Liangsheng Yin committed
468
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
469
470
471
472
473
474
475
476
477
478
479
480
        if self.is_generation:
            return self.model.forward(
                batch.input_ids, input_metadata.positions, input_metadata
            )
        else:
            # Only embedding models have get_embedding parameter
            return self.model.forward(
                batch.input_ids,
                input_metadata.positions,
                input_metadata,
                get_embedding=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
481
482

    @torch.inference_mode()
483
    def forward_extend_multi_modal(self, batch: ScheduleBatch):
Liangsheng Yin's avatar
Liangsheng Yin committed
484
        input_metadata = InputMetadata.from_schedule_batch(self, batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
485
        return self.model.forward(
486
            batch.input_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
487
488
            input_metadata.positions,
            input_metadata,
489
490
491
            input_metadata.pixel_values,
            input_metadata.image_sizes,
            input_metadata.image_offsets,
Lianmin Zheng's avatar
Lianmin Zheng committed
492
493
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
494
    def forward(
Liangsheng Yin's avatar
Liangsheng Yin committed
495
        self, batch: ScheduleBatch
Liangsheng Yin's avatar
Liangsheng Yin committed
496
    ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
Liangsheng Yin's avatar
Liangsheng Yin committed
497
498
499
        assert batch.forward_mode is not None

        if self.is_multimodal_model and batch.forward_mode.is_extend():
500
            return self.forward_extend_multi_modal(batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
501
        elif batch.forward_mode.is_decode():
502
            return self.forward_decode(batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
503
        elif batch.forward_mode.is_extend():
504
            return self.forward_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
505
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
506
            raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
507
508
509
510
511
512
513
514
515
516
517


@lru_cache()
def import_model_classes():
    model_arch_name_to_cls = {}
    package_name = "sglang.srt.models"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
        if not ispkg:
            module = importlib.import_module(name)
            if hasattr(module, "EntryClass"):
518
                entry = module.EntryClass
519
520
521
                if isinstance(
                    entry, list
                ):  # To support multiple model classes in one module
522
                    for tmp in entry:
523
                        assert tmp.__name__ not in model_arch_name_to_cls
524
                        model_arch_name_to_cls[tmp.__name__] = tmp
525
                else:
526
                    assert entry.__name__ not in model_arch_name_to_cls
527
                    model_arch_name_to_cls[entry.__name__] = entry
Qubitium's avatar
Qubitium committed
528

529
530
531
532
533
    return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
    model_arch_name_to_cls = import_model_classes()
Qubitium's avatar
Qubitium committed
534

535
536
537
538
539
540
541
542
543
    if model_arch not in model_arch_name_to_cls:
        raise ValueError(
            f"Unsupported architectures: {model_arch}. "
            f"Supported list: {list(model_arch_name_to_cls.keys())}"
        )
    return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
Yineng Zhang's avatar
Yineng Zhang committed
544
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)