Unverified Commit 1083e7e3 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Deprecate `global_server_args_dict` (#11331)

parent 2157d12a
...@@ -6,9 +6,6 @@ ...@@ -6,9 +6,6 @@
class GlobalConfig: class GlobalConfig:
""" """
Store some global constants. Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
""" """
def __init__(self): def __init__(self):
......
...@@ -5,7 +5,7 @@ from packaging import version ...@@ -5,7 +5,7 @@ from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """ nccl_allocator_source = """
#include <nccl.h> #include <nccl.h>
...@@ -32,7 +32,7 @@ _graph_pool_id = None ...@@ -32,7 +32,7 @@ _graph_pool_id = None
def is_symmetric_memory_enabled(): def is_symmetric_memory_enabled():
return global_server_args_dict["enable_symm_mem"] return get_global_server_args().enable_symm_mem
def set_graph_pool_id(graph_pool_id): def set_graph_pool_id(graph_pool_id):
......
...@@ -18,7 +18,7 @@ from typing import Literal, Optional ...@@ -18,7 +18,7 @@ from typing import Literal, Optional
import torch import torch
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
@dataclass @dataclass
...@@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo: ...@@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo:
@classmethod @classmethod
def init_new(cls, layer_id: int): def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm
expert_location_metadata = get_global_expert_location_metadata() expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None assert expert_location_metadata is not None
......
...@@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import ( ...@@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import (
ExpertLocationMetadata, ExpertLocationMetadata,
get_global_expert_location_metadata, get_global_expert_location_metadata,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -97,7 +97,7 @@ def _update_expert_weights_with_canary( ...@@ -97,7 +97,7 @@ def _update_expert_weights_with_canary(
canary_tensor = ( canary_tensor = (
_get_canary_value(old_expert_location_metadata, layer_id) _get_canary_value(old_expert_location_metadata, layer_id)
.clone() .clone()
.to(device=global_server_args_dict["device"], non_blocking=True) .to(device=get_global_server_args().device, non_blocking=True)
) )
routed_experts_weights_of_layer[layer_id].append(canary_tensor) routed_experts_weights_of_layer[layer_id].append(canary_tensor)
......
...@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING ...@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
import torch import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
# TODO: Change the hard-coded block_seq_num # TODO: Change the hard-coded block_seq_num
self.BLOCK_SEQ = 128 self.BLOCK_SEQ = 128
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): if get_global_server_args().triton_attention_reduce_in_fp32:
self.reduce_dtype = torch.float32 self.reduce_dtype = torch.float32
else: else:
self.reduce_dtype = torch.float16 self.reduce_dtype = torch.float16
......
...@@ -11,8 +11,8 @@ import triton.language as tl ...@@ -11,8 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
): ):
# Do multi-head attention with chunked prefix cache # Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
assert not global_server_args_dict["disable_chunked_prefix_cache"] assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA # MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None assert forward_batch.prefix_chunk_cu_seq_lens is not None
......
...@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import ( ...@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ( from sglang.srt.utils import (
is_flashinfer_available, is_flashinfer_available,
...@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.enable_chunk_kv = ( self.enable_chunk_kv = (
not skip_prefill not skip_prefill
and global_server_args_dict["disaggregation_mode"] != "decode" and get_global_server_args().disaggregation_mode != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"] and not get_global_server_args().disable_chunked_prefix_cache
and not global_server_args_dict["flashinfer_mla_disable_ragged"] and not get_global_server_args().flashinfer_mla_disable_ragged
) )
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
...@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = ( use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"] not get_global_server_args().flashinfer_mla_disable_ragged
and extend_no_prefix and extend_no_prefix
) )
......
...@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear ...@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
...@@ -162,7 +162,7 @@ class Indexer(CustomOp): ...@@ -162,7 +162,7 @@ class Indexer(CustomOp):
base=rope_theta, # type: ignore base=rope_theta, # type: ignore
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=False, is_neox_style=False,
device=global_server_args_dict["device"], device=get_global_server_args().device,
) )
self.block_size = block_size self.block_size = block_size
self.scale_fmt = scale_fmt self.scale_fmt = scale_fmt
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -11,7 +11,7 @@ if _is_cuda: ...@@ -11,7 +11,7 @@ if _is_cuda:
_is_hip = is_hip() _is_hip = is_hip()
if global_server_args_dict.get("attention_reduce_in_fp32", False): if get_global_server_args().triton_attention_reduce_in_fp32:
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32 REDUCE_TORCH_TYPE = torch.float32
else: else:
......
...@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import ( ...@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton, create_flashmla_kv_indices_triton,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available from sglang.srt.utils import is_cuda, is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
...@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
self.disable_chunked_prefix_cache = global_server_args_dict[ self.disable_chunked_prefix_cache = (
"disable_chunked_prefix_cache" get_global_server_args().disable_chunked_prefix_cache
] )
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
......
...@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import ( ...@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
ROTARY_EMBED_CLASSES = { ROTARY_EMBED_CLASSES = {
...@@ -468,7 +468,7 @@ class VisionAttention(nn.Module): ...@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
_passed_backend = qkv_backend _passed_backend = qkv_backend
qkv_backend = self._determine_attention_backend(_passed_backend) qkv_backend = self._determine_attention_backend(_passed_backend)
if ( if (
global_server_args_dict["mm_attention_backend"] is None get_global_server_args().mm_attention_backend is None
and _passed_backend is None and _passed_backend is None
): ):
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
...@@ -528,7 +528,7 @@ class VisionAttention(nn.Module): ...@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
- CUDA: "triton_attn" - CUDA: "triton_attn"
- Non-CUDA: "sdpa" - Non-CUDA: "sdpa"
""" """
override_backend = global_server_args_dict["mm_attention_backend"] override_backend = get_global_server_args().mm_attention_backend
if override_backend is not None: if override_backend is not None:
backend = override_backend backend = override_backend
elif passed_backend is not None: elif passed_backend is not None:
......
...@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import ( ...@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend, get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_cuda, is_cuda,
...@@ -168,7 +169,7 @@ class LayerScatterModes: ...@@ -168,7 +169,7 @@ class LayerScatterModes:
def enable_moe_dense_fully_dp(): def enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1 return get_global_server_args().moe_dense_tp_size == 1
class LayerCommunicator: class LayerCommunicator:
...@@ -314,7 +315,9 @@ class LayerCommunicator: ...@@ -314,7 +315,9 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer( def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> bool: ) -> bool:
speculative_algo = global_server_args_dict.get("speculative_algorithm", None) speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if ( if (
is_dp_attention_enabled() is_dp_attention_enabled()
and speculative_algo is not None and speculative_algo is not None
...@@ -333,7 +336,7 @@ class LayerCommunicator: ...@@ -333,7 +336,7 @@ class LayerCommunicator:
static_conditions_met = ( static_conditions_met = (
(not self.is_last_layer) (not self.is_last_layer)
and (self._context.tp_size > 1) and (self._context.tp_size > 1)
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) and get_global_server_args().enable_flashinfer_allreduce_fusion
and _is_flashinfer_available and _is_flashinfer_available
) )
...@@ -531,7 +534,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -531,7 +534,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
(_is_sm100_supported or _is_sm90_supported) (_is_sm100_supported or _is_sm90_supported)
and _is_flashinfer_available and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion") and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and get_global_server_args().enable_flashinfer_allreduce_fusion
and hidden_states.shape[0] <= 4096 and hidden_states.shape[0] <= 4096
): ):
hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual = layernorm.forward_with_allreduce_fusion(
......
...@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import ( ...@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
get_dp_device, get_dp_device,
get_dp_dtype, get_dp_dtype,
get_dp_hidden_size, get_dp_hidden_size,
get_global_dp_buffer,
get_local_attention_dp_size, get_local_attention_dp_size,
set_dp_buffer_len,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -230,8 +228,8 @@ class LogitsProcessor(nn.Module): ...@@ -230,8 +228,8 @@ class LogitsProcessor(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.logit_scale = logit_scale self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"] self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
if self.use_attn_tp_group: if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = ( self.do_tensor_parallel_all_gather = (
...@@ -254,8 +252,8 @@ class LogitsProcessor(nn.Module): ...@@ -254,8 +252,8 @@ class LogitsProcessor(nn.Module):
): ):
self.final_logit_softcapping = None self.final_logit_softcapping = None
self.debug_tensor_dump_output_folder = global_server_args_dict.get( self.debug_tensor_dump_output_folder = (
"debug_tensor_dump_output_folder", None get_global_server_args().debug_tensor_dump_output_folder
) )
def compute_logprobs_for_multi_item_scoring( def compute_logprobs_for_multi_item_scoring(
...@@ -372,9 +370,7 @@ class LogitsProcessor(nn.Module): ...@@ -372,9 +370,7 @@ class LogitsProcessor(nn.Module):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests) # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = global_server_args_dict.get( multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
"multi_item_scoring_delimiter"
)
if multi_item_delimiter is not None and logits_metadata.is_prefill_only: if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring( return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
......
...@@ -27,12 +27,10 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker ...@@ -27,12 +27,10 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
......
...@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
is_cuda, is_cuda,
...@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[ self.flashinfer_mxfp4_moe_precision = (
"flashinfer_mxfp4_moe_precision" get_global_server_args().flashinfer_mxfp4_moe_precision
] )
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None
......
...@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import ( ...@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
if is_cuda(): if is_cuda():
...@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") ...@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.use_nan_detection = global_server_args_dict["enable_nan_detection"] self.use_nan_detection = get_global_server_args().enable_nan_detection
self.tp_sync_group = get_tp_group().device_group self.tp_sync_group = get_tp_group().device_group
if is_dp_attention_enabled(): if is_dp_attention_enabled():
...@@ -104,7 +104,7 @@ class Sampler(nn.Module): ...@@ -104,7 +104,7 @@ class Sampler(nn.Module):
del logits del logits
if True: # Keep this redundant check to simplify some internal code sync if True: # Keep this redundant check to simplify some internal code sync
if global_server_args_dict["sampling_backend"] == "flashinfer": if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling: if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps) probs = top_p_renorm_prob(probs, sampling_info.top_ps)
...@@ -119,7 +119,7 @@ class Sampler(nn.Module): ...@@ -119,7 +119,7 @@ class Sampler(nn.Module):
filter_apply_order="joint", filter_apply_order="joint",
check_nan=self.use_nan_detection, check_nan=self.use_nan_detection,
) )
elif global_server_args_dict["sampling_backend"] == "pytorch": elif get_global_server_args().sampling_backend == "pytorch":
# A slower fallback implementation with torch native operations. # A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, probs,
...@@ -132,7 +132,7 @@ class Sampler(nn.Module): ...@@ -132,7 +132,7 @@ class Sampler(nn.Module):
) )
else: else:
raise ValueError( raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
) )
if return_logprob: if return_logprob:
......
...@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict,
) )
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
from sglang.utils import logger from sglang.utils import logger
...@@ -428,7 +428,7 @@ def _adjust_embedding_length( ...@@ -428,7 +428,7 @@ def _adjust_embedding_length(
f"tokens from multimodal embeddings." f"tokens from multimodal embeddings."
) )
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding: if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] chunked_prefill_size = get_global_server_args().chunked_prefill_size
if chunked_prefill_size != -1: if chunked_prefill_size != -1:
logger.warning( logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill" "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
......
...@@ -72,7 +72,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats ...@@ -72,7 +72,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils import flatten_nested_list from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import next_power_of_2 from sglang.srt.utils.common import next_power_of_2
...@@ -82,47 +82,6 @@ if TYPE_CHECKING: ...@@ -82,47 +82,6 @@ if TYPE_CHECKING:
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
GLOBAL_SERVER_ARGS_KEYS = [
"attention_backend",
"mm_attention_backend",
"debug_tensor_dump_inject",
"debug_tensor_dump_output_folder",
"chunked_prefill_size",
"device",
"disable_chunked_prefix_cache",
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"ep_num_redundant_experts",
"enable_nan_detection",
"flashinfer_mla_disable_ragged",
"pp_max_micro_batch_size",
"disable_shared_experts_fusion",
"sampling_backend",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"speculative_attention_mode",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_multimodal",
"enable_symm_mem",
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
"multi_item_scoring_delimiter",
]
# Put some global args for easy access
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -683,12 +642,9 @@ class Req: ...@@ -683,12 +642,9 @@ class Req:
def is_prefill_only(self) -> bool: def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed).""" """Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled # NOTE: when spec is enabled, prefill_only optimizations are disabled
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
spec_alg = global_server_args_dict["speculative_algorithm"] spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and ( return self.sampling_params.max_new_tokens == 0 and spec_alg is None
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
)
def add_latency(self, stage: RequestStage): def add_latency(self, stage: RequestStage):
if self.metrics_collector is None: if self.metrics_collector is None:
......
...@@ -122,7 +122,6 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -122,7 +122,6 @@ from sglang.srt.managers.schedule_batch import (
Req, Req,
RequestStage, RequestStage,
ScheduleBatch, ScheduleBatch,
global_server_args_dict,
) )
from sglang.srt.managers.schedule_policy import ( from sglang.srt.managers.schedule_policy import (
AddReqResult, AddReqResult,
...@@ -150,7 +149,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache ...@@ -150,7 +149,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import ( from sglang.srt.tracing.trace import (
...@@ -447,13 +446,12 @@ class Scheduler( ...@@ -447,13 +446,12 @@ class Scheduler(
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
worker_global_server_args_dict,
_, _,
_, _,
_, _,
) = self.tp_worker.get_worker_info() ) = self.tp_worker.get_worker_info()
if global_server_args_dict["pp_max_micro_batch_size"] is None: if get_global_server_args().pp_max_micro_batch_size is None:
global_server_args_dict["pp_max_micro_batch_size"] = max( get_global_server_args().pp_max_micro_batch_size = max(
self.max_running_requests // server_args.pp_size, 1 self.max_running_requests // server_args.pp_size, 1
) )
...@@ -465,7 +463,6 @@ class Scheduler( ...@@ -465,7 +463,6 @@ class Scheduler(
self.world_group = get_world_group() self.world_group = get_world_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Hybrid memory pool # Hybrid memory pool
...@@ -1866,7 +1863,7 @@ class Scheduler( ...@@ -1866,7 +1863,7 @@ class Scheduler(
return ret return ret
def get_num_allocatable_reqs(self, running_bs): def get_num_allocatable_reqs(self, running_bs):
res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs res = get_global_server_args().pp_max_micro_batch_size - running_bs
if self.pp_size > 1: if self.pp_size > 1:
res = min(res, self.req_to_token_pool.available_size()) res = min(res, self.req_to_token_pool.available_size())
return res return res
...@@ -2610,7 +2607,7 @@ class Scheduler( ...@@ -2610,7 +2607,7 @@ class Scheduler(
) )
def get_internal_state(self, recv_req: GetInternalStateReq): def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict) ret = vars(get_global_server_args())
ret["last_gen_throughput"] = self.last_gen_throughput ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = { ret["memory_usage"] = {
"weight": round( "weight": round(
...@@ -2666,11 +2663,11 @@ class Scheduler( ...@@ -2666,11 +2663,11 @@ class Scheduler(
logger.info(f"{avg_spec_accept_length=}") logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0 self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items(): for k, v in server_args_dict.items():
global_server_args_dict[k] = v setattr(get_global_server_args(), k, v)
logger.info(f"Global server args updated! {global_server_args_dict=}") logger.info(f"Global server args updated! {get_global_server_args()=}")
return SetInternalStateReqOutput( return SetInternalStateReqOutput(
updated=True, updated=True,
server_args=global_server_args_dict, server_args=vars(get_global_server_args()),
) )
def handle_rpc_request(self, recv_req: RpcReqInput): def handle_rpc_request(self, recv_req: RpcReqInput):
......
...@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
...@@ -190,7 +190,6 @@ class TpModelWorker: ...@@ -190,7 +190,6 @@ class TpModelWorker:
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
global_server_args_dict,
self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.size,
self.model_runner.req_to_token_pool.max_context_len, self.model_runner.req_to_token_pool.max_context_len,
self.model_runner.token_to_kv_pool.size, self.model_runner.token_to_kv_pool.size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment