Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1bdd0102
Unverified
Commit
1bdd0102
authored
Oct 12, 2025
by
Cheng Wan
Committed by
GitHub
Oct 12, 2025
Browse files
Revert "Deprecate `global_server_args_dict`" (#11520)
parent
6cd29694
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
117 additions
and
63 deletions
+117
-63
python/sglang/global_config.py
python/sglang/global_config.py
+3
-0
python/sglang/srt/distributed/device_communicators/pynccl_allocator.py
.../srt/distributed/device_communicators/pynccl_allocator.py
+2
-2
python/sglang/srt/eplb/expert_location_dispatch.py
python/sglang/srt/eplb/expert_location_dispatch.py
+2
-2
python/sglang/srt/eplb/expert_location_updater.py
python/sglang/srt/eplb/expert_location_updater.py
+2
-2
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+2
-2
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+5
-5
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+2
-2
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
.../layers/attention/triton_ops/double_sparsity_attention.py
+2
-2
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+4
-4
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+3
-3
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+5
-8
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+10
-6
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+2
-0
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+4
-4
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+5
-5
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+47
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+11
-8
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+2
-1
No files found.
python/sglang/global_config.py
View file @
1bdd0102
...
...
@@ -6,6 +6,9 @@
class
GlobalConfig
:
"""
Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
"""
def
__init__
(
self
):
...
...
python/sglang/srt/distributed/device_communicators/pynccl_allocator.py
View file @
1bdd0102
...
...
@@ -5,7 +5,7 @@ from packaging import version
from
torch.cuda.memory
import
CUDAPluggableAllocator
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
from
sglang.srt.
server_args
import
get_
global_server_args
from
sglang.srt.
managers.schedule_batch
import
global_server_args
_dict
nccl_allocator_source
=
"""
#include <nccl.h>
...
...
@@ -32,7 +32,7 @@ _graph_pool_id = None
def
is_symmetric_memory_enabled
():
return
get_
global_server_args
().
enable_symm_mem
return
global_server_args
_dict
[
"
enable_symm_mem
"
]
def
set_graph_pool_id
(
graph_pool_id
):
...
...
python/sglang/srt/eplb/expert_location_dispatch.py
View file @
1bdd0102
...
...
@@ -18,7 +18,7 @@ from typing import Literal, Optional
import
torch
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.
server_args
import
get_
global_server_args
from
sglang.srt.
managers.schedule_batch
import
global_server_args
_dict
@
dataclass
...
...
@@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo:
@
classmethod
def
init_new
(
cls
,
layer_id
:
int
):
ep_dispatch_algorithm
=
get_
global_server_args
().
ep_dispatch_algorithm
ep_dispatch_algorithm
=
global_server_args
_dict
[
"
ep_dispatch_algorithm
"
]
expert_location_metadata
=
get_global_expert_location_metadata
()
assert
expert_location_metadata
is
not
None
...
...
python/sglang/srt/eplb/expert_location_updater.py
View file @
1bdd0102
...
...
@@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import (
ExpertLocationMetadata
,
get_global_expert_location_metadata
,
)
from
sglang.srt.
server_args
import
get_
global_server_args
from
sglang.srt.
managers.schedule_batch
import
global_server_args
_dict
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -97,7 +97,7 @@ def _update_expert_weights_with_canary(
canary_tensor
=
(
_get_canary_value
(
old_expert_location_metadata
,
layer_id
)
.
clone
()
.
to
(
device
=
get_
global_server_args
().
device
,
non_blocking
=
True
)
.
to
(
device
=
global_server_args
_dict
[
"
device
"
]
,
non_blocking
=
True
)
)
routed_experts_weights_of_layer
[
layer_id
].
append
(
canary_tensor
)
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
1bdd0102
...
...
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
import
torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
get_global_server_args
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
# TODO: Change the hard-coded block_seq_num
self
.
BLOCK_SEQ
=
128
if
get_
global_server_args
().
triton_attention_reduce_in_fp32
:
if
global_server_args
_dict
.
get
(
"
triton_attention_reduce_in_fp32
"
,
False
)
:
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
1bdd0102
...
...
@@ -11,8 +11,8 @@ import triton.language as tl
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.spec_info
import
SpecInput
if
TYPE_CHECKING
:
...
...
@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
):
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_
global_server_args
().
disable_chunked_prefix_cache
assert
not
global_server_args
_dict
[
"
disable_chunked_prefix_cache
"
]
# MHA for chunked prefix kv cache when running model with MLA
assert
forward_batch
.
prefix_chunk_idx
is
not
None
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
1bdd0102
...
...
@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
(
is_flashinfer_available
,
...
...
@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self
.
skip_prefill
=
skip_prefill
self
.
enable_chunk_kv
=
(
not
skip_prefill
and
get_
global_server_args
().
disaggregation_mode
!=
"decode"
and
not
get_
global_server_args
().
disable_chunked_prefix_cache
and
not
get_
global_server_args
().
flashinfer_mla_disable_ragged
and
global_server_args
_dict
[
"
disaggregation_mode
"
]
!=
"decode"
and
not
global_server_args
_dict
[
"
disable_chunked_prefix_cache
"
]
and
not
global_server_args
_dict
[
"
flashinfer_mla_disable_ragged
"
]
)
self
.
page_size
=
model_runner
.
page_size
...
...
@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
prefix_lens
=
forward_batch
.
extend_prefix_lens
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens_cpu
)
use_ragged
=
(
not
get_
global_server_args
().
flashinfer_mla_disable_ragged
not
global_server_args
_dict
[
"
flashinfer_mla_disable_ragged
"
]
and
extend_no_prefix
)
...
...
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
1bdd0102
...
...
@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.rotary_embedding
import
get_rope_wrapper
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
get_global_server_args
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
NSATokenToKVPool
...
...
@@ -162,7 +162,7 @@ class Indexer(CustomOp):
base
=
rope_theta
,
# type: ignore
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
device
=
get_
global_server_args
().
device
,
device
=
global_server_args
_dict
[
"
device
"
]
,
)
self
.
block_size
=
block_size
self
.
scale_fmt
=
scale_fmt
...
...
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
View file @
1bdd0102
...
...
@@ -2,7 +2,7 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.
server_args
import
get_
global_server_args
from
sglang.srt.
managers.schedule_batch
import
global_server_args
_dict
from
sglang.srt.utils
import
is_cuda
,
is_hip
_is_cuda
=
is_cuda
()
...
...
@@ -11,7 +11,7 @@ if _is_cuda:
_is_hip
=
is_hip
()
if
get_
global_server_args
().
triton_
attention_reduce_in_fp32
:
if
global_server_args
_dict
.
get
(
"
attention_reduce_in_fp32
"
,
False
)
:
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TORCH_TYPE
=
torch
.
float32
else
:
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
1bdd0102
...
...
@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
if
is_flashinfer_available
():
...
...
@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self
.
forward_prefill_metadata
:
Optional
[
TRTLLMMLAPrefillMetadata
]
=
None
self
.
forward_decode_metadata
:
Union
[
TRTLLMMLADecodeMetadata
,
None
]
=
None
self
.
disable_chunked_prefix_cache
=
(
get_global_server_args
().
disable_chunked_prefix_cache
)
self
.
disable_chunked_prefix_cache
=
global_server_args_dict
[
"
disable_chunked_prefix_cache
"
]
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
...
...
python/sglang/srt/layers/attention/vision.py
View file @
1bdd0102
...
...
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
)
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.rotary_embedding
import
apply_rotary_pos_emb
from
sglang.srt.
server_args
import
get_
global_server_args
from
sglang.srt.
managers.schedule_batch
import
global_server_args
_dict
from
sglang.srt.utils
import
add_prefix
ROTARY_EMBED_CLASSES
=
{
...
...
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
_passed_backend
=
qkv_backend
qkv_backend
=
self
.
_determine_attention_backend
(
_passed_backend
)
if
(
get_
global_server_args
().
mm_attention_backend
is
None
global_server_args
_dict
[
"
mm_attention_backend
"
]
is
None
and
_passed_backend
is
None
):
print_info_once
(
f
"Multimodal attention backend not set. Use
{
qkv_backend
}
."
)
...
...
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
- CUDA: "triton_attn"
- Non-CUDA: "sdpa"
"""
override_backend
=
get_
global_server_args
().
mm_attention_backend
override_backend
=
global_server_args
_dict
[
"
mm_attention_backend
"
]
if
override_backend
is
not
None
:
backend
=
override_backend
elif
passed_backend
is
not
None
:
...
...
python/sglang/srt/layers/communicator.py
View file @
1bdd0102
...
...
@@ -40,9 +40,8 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_cuda
,
...
...
@@ -169,7 +168,7 @@ class LayerScatterModes:
def
enable_moe_dense_fully_dp
():
return
get_
global_server_args
().
moe_dense_tp_size
==
1
return
global_server_args
_dict
[
"
moe_dense_tp_size
"
]
==
1
class
LayerCommunicator
:
...
...
@@ -315,9 +314,7 @@ class LayerCommunicator:
def
should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
:
ForwardBatch
)
->
bool
:
speculative_algo
=
SpeculativeAlgorithm
.
from_string
(
get_global_server_args
().
speculative_algorithm
)
speculative_algo
=
global_server_args_dict
.
get
(
"speculative_algorithm"
,
None
)
if
(
is_dp_attention_enabled
()
and
speculative_algo
is
not
None
...
...
@@ -336,7 +333,7 @@ class LayerCommunicator:
static_conditions_met
=
(
(
not
self
.
is_last_layer
)
and
(
self
.
_context
.
tp_size
>
1
)
and
get_
global_server_args
().
enable_flashinfer_allreduce_fusion
and
global_server_args
_dict
.
get
(
"
enable_flashinfer_allreduce_fusion
"
,
False
)
and
_is_flashinfer_available
)
...
...
@@ -534,7 +531,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
(
_is_sm100_supported
or
_is_sm90_supported
)
and
_is_flashinfer_available
and
hasattr
(
layernorm
,
"forward_with_allreduce_fusion"
)
and
get_
global_server_args
().
enable_flashinfer_allreduce_fusion
and
global_server_args
_dict
[
"
enable_flashinfer_allreduce_fusion
"
]
and
hidden_states
.
shape
[
0
]
<=
4096
):
hidden_states
,
residual
=
layernorm
.
forward_with_allreduce_fusion
(
...
...
python/sglang/srt/layers/logits_processor.py
View file @
1bdd0102
...
...
@@ -38,15 +38,17 @@ from sglang.srt.layers.dp_attention import (
get_dp_device
,
get_dp_dtype
,
get_dp_hidden_size
,
get_global_dp_buffer
,
get_local_attention_dp_size
,
set_dp_buffer_len
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
dump_to_file
,
is_npu
,
use_intel_amx_backend
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -228,8 +230,8 @@ class LogitsProcessor(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
logit_scale
=
logit_scale
self
.
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
self
.
use_fp32_lm_head
=
get_
global_server_args
().
enable_fp32_lm_head
self
.
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
self
.
use_fp32_lm_head
=
global_server_args
_dict
[
"
enable_fp32_lm_head
"
]
if
self
.
use_attn_tp_group
:
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
do_tensor_parallel_all_gather
=
(
...
...
@@ -252,8 +254,8 @@ class LogitsProcessor(nn.Module):
):
self
.
final_logit_softcapping
=
None
self
.
debug_tensor_dump_output_folder
=
(
get_global_server_args
().
debug_tensor_dump_output_folder
self
.
debug_tensor_dump_output_folder
=
global_server_args_dict
.
get
(
"
debug_tensor_dump_output_folder
"
,
None
)
def
compute_logprobs_for_multi_item_scoring
(
...
...
@@ -370,7 +372,9 @@ class LogitsProcessor(nn.Module):
logits_metadata
=
LogitsMetadata
.
from_forward_batch
(
logits_metadata
)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter
=
get_global_server_args
().
multi_item_scoring_delimiter
multi_item_delimiter
=
global_server_args_dict
.
get
(
"multi_item_scoring_delimiter"
)
if
multi_item_delimiter
is
not
None
and
logits_metadata
.
is_prefill_only
:
return
self
.
compute_logprobs_for_multi_item_scoring
(
input_ids
,
hidden_states
,
lm_head
,
logits_metadata
,
multi_item_delimiter
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
1bdd0102
...
...
@@ -27,10 +27,12 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8MoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
1bdd0102
...
...
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.
server_args
import
get_
global_server_args
from
sglang.srt.
managers.schedule_batch
import
global_server_args
_dict
from
sglang.srt.utils
import
(
direct_register_custom_op
,
is_cuda
,
...
...
@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
with_bias
=
False
self
.
use_flashinfer
=
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
self
.
flashinfer_mxfp4_moe_precision
=
(
get_global_server_args
().
flashinfer_mxfp4_moe_precision
)
self
.
flashinfer_mxfp4_moe_precision
=
global_server_args_dict
[
"
flashinfer_mxfp4_moe_precision
"
]
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
...
...
python/sglang/srt/layers/sampler.py
View file @
1bdd0102
...
...
@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
is_dp_attention_enabled
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
crash_on_warnings
,
get_bool_env_var
,
is_cuda
if
is_cuda
():
...
...
@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
use_nan_detection
=
get_
global_server_args
().
enable_nan_detection
self
.
use_nan_detection
=
global_server_args
_dict
[
"
enable_nan_detection
"
]
self
.
tp_sync_group
=
get_tp_group
().
device_group
if
is_dp_attention_enabled
():
...
...
@@ -103,7 +103,7 @@ class Sampler(nn.Module):
del
logits
if
True
:
# Keep this redundant check to simplify some internal code sync
if
get_
global_server_args
().
sampling_backend
==
"flashinfer"
:
if
global_server_args
_dict
[
"
sampling_backend
"
]
==
"flashinfer"
:
if
sampling_info
.
need_min_p_sampling
:
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
...
...
@@ -118,7 +118,7 @@ class Sampler(nn.Module):
filter_apply_order
=
"joint"
,
check_nan
=
self
.
use_nan_detection
,
)
elif
get_
global_server_args
().
sampling_backend
==
"pytorch"
:
elif
global_server_args
_dict
[
"
sampling_backend
"
]
==
"pytorch"
:
# A slower fallback implementation with torch native operations.
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
...
...
@@ -131,7 +131,7 @@ class Sampler(nn.Module):
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
get_
global_server_args
().
sampling_backend
}
"
f
"Invalid sampling backend:
{
global_server_args
_dict
[
'
sampling_backend
'
]
}
"
)
if
return_logprob
:
...
...
python/sglang/srt/managers/mm_utils.py
View file @
1bdd0102
...
...
@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.multimodal_cache
import
MultiModalCache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
flatten_nested_list
,
is_npu
,
print_warning_once
from
sglang.utils
import
logger
...
...
@@ -428,7 +428,7 @@ def _adjust_embedding_length(
f
"tokens from multimodal embeddings."
)
if
num_mm_tokens_in_input_ids
<
num_mm_tokens_in_embedding
:
chunked_prefill_size
=
get_
global_server_args
().
chunked_prefill_size
chunked_prefill_size
=
global_server_args
_dict
[
"
chunked_prefill_size
"
]
if
chunked_prefill_size
!=
-
1
:
logger
.
warning
(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
1bdd0102
...
...
@@ -72,7 +72,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
flatten_nested_list
from
sglang.srt.utils.common
import
next_power_of_2
...
...
@@ -82,6 +82,47 @@ if TYPE_CHECKING:
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
GLOBAL_SERVER_ARGS_KEYS
=
[
"attention_backend"
,
"mm_attention_backend"
,
"debug_tensor_dump_inject"
,
"debug_tensor_dump_output_folder"
,
"chunked_prefill_size"
,
"device"
,
"disable_chunked_prefix_cache"
,
"disable_flashinfer_cutlass_moe_fp4_allgather"
,
"disable_radix_cache"
,
"enable_dp_lm_head"
,
"enable_fp32_lm_head"
,
"flashinfer_mxfp4_moe_precision"
,
"enable_flashinfer_allreduce_fusion"
,
"moe_dense_tp_size"
,
"ep_dispatch_algorithm"
,
"ep_num_redundant_experts"
,
"enable_nan_detection"
,
"flashinfer_mla_disable_ragged"
,
"pp_max_micro_batch_size"
,
"disable_shared_experts_fusion"
,
"sampling_backend"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_acc"
,
"speculative_attention_mode"
,
"torchao_config"
,
"triton_attention_reduce_in_fp32"
,
"num_reserved_decode_tokens"
,
"weight_loader_disable_mmap"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"enable_custom_logit_processor"
,
"disaggregation_mode"
,
"enable_deterministic_inference"
,
"nsa_prefill"
,
"nsa_decode"
,
"multi_item_scoring_delimiter"
,
]
# Put some global args for easy access
global_server_args_dict
=
{
k
:
getattr
(
ServerArgs
,
k
)
for
k
in
GLOBAL_SERVER_ARGS_KEYS
}
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -642,9 +683,12 @@ class Req:
def
is_prefill_only
(
self
)
->
bool
:
"""Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
spec_alg
=
get_global_server_args
().
speculative_algorithm
return
self
.
sampling_params
.
max_new_tokens
==
0
and
spec_alg
is
None
spec_alg
=
global_server_args_dict
[
"speculative_algorithm"
]
return
self
.
sampling_params
.
max_new_tokens
==
0
and
(
spec_alg
is
None
or
spec_alg
==
SpeculativeAlgorithm
.
NONE
)
def
add_latency
(
self
,
stage
:
RequestStage
):
if
self
.
metrics_collector
is
None
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
1bdd0102
...
...
@@ -122,6 +122,7 @@ from sglang.srt.managers.schedule_batch import (
Req
,
RequestStage
,
ScheduleBatch
,
global_server_args_dict
,
)
from
sglang.srt.managers.schedule_policy
import
(
AddReqResult
,
...
...
@@ -149,7 +150,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.parser.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.tracing.trace
import
(
...
...
@@ -446,12 +447,13 @@ class Scheduler(
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
device
,
worker_global_server_args_dict
,
_
,
_
,
_
,
)
=
self
.
tp_worker
.
get_worker_info
()
if
get_
global_server_args
().
pp_max_micro_batch_size
is
None
:
get_
global_server_args
().
pp_max_micro_batch_size
=
max
(
if
global_server_args
_dict
[
"
pp_max_micro_batch_size
"
]
is
None
:
global_server_args
_dict
[
"
pp_max_micro_batch_size
"
]
=
max
(
self
.
max_running_requests
//
server_args
.
pp_size
,
1
)
...
...
@@ -463,6 +465,7 @@ class Scheduler(
self
.
world_group
=
get_world_group
()
self
.
pad_input_ids_func
=
self
.
tp_worker
.
get_pad_input_ids_func
()
global_server_args_dict
.
update
(
worker_global_server_args_dict
)
set_random_seed
(
self
.
random_seed
)
# Hybrid memory pool
...
...
@@ -1863,7 +1866,7 @@ class Scheduler(
return
ret
def
get_num_allocatable_reqs
(
self
,
running_bs
):
res
=
get_
global_server_args
().
pp_max_micro_batch_size
-
running_bs
res
=
global_server_args
_dict
[
"
pp_max_micro_batch_size
"
]
-
running_bs
if
self
.
pp_size
>
1
:
res
=
min
(
res
,
self
.
req_to_token_pool
.
available_size
())
return
res
...
...
@@ -2607,7 +2610,7 @@ class Scheduler(
)
def
get_internal_state
(
self
,
recv_req
:
GetInternalStateReq
):
ret
=
vars
(
get_
global_server_args
()
)
ret
=
dict
(
global_server_args
_dict
)
ret
[
"last_gen_throughput"
]
=
self
.
last_gen_throughput
ret
[
"memory_usage"
]
=
{
"weight"
:
round
(
...
...
@@ -2663,11 +2666,11 @@ class Scheduler(
logger
.
info
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
cum_spec_accept_length
=
self
.
cum_spec_accept_count
=
0
for
k
,
v
in
server_args_dict
.
items
():
setattr
(
get_
global_server_args
(),
k
,
v
)
logger
.
info
(
f
"Global server args updated!
{
get_
global_server_args
()
=
}
"
)
global_server_args
_dict
[
k
]
=
v
logger
.
info
(
f
"Global server args updated!
{
global_server_args
_dict
=
}
"
)
return
SetInternalStateReqOutput
(
updated
=
True
,
server_args
=
vars
(
get_
global_server_args
())
,
server_args
=
global_server_args
_dict
,
)
def
handle_rpc_request
(
self
,
recv_req
:
RpcReqInput
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
1bdd0102
...
...
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
...
...
@@ -190,6 +190,7 @@ class TpModelWorker:
self
.
max_req_input_len
,
self
.
random_seed
,
self
.
device
,
global_server_args_dict
,
self
.
model_runner
.
req_to_token_pool
.
size
,
self
.
model_runner
.
req_to_token_pool
.
max_context_len
,
self
.
model_runner
.
token_to_kv_pool
.
size
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment