# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ModelRunner runs the forward passes of the models.""" import gc import json import logging import time from typing import List, Optional, Tuple import torch import torch.distributed as dist from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.distributed import ( get_tp_group, init_distributed_environment, initialize_model_parallel, set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, initialize_dp_attention, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, init_custom_process_group, is_cuda, is_hip, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, ) logger = logging.getLogger(__name__) class ModelRunner: """ModelRunner runs the forward passes of the models.""" def __init__( self, model_config: ModelConfig, mem_fraction_static: float, gpu_id: int, tp_rank: int, tp_size: int, nccl_port: int, server_args: ServerArgs, is_draft_worker: bool = False, ): # Parse args self.model_config = model_config self.mem_fraction_static = mem_fraction_static self.device = server_args.device self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size self.dist_port = nccl_port self.server_args = server_args self.is_draft_worker = is_draft_worker self.is_generation = model_config.is_generation self.is_multimodal = model_config.is_multimodal self.should_log = tp_rank == 0 self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) # Model-specific adjustment if ( self.model_config.attention_arch == AttentionArch.MLA and not self.server_args.disable_mla ): # TODO: add MLA optimization on CPU if self.server_args.device != "cpu": logger.info("MLA optimization is turned on. Use triton backend.") self.server_args.attention_backend = "triton" if self.server_args.enable_double_sparsity: logger.info( "Double sparsity optimization is turned on. Use triton backend without CUDA graph." ) self.server_args.attention_backend = "triton" self.server_args.disable_cuda_graph = True if self.server_args.ds_heavy_channel_type is None: raise ValueError( "Please specify the heavy channel type for double sparsity optimization." ) self.init_double_sparsity_channel_config( self.server_args.ds_heavy_channel_type ) if self.is_multimodal: self.mem_fraction_static *= 0.95 logger.info( f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " f"because this is a multimodal model." ) if self.model_config.hf_config.architectures == [ "MllamaForConditionalGeneration" ]: logger.info("Automatically turn off --chunked-prefill-size for mllama.") server_args.chunked_prefill_size = -1 if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" ]: # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically logger.info( "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl." ) server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True # Global vars if server_args.show_time_cost: enable_show_time_cost() if server_args.disable_outlines_disk_cache: from outlines.caching import disable_cache disable_cache() global_server_args_dict.update( { "attention_backend": server_args.attention_backend, "sampling_backend": server_args.sampling_backend, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "disable_mla": server_args.disable_mla, "torchao_config": server_args.torchao_config, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, "device": server_args.device, } ) set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=self.server_args.enable_memory_saver ) # Load the model self.sampler = Sampler() self.load_model() # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) # In layered loading, torchao may have been applied if not torchao_applied: apply_torchao_config_to_model( self.model, global_server_args_dict["torchao_config"] ) # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) if self.tp_size > 1 and supports_torch_tp: self.apply_torch_tp() self.torch_tp_applied = True else: self.torch_tp_applied = False # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager() self.init_memory_pool( min_per_gpu_memory, server_args.max_running_requests, server_args.max_total_tokens, ) if self.device == "cuda": self.init_cublas() self.init_attention_backend() self.init_cuda_graphs() else: self.cuda_graph_runner = None self.init_attention_backend() def init_torch_distributed(self): logger.info("Init torch distributed begin.") torch.get_device_module(self.device).set_device(self.gpu_id) if self.device == "cuda": backend = "nccl" elif self.device == "xpu": # TODO(liangan1): Just use gloo to bypass the initilization fail # Need to use xccl for xpu backend in the future backend = "gloo" elif self.device == "hpu": backend = "hccl" elif self.device == "cpu": backend = "gloo" if not self.server_args.enable_p2p_check: monkey_patch_p2p_access_check() if self.server_args.dist_init_addr: dist_init_method = f"tcp://{self.server_args.dist_init_addr}" else: dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) if not self.is_draft_worker: # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, world_size=self.tp_size, rank=self.tp_rank, local_rank=self.gpu_id, distributed_init_method=dist_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_dp_attention( enable_dp_attention=self.server_args.enable_dp_attention, tp_rank=self.tp_rank, tp_size=self.tp_size, dp_size=self.server_args.dp_size, ) min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) self.tp_group = get_tp_group() self.attention_tp_group = get_attention_tp_group() # Check memory for tensor parallelism if self.tp_size > 1: local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) if min_per_gpu_memory < local_gpu_memory * 0.9: raise ValueError( "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." ) return min_per_gpu_memory def load_model(self): logger.info( f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) # This can reduce thread conflicts and speed up weight loading. if self.device != "cpu": torch.set_num_threads(1) if self.device == "cuda": if torch.cuda.get_device_capability()[0] < 8: logger.info( "Compute capability below sm80. Use float16 due to lack of bfloat16 support." ) self.server_args.dtype = "float16" self.model_config.dtype = torch.float16 if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") # Prepare the model config self.load_config = LoadConfig( load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, ) if self.server_args.load_format == "gguf": monkey_patch_vllm_gguf_config() # Load the model # Remove monkey_patch when linear.py quant remove dependencies with vllm monkey_patch_vllm_parallel_state() with self.memory_saver_adapter.region(): self.model = get_model( model_config=self.model_config, load_config=self.load_config, device_config=DeviceConfig(self.device), ) monkey_patch_vllm_parallel_state(reverse=True) if self.server_args.kv_cache_dtype == "fp8_e4m3": if self.server_args.quantization_param_path is not None: if callable(getattr(self.model, "load_kv_cache_scales", None)): self.model.load_kv_cache_scales( self.server_args.quantization_param_path ) logger.info( "Loaded KV cache scaling factors from %s", self.server_args.quantization_param_path, ) else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " "model %s does not support loading scaling factors.", self.model.__class__, ) else: logger.warning( "Using FP8 KV cache but no scaling factors " "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!" ) # Parse other args self.sliding_window_size = ( self.model.get_attention_sliding_window_size() if hasattr(self.model, "get_attention_sliding_window_size") else None ) self.dtype = self.model_config.dtype logger.info( f"Load weight end. " f"type={type(self.model).__name__}, " f"dtype={self.dtype}, " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) def update_weights_from_disk( self, model_path: str, load_format: str ) -> tuple[bool, str]: """Update engine weights in-place from the disk.""" from sglang.srt.model_loader.loader import ( DefaultModelLoader, device_loading_context, get_model_loader, ) from sglang.srt.model_loader.utils import set_default_torch_dtype logger.info( f"Update engine weights online from disk begin. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) target_device = torch.device(self.device) self.model_config.model_path = model_path load_config = LoadConfig(load_format=load_format) # Only support vllm DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): message = f"Failed to get model loader: {loader}." return False, message def get_weight_iter(config): iter = loader._get_weights_iterator( DefaultModelLoader.Source( config.model_path, revision=config.revision, fall_back_to_pt=getattr( self.model, "fall_back_to_pt_during_load", True ), ) ) return iter def model_load_weights(model, iter): model.load_weights(iter) for _, module in self.model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) return model with set_default_torch_dtype(self.model_config.dtype): try: iter = get_weight_iter(self.model_config) except Exception as e: message = f"Failed to get weights iterator: {e}." return False, message try: model = model_load_weights(self.model, iter) except Exception as e: message = ( f"Failed to update weights: {e}.\nRolling back to original weights." ) del iter gc.collect() iter = get_weight_iter(self.model_config) self.model = model_load_weights(self.model, iter) return False, message self.model = model self.server_args.model_path = model_path self.server_args.load_format = load_format self.load_config = load_config logger.info("Update weights end.") return True, "Succeeded to update model weights." def init_weights_update_group( self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", ): """Initialize the Torch process group for model parameter updates. `_model_update_group` is used in the RLHF workflow, where rank 0 is the actor model in the training engine, and the other ranks are the inference engine, which is used for rollout. In the RLHF workflow, the training engine updates the model weights/parameters online, and broadcasts them to the inference engine through the `_model_update_group` process group. """ assert ( torch.distributed.is_initialized() ), "Default torch process group must be initialized" assert group_name != "", "Group name cannot be empty" rank = rank_offset + self.tp_rank logger.info( f"init custom process group: master_address={master_address}, master_port={master_port}, " f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}" ) try: self._model_update_group = init_custom_process_group( backend=backend, init_method=f"tcp://{master_address}:{master_port}", world_size=world_size, rank=rank, group_name=group_name, ) dist.barrier(group=self._model_update_group, device_ids=[rank]) return True, "Succeeded to initialize custom process group." except Exception as e: message = f"Failed to initialize custom process group: {e}." logger.error(message) return False, message def update_weights_from_distributed(self, name, dtype, shape): """ Update specific parameter in the model weights online through `_model_update_group` process group. Args: name: the name of the parameter to be updated. dtype: the data type of the parameter to be updated. shape: the shape of the parameter to be updated. """ target_dtype = ( dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) ) assert ( self._model_update_group is not None ), "model update group must be initialized" try: weights = torch.empty(shape, dtype=target_dtype, device=self.device) torch.distributed.broadcast(weights, src=0, group=self._model_update_group) self.model.load_weights([(name, weights)]) return True, f"Succeeded to update parameter {name} online." except Exception as e: error_msg = ( f"Failed to update parameter online: {e}. " f"The full weights of the ModelRunner are partially updated. " f"Please discard the whole weights." ) logger.error(error_msg) return False, error_msg def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): self.model.load_weights(named_tensors) return True, "Success" def get_weights_by_name( self, name: str, truncate_size: int = 100 ) -> Optional[torch.Tensor]: """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. Only used for unit test with an unoptimized performance. For optimized performance, please use torch.save and torch.load. """ # TODO: (chenyang) Add support for Qwen models. try: return self.model.get_weights_by_name( name, truncate_size, tp_size=self.tp_size ) except Exception as e: logger.error(f"Error when getting parameter {name}: {e}") return None def init_lora_manager(self): self.lora_manager = LoRAManager( base_model=self.model, lora_paths=self.server_args.lora_paths, base_hf_config=self.model_config.hf_config, max_loras_per_batch=self.server_args.max_loras_per_batch, load_config=self.load_config, dtype=self.dtype, ) logger.info("LoRA manager ready.") def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) if ( self.model_config.attention_arch == AttentionArch.MLA and not self.server_args.disable_mla ): cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * self.model_config.num_hidden_layers * torch._utils._element_size(self.kv_cache_dtype) ) else: cell_size = ( self.model_config.get_num_kv_heads(get_attention_tp_size()) * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.kv_cache_dtype) ) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token def init_memory_pool( self, total_gpu_memory: int, max_num_reqs: Optional[int] = None, max_total_tokens: Optional[int] = None, ): if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": if is_hip(): # Using natively supported format self.kv_cache_dtype = torch.float8_e5m2fnuz else: self.kv_cache_dtype = torch.float8_e5m2 elif self.server_args.kv_cache_dtype == "fp8_e4m3": if is_cuda(): self.kv_cache_dtype = torch.float8_e4m3fn else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) if max_num_reqs is None: max_num_reqs = min( max( int( self.max_total_num_tokens / self.model_config.context_len * 512 ), 2048, ), 4096, ) if not self.spec_algorithm.is_none(): if self.is_draft_worker: self.max_total_num_tokens = self.server_args.draft_runner_cache_size else: self.server_args.draft_runner_cache_size = ( self.max_total_num_tokens + max_num_reqs * self.server_args.speculative_num_steps + 100 ) if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: logging.warning( f"max_total_tokens={max_total_tokens} is larger than the profiled value " f"{self.max_total_num_tokens}. " f"Use the profiled value instead." ) self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) if self.max_total_num_tokens <= 0: raise RuntimeError( "Not enough memory. Please try to increase --mem-fraction-static." ) self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) if ( self.model_config.attention_arch == AttentionArch.MLA and not self.server_args.disable_mla ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, heavy_channel_num=self.server_args.ds_heavy_channel_num, enable_memory_saver=self.server_args.enable_memory_saver, ) else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) logger.info( f"Memory pool end. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) def init_cublas(self): """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later.""" dtype = torch.float16 device = "cuda" a = torch.ones((16, 16), dtype=dtype, device=device) b = torch.ones((16, 16), dtype=dtype, device=device) c = a @ b return c def init_attention_backend(self): """Init attention kernel backend.""" if self.server_args.attention_backend == "flashinfer": self.attn_backend = FlashInferAttnBackend(self) elif self.server_args.attention_backend == "triton": assert self.sliding_window_size is None, ( "Window attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." ) assert not self.model_config.is_encoder_decoder, ( "Cross attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." ) if self.server_args.enable_double_sparsity: self.attn_backend = DoubleSparseAttnBackend(self) else: self.attn_backend = TritonAttnBackend(self) elif self.server_args.attention_backend == "torch_native": self.attn_backend = TorchNativeAttnBackend(self) else: raise ValueError( f"Invalid attention backend: {self.server_args.attention_backend}" ) def init_double_sparsity_channel_config(self, selected_channel): selected_channel = "." + selected_channel + "_proj" self.sorted_channels = [] # load channel config with open(self.server_args.ds_channel_config_path, "r") as f: channel_config = json.load(f) for i in range(self.model_config.num_hidden_layers): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ :, : self.server_args.ds_heavy_channel_num ] .contiguous() .cuda() ) def init_cuda_graphs(self): """Capture cuda graphs.""" from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner self.cuda_graph_runner = None if not self.is_generation: # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models return if self.server_args.disable_cuda_graph: return tic = time.time() logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") from sglang.srt.model_parallel import tensor_parallel device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) tensor_parallel(self.model, device_mesh) def forward_decode(self, forward_batch: ForwardBatch): self.attn_backend.init_forward_metadata(forward_batch) return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) def forward_extend(self, forward_batch: ForwardBatch): self.attn_backend.init_forward_metadata(forward_batch) if self.is_generation: if forward_batch.input_embeds is None: return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) else: return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch, input_embeds=forward_batch.input_embeds.bfloat16(), ) else: # Only embedding models have get_embedding parameter return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch, get_embedding=True, ) def forward_idle(self, forward_batch: ForwardBatch): return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: if ( forward_batch.forward_mode.is_cuda_graph() and self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch) ): return self.cuda_graph_runner.replay(forward_batch) if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) elif forward_batch.forward_mode.is_extend(): return self.forward_extend(forward_batch) elif forward_batch.forward_mode.is_idle(): return self.forward_idle(forward_batch) else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") def sample( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ) -> torch.Tensor: # Apply logit bias sampling_info = forward_batch.sampling_info if sampling_info.sampling_info_done: # Overlap mode: the function update_regex_vocab_mask was executed # in process_batch_result of the last batch. if sampling_info.grammars: sampling_info.sampling_info_done.wait() else: # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info.update_regex_vocab_mask() sampling_info.update_penalties() sampling_info.apply_logits_bias(logits_output.next_token_logits) # Sample the next tokens next_token_ids = self.sampler( logits_output, sampling_info, forward_batch.return_logprob, forward_batch.top_logprobs_nums, ) return next_token_ids @property def model_is_mrope(self) -> bool: """Detect if the model has "mrope" rope_scaling type. mrope requires keep "rope_deltas" between prompt and decoding phases.""" rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) if rope_scaling is None: return False return rope_scaling.get("type", None) == "mrope"