Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1083e7e3
Unverified
Commit
1083e7e3
authored
Oct 13, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 13, 2025
Browse files
Deprecate `global_server_args_dict` (#11331)
parent
2157d12a
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
72 additions
and
73 deletions
+72
-73
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+2
-2
python/sglang/srt/models/qwen3_next_mtp.py
python/sglang/srt/models/qwen3_next_mtp.py
+3
-4
python/sglang/srt/models/qwen3_vl_moe.py
python/sglang/srt/models/qwen3_vl_moe.py
+3
-11
python/sglang/srt/models/step3_vl.py
python/sglang/srt/models/step3_vl.py
+2
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+6
-13
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+19
-2
python/sglang/srt/single_batch_overlap.py
python/sglang/srt/single_batch_overlap.py
+0
-1
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+4
-7
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+3
-7
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+6
-6
python/sglang/srt/speculative/ngram_info.py
python/sglang/srt/speculative/ngram_info.py
+5
-5
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+6
-5
test/srt/rl/test_fp32_lm_head.py
test/srt/rl/test_fp32_lm_head.py
+9
-3
test/srt/test_gptqmodel_dynamic.py
test/srt/test_gptqmodel_dynamic.py
+4
-4
No files found.
python/sglang/srt/models/qwen3_next.py
View file @
1083e7e3
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
...
@@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader
,
sharded_weight_loader
,
)
)
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
,
Qwen2MoeSparseMoeBlock
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
,
Qwen2MoeSparseMoeBlock
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
LazyValue
,
LazyValue
,
add_prefix
,
add_prefix
,
...
@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
...
@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/qwen3_next_mtp.py
View file @
1083e7e3
...
@@ -21,14 +21,13 @@ from torch import nn
...
@@ -21,14 +21,13 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
,
RMSNorm
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeModel
from
sglang.srt.models.qwen3_next
import
Qwen3NextForCausalLM
,
Qwen3NextModel
from
sglang.srt.models.qwen3_next
import
Qwen3NextForCausalLM
,
Qwen3NextModel
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
...
@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/qwen3_vl_moe.py
View file @
1083e7e3
...
@@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
(
from
sglang.srt.managers.mm_utils
import
general_mm_embed_routine
MultiModalityDataPaddingPatternMultimodalTokens
,
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
(
MultimodalDataItem
,
MultimodalInputs
,
global_server_args_dict
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeForCausalLM
,
Qwen3MoeModel
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeModel
from
sglang.srt.models.qwen3_vl
import
(
from
sglang.srt.models.qwen3_vl
import
(
Qwen3_VisionTransformer
,
Qwen3_VisionTransformer
,
Qwen3VLForConditionalGeneration
,
Qwen3VLForConditionalGeneration
,
...
...
python/sglang/srt/models/step3_vl.py
View file @
1083e7e3
...
@@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import (
Modality
,
Modality
,
MultimodalDataItem
,
MultimodalDataItem
,
MultimodalInputs
,
MultimodalInputs
,
global_server_args_dict
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module):
...
@@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module):
# self.n_shared_experts = 1
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# self.num_fused_shared_experts = (
# 0
# 0
# if global_server_args
_dict["
disable_shared_experts_fusion
"]
# if global_server_args
.
disable_shared_experts_fusion
# else self.n_shared_experts
# else self.n_shared_experts
# )
# )
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
...
@@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module):
...
@@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module):
# self.n_shared_experts = 1
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# self.num_fused_shared_experts = (
# 0
# 0
# if global_server_args
_dict["
disable_shared_experts_fusion
"]
# if global_server_args
.
disable_shared_experts_fusion
# else self.n_shared_experts
# else self.n_shared_experts
# )
# )
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
1083e7e3
...
@@ -2,7 +2,6 @@ from __future__ import annotations
...
@@ -2,7 +2,6 @@ from __future__ import annotations
import
dataclasses
import
dataclasses
import
logging
import
logging
import
threading
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -10,6 +9,7 @@ import torch
...
@@ -10,6 +9,7 @@ import torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.server_args
import
get_global_server_args
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -66,16 +66,10 @@ class SamplingBatchInfo:
...
@@ -66,16 +66,10 @@ class SamplingBatchInfo:
# Handle logit bias
# Handle logit bias
logit_bias
:
Optional
[
torch
.
Tensor
]
=
None
logit_bias
:
Optional
[
torch
.
Tensor
]
=
None
@
classmethod
def
_get_global_server_args_dict
(
cls
):
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
return
global_server_args_dict
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
global_server_args
_dict
=
cls
.
_
get_global_server_args
_dict
()
global_server_args
=
get_global_server_args
()
enable_deterministic
=
global_server_args
_dict
[
"
enable_deterministic_inference
"
]
enable_deterministic
=
global_server_args
.
enable_deterministic_inference
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
device
=
batch
.
device
device
=
batch
.
device
...
@@ -112,10 +106,9 @@ class SamplingBatchInfo:
...
@@ -112,10 +106,9 @@ class SamplingBatchInfo:
logit_bias
[
i
,
int
(
key
)]
=
value
logit_bias
[
i
,
int
(
key
)]
=
value
# Check if any request has custom logit processor
# Check if any request has custom logit processor
has_custom_logit_processor
=
global_server_args_dict
[
has_custom_logit_processor
=
(
"enable_custom_logit_processor"
global_server_args
.
enable_custom_logit_processor
]
and
any
(
# check the flag first.
and
any
(
r
.
custom_logit_processor
for
r
in
reqs
)
# check the flag first.
r
.
custom_logit_processor
for
r
in
reqs
)
# then check the requests.
)
# then check the requests.
if
has_custom_logit_processor
:
if
has_custom_logit_processor
:
...
...
python/sglang/srt/server_args.py
View file @
1083e7e3
...
@@ -53,6 +53,7 @@ from sglang.utils import is_in_ci
...
@@ -53,6 +53,7 @@ from sglang.utils import is_in_ci
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Define constants
# Define constants
LOAD_FORMAT_CHOICES
=
[
LOAD_FORMAT_CHOICES
=
[
"auto"
,
"auto"
,
...
@@ -3329,6 +3330,22 @@ class ServerArgs:
...
@@ -3329,6 +3330,22 @@ class ServerArgs:
)
)
# NOTE: This is a global variable to hold the server args for scheduler.
_global_server_args
:
Optional
[
ServerArgs
]
=
None
def
set_global_server_args_for_scheduler
(
server_args
:
ServerArgs
):
global
_global_server_args
_global_server_args
=
server_args
def
get_global_server_args
()
->
ServerArgs
:
if
_global_server_args
is
None
:
raise
ValueError
(
"Global server args is not set yet!"
)
return
_global_server_args
def
prepare_server_args
(
argv
:
List
[
str
])
->
ServerArgs
:
def
prepare_server_args
(
argv
:
List
[
str
])
->
ServerArgs
:
"""
"""
Prepare the server arguments from the command line arguments.
Prepare the server arguments from the command line arguments.
...
@@ -3363,8 +3380,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
...
@@ -3363,8 +3380,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
ServerArgs
.
add_cli_args
(
parser
)
raw_args
=
parser
.
parse_args
(
argv
)
raw_args
=
parser
.
parse_args
(
argv
)
server_args
=
ServerArgs
.
from_cli_args
(
raw_args
)
return
s
erver
_args
return
S
erver
Args
.
from_cli_args
(
raw_args
)
ZMQ_TCP_PORT_DELTA
=
233
ZMQ_TCP_PORT_DELTA
=
233
...
...
python/sglang/srt/single_batch_overlap.py
View file @
1083e7e3
...
@@ -6,7 +6,6 @@ import torch
...
@@ -6,7 +6,6 @@ import torch
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.layers.moe.utils
import
is_sbo_enabled
from
sglang.srt.layers.moe.utils
import
is_sbo_enabled
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
get_int_env_var
from
sglang.srt.utils
import
get_int_env_var
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
1083e7e3
...
@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
...
@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.overlap_utils
import
FutureIndices
from
sglang.srt.managers.overlap_utils
import
FutureIndices
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.common
import
(
from
sglang.srt.mem_cache.common
import
(
alloc_paged_token_slots_extend
,
alloc_paged_token_slots_extend
,
...
@@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import (
...
@@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import (
get_last_loc
,
get_last_loc
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.eagle_info_v2
import
(
from
sglang.srt.speculative.eagle_info_v2
import
(
EagleDraftInputV2Mixin
,
EagleDraftInputV2Mixin
,
EagleVerifyInputV2Mixin
,
EagleVerifyInputV2Mixin
,
...
@@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
global_server_args_dict
[
threshold_single
=
get_global_server_args
().
speculative_accept_threshold_single
,
"speculative_accept_threshold_single"
threshold_acc
=
get_global_server_args
().
speculative_accept_threshold_acc
,
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
deterministic
=
True
,
)
)
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
1083e7e3
...
@@ -11,7 +11,6 @@ import triton.language as tl
...
@@ -11,7 +11,6 @@ import triton.language as tl
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.scheduler
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
CaptureHiddenMode
,
...
@@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.build_eagle_tree
import
TreeMaskMode
from
sglang.srt.speculative.build_eagle_tree
import
TreeMaskMode
from
sglang.srt.speculative.spec_utils
import
(
from
sglang.srt.speculative.spec_utils
import
(
SIMULATE_ACC_LEN
,
SIMULATE_ACC_LEN
,
...
@@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin:
...
@@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin:
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
global_server_args_dict
[
threshold_single
=
get_global_server_args
().
speculative_accept_threshold_single
,
"speculative_accept_threshold_single"
threshold_acc
=
get_global_server_args
().
speculative_accept_threshold_acc
,
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
deterministic
=
True
,
)
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
1083e7e3
...
@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
...
@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.mem_cache.common
import
(
from
sglang.srt.mem_cache.common
import
(
...
@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
EAGLEDraftCudaGraphRunner
,
EAGLEDraftCudaGraphRunner
,
...
@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
)
)
def
_create_flashinfer_decode_backend
(
self
):
def
_create_flashinfer_decode_backend
(
self
):
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
if
not
get_
global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferMultiStepDraftBackend
,
FlashInferMultiStepDraftBackend
,
)
)
...
@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
)
)
def
_create_trtllm_mla_decode_backend
(
self
):
def
_create_trtllm_mla_decode_backend
(
self
):
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
if
not
get_
global_server_args
().
use_mla_backend
:
raise
ValueError
(
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
)
...
@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
)
)
def
_create_flashinfer_prefill_backend
(
self
):
def
_create_flashinfer_prefill_backend
(
self
):
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
if
not
get_
global_server_args
().
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
FlashInferAttnBackend
,
)
)
...
@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mla_prefill_backend
(
self
):
def
_create_trtllm_mla_prefill_backend
(
self
):
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
if
not
get_
global_server_args
().
use_mla_backend
:
raise
ValueError
(
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
)
...
...
python/sglang/srt/speculative/ngram_info.py
View file @
1083e7e3
...
@@ -7,6 +7,8 @@ from typing import Optional, Tuple
...
@@ -7,6 +7,8 @@ from typing import Optional, Tuple
import
torch
import
torch
import
triton
import
triton
from
sglang.srt.server_args
import
get_global_server_args
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -16,7 +18,7 @@ import torch.nn.functional as F
...
@@ -16,7 +18,7 @@ import torch.nn.functional as F
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.common
import
(
from
sglang.srt.mem_cache.common
import
(
alloc_paged_token_slots_extend
,
alloc_paged_token_slots_extend
,
alloc_token_slots
,
alloc_token_slots
,
...
@@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput):
...
@@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput):
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
global_server_args_dict
[
threshold_single
=
get_global_server_args
().
speculative_accept_threshold_single
,
"speculative_accept_threshold_single"
threshold_acc
=
get_global_server_args
().
speculative_accept_threshold_acc
,
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
deterministic
=
True
,
)
)
...
...
python/sglang/srt/two_batch_overlap.py
View file @
1083e7e3
...
@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
...
@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
)
)
from
sglang.srt.layers.moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
...
@@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import (
)
)
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
...
@@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field(
...
@@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field(
cpu_value
cpu_value
if
isinstance
(
cpu_value
,
torch
.
Tensor
)
if
isinstance
(
cpu_value
,
torch
.
Tensor
)
else
torch
.
tensor
(
cpu_value
,
dtype
=
old_device_value
.
dtype
)
else
torch
.
tensor
(
cpu_value
,
dtype
=
old_device_value
.
dtype
)
).
to
(
device
=
global_server_args
_dict
[
"
device
"
]
,
non_blocking
=
True
)
).
to
(
device
=
get_
global_server_args
().
device
,
non_blocking
=
True
)
setattr
(
batch
,
device_field
,
new_device_value
)
setattr
(
batch
,
device_field
,
new_device_value
)
if
sum_field
is
not
None
:
if
sum_field
is
not
None
:
...
@@ -582,7 +583,7 @@ class TboForwardBatchPreparer:
...
@@ -582,7 +583,7 @@ class TboForwardBatchPreparer:
sum_field
=
None
,
sum_field
=
None
,
)
)
_
,
child_b
.
extend_start_loc
=
compute_position
(
_
,
child_b
.
extend_start_loc
=
compute_position
(
global_server_args
_dict
[
"
attention_backend
"
]
,
get_
global_server_args
().
attention_backend
,
child_b
.
extend_prefix_lens
,
child_b
.
extend_prefix_lens
,
child_b
.
extend_seq_lens
,
child_b
.
extend_seq_lens
,
child_b
.
extend_num_tokens
,
child_b
.
extend_num_tokens
,
...
@@ -687,7 +688,7 @@ class TboForwardBatchPreparer:
...
@@ -687,7 +688,7 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw`
# TODO improve, e.g. unify w/ `init_raw`
if
(
if
(
global_server_args
_dict
[
"
moe_dense_tp_size
"
]
==
1
get_
global_server_args
().
moe_dense_tp_size
==
1
and
batch
.
global_dp_buffer_len
is
not
None
and
batch
.
global_dp_buffer_len
is
not
None
):
):
sum_len
=
end_token_index
-
start_token_index
sum_len
=
end_token_index
-
start_token_index
...
@@ -755,7 +756,7 @@ class TboForwardBatchPreparer:
...
@@ -755,7 +756,7 @@ class TboForwardBatchPreparer:
value_a
=
min
(
tbo_split_token_index
,
num_token_non_padded
)
value_a
=
min
(
tbo_split_token_index
,
num_token_non_padded
)
value_b
=
max
(
0
,
num_token_non_padded
-
tbo_split_token_index
)
value_b
=
max
(
0
,
num_token_non_padded
-
tbo_split_token_index
)
return
torch
.
tensor
([
value_a
,
value_b
],
dtype
=
torch
.
int32
).
to
(
return
torch
.
tensor
([
value_a
,
value_b
],
dtype
=
torch
.
int32
).
to
(
device
=
global_server_args
_dict
[
"
device
"
]
,
non_blocking
=
True
device
=
get_
global_server_args
().
device
,
non_blocking
=
True
)
)
@
classmethod
@
classmethod
...
...
test/srt/rl/test_fp32_lm_head.py
View file @
1083e7e3
...
@@ -7,7 +7,11 @@ import torch.nn as nn
...
@@ -7,7 +7,11 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.server_args
import
(
ServerArgs
,
get_global_server_args
,
set_global_server_args_for_scheduler
,
)
class
LMHeadStub
(
nn
.
Module
):
class
LMHeadStub
(
nn
.
Module
):
...
@@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase):
...
@@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase):
raise
unittest
.
SkipTest
(
"needs CUDA GPU"
)
raise
unittest
.
SkipTest
(
"needs CUDA GPU"
)
def
_make_logprocessor
(
self
,
vocab_size
,
enable_fp32
):
def
_make_logprocessor
(
self
,
vocab_size
,
enable_fp32
):
global_server_args_dict
[
"enable_dp_lm_head"
]
=
False
ServerArgs
.
__post_init__
=
lambda
self
:
None
# disable validation
global_server_args_dict
[
"enable_fp32_lm_head"
]
=
enable_fp32
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
get_global_server_args
().
enable_dp_lm_head
=
False
get_global_server_args
().
enable_fp32_lm_head
=
enable_fp32
cfg
=
SimpleNamespace
(
vocab_size
=
vocab_size
,
final_logit_softcapping
=
None
)
cfg
=
SimpleNamespace
(
vocab_size
=
vocab_size
,
final_logit_softcapping
=
None
)
return
LogitsProcessor
(
cfg
,
skip_all_gather
=
True
,
logit_scale
=
None
)
return
LogitsProcessor
(
cfg
,
skip_all_gather
=
True
,
logit_scale
=
None
)
...
...
test/srt/test_gptqmodel_dynamic.py
View file @
1083e7e3
...
@@ -4,6 +4,7 @@ import unittest
...
@@ -4,6 +4,7 @@ import unittest
import
requests
import
requests
import
torch
import
torch
from
sglang.srt.server_args
import
set_global_server_args_for_scheduler
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
@@ -16,17 +17,15 @@ from sglang.test.test_utils import (
...
@@ -16,17 +17,15 @@ from sglang.test.test_utils import (
def
check_quant_method
(
model_path
:
str
,
use_marlin_kernel
:
bool
):
def
check_quant_method
(
model_path
:
str
,
use_marlin_kernel
:
bool
):
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tp_group
,
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
,
initialize_model_parallel
,
set_custom_all_reduce
,
)
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.quantization.utils
import
get_dynamic_override
from
sglang.srt.layers.quantization.utils
import
get_dynamic_override
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
try
:
try
:
init_distributed_environment
(
init_distributed_environment
(
...
@@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
...
@@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
pass
pass
server_args
=
ServerArgs
(
model_path
=
model_path
,
dtype
=
torch
.
float16
)
server_args
=
ServerArgs
(
model_path
=
model_path
,
dtype
=
torch
.
float16
)
set_global_server_args_for_scheduler
(
server_args
)
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
load_config
=
LoadConfig
()
load_config
=
LoadConfig
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment