Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1bdd0102
Unverified
Commit
1bdd0102
authored
Oct 12, 2025
by
Cheng Wan
Committed by
GitHub
Oct 12, 2025
Browse files
Revert "Deprecate `global_server_args_dict`" (#11520)
parent
6cd29694
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
73 additions
and
72 deletions
+73
-72
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+2
-2
python/sglang/srt/models/qwen3_next_mtp.py
python/sglang/srt/models/qwen3_next_mtp.py
+4
-3
python/sglang/srt/models/qwen3_vl_moe.py
python/sglang/srt/models/qwen3_vl_moe.py
+11
-3
python/sglang/srt/models/step3_vl.py
python/sglang/srt/models/step3_vl.py
+3
-2
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+13
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-19
python/sglang/srt/single_batch_overlap.py
python/sglang/srt/single_batch_overlap.py
+1
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+7
-4
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+7
-3
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+6
-6
python/sglang/srt/speculative/ngram_info.py
python/sglang/srt/speculative/ngram_info.py
+5
-5
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+5
-6
test/srt/rl/test_fp32_lm_head.py
test/srt/rl/test_fp32_lm_head.py
+3
-9
test/srt/test_gptqmodel_dynamic.py
test/srt/test_gptqmodel_dynamic.py
+4
-4
No files found.
python/sglang/srt/models/qwen3_next.py
View file @
1bdd0102
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
...
@@ -46,7 +47,6 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -46,7 +47,6 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader
,
sharded_weight_loader
,
)
)
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
,
Qwen2MoeSparseMoeBlock
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
,
Qwen2MoeSparseMoeBlock
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
LazyValue
,
LazyValue
,
add_prefix
,
add_prefix
,
...
@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
...
@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
)
)
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
lm_head
=
self
.
lm_head
.
float
()
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/qwen3_next_mtp.py
View file @
1bdd0102
...
@@ -21,13 +21,14 @@ from torch import nn
...
@@ -21,13 +21,14 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
from
sglang.srt.layers.layernorm
import
GemmaRMSNorm
,
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeModel
from
sglang.srt.models.qwen3_next
import
Qwen3NextForCausalLM
,
Qwen3NextModel
from
sglang.srt.models.qwen3_next
import
Qwen3NextForCausalLM
,
Qwen3NextModel
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -68,7 +69,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
...
@@ -68,7 +69,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
prefix
=
add_prefix
(
"model.shared_head.head"
,
prefix
),
use_attn_tp_group
=
get_
global_server_args
().
enable_dp_lm_head
,
use_attn_tp_group
=
global_server_args
_dict
[
"
enable_dp_lm_head
"
]
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
...
python/sglang/srt/models/qwen3_vl_moe.py
View file @
1bdd0102
...
@@ -38,12 +38,20 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -38,12 +38,20 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
general_mm_embed_routine
from
sglang.srt.managers.mm_utils
import
(
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
(
MultimodalDataItem
,
MultimodalInputs
,
global_server_args_dict
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeModel
from
sglang.srt.models.qwen3_moe
import
Qwen3MoeForCausalLM
,
Qwen3MoeModel
from
sglang.srt.models.qwen3_vl
import
(
from
sglang.srt.models.qwen3_vl
import
(
Qwen3_VisionTransformer
,
Qwen3_VisionTransformer
,
Qwen3VLForConditionalGeneration
,
Qwen3VLForConditionalGeneration
,
...
...
python/sglang/srt/models/step3_vl.py
View file @
1bdd0102
...
@@ -57,6 +57,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -57,6 +57,7 @@ from sglang.srt.managers.schedule_batch import (
Modality
,
Modality
,
MultimodalDataItem
,
MultimodalDataItem
,
MultimodalInputs
,
MultimodalInputs
,
global_server_args_dict
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -299,7 +300,7 @@ class Step3TextDecoderLayer(nn.Module):
...
@@ -299,7 +300,7 @@ class Step3TextDecoderLayer(nn.Module):
# self.n_shared_experts = 1
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# self.num_fused_shared_experts = (
# 0
# 0
# if global_server_args
.
disable_shared_experts_fusion
# if global_server_args
_dict["
disable_shared_experts_fusion
"]
# else self.n_shared_experts
# else self.n_shared_experts
# )
# )
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
...
@@ -773,7 +774,7 @@ class Step3VLForConditionalGeneration(nn.Module):
...
@@ -773,7 +774,7 @@ class Step3VLForConditionalGeneration(nn.Module):
# self.n_shared_experts = 1
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# self.num_fused_shared_experts = (
# 0
# 0
# if global_server_args
.
disable_shared_experts_fusion
# if global_server_args
_dict["
disable_shared_experts_fusion
"]
# else self.n_shared_experts
# else self.n_shared_experts
# )
# )
self
.
num_fused_shared_experts
=
0
self
.
num_fused_shared_experts
=
0
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
1bdd0102
...
@@ -2,6 +2,7 @@ from __future__ import annotations
...
@@ -2,6 +2,7 @@ from __future__ import annotations
import
dataclasses
import
dataclasses
import
logging
import
logging
import
threading
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -9,7 +10,6 @@ import torch
...
@@ -9,7 +10,6 @@ import torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.custom_logit_processor
import
CustomLogitProcessor
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.sampling.sampling_params
import
TOP_K_ALL
from
sglang.srt.server_args
import
get_global_server_args
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -66,10 +66,16 @@ class SamplingBatchInfo:
...
@@ -66,10 +66,16 @@ class SamplingBatchInfo:
# Handle logit bias
# Handle logit bias
logit_bias
:
Optional
[
torch
.
Tensor
]
=
None
logit_bias
:
Optional
[
torch
.
Tensor
]
=
None
@
classmethod
def
_get_global_server_args_dict
(
cls
):
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
return
global_server_args_dict
@
classmethod
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
global_server_args
=
get_global_server_args
()
global_server_args
_dict
=
cls
.
_
get_global_server_args
_dict
()
enable_deterministic
=
global_server_args
.
enable_deterministic_inference
enable_deterministic
=
global_server_args
_dict
[
"
enable_deterministic_inference
"
]
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
device
=
batch
.
device
device
=
batch
.
device
...
@@ -106,9 +112,10 @@ class SamplingBatchInfo:
...
@@ -106,9 +112,10 @@ class SamplingBatchInfo:
logit_bias
[
i
,
int
(
key
)]
=
value
logit_bias
[
i
,
int
(
key
)]
=
value
# Check if any request has custom logit processor
# Check if any request has custom logit processor
has_custom_logit_processor
=
(
has_custom_logit_processor
=
global_server_args_dict
[
global_server_args
.
enable_custom_logit_processor
"enable_custom_logit_processor"
and
any
(
r
.
custom_logit_processor
for
r
in
reqs
)
# check the flag first.
]
and
any
(
# check the flag first.
r
.
custom_logit_processor
for
r
in
reqs
)
# then check the requests.
)
# then check the requests.
if
has_custom_logit_processor
:
if
has_custom_logit_processor
:
...
...
python/sglang/srt/server_args.py
View file @
1bdd0102
...
@@ -53,7 +53,6 @@ from sglang.utils import is_in_ci
...
@@ -53,7 +53,6 @@ from sglang.utils import is_in_ci
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Define constants
# Define constants
LOAD_FORMAT_CHOICES
=
[
LOAD_FORMAT_CHOICES
=
[
"auto"
,
"auto"
,
...
@@ -3324,22 +3323,6 @@ class ServerArgs:
...
@@ -3324,22 +3323,6 @@ class ServerArgs:
)
)
# NOTE: This is a global variable to hold the server args for scheduler.
_global_server_args
:
Optional
[
ServerArgs
]
=
None
def
set_global_server_args_for_scheduler
(
server_args
:
ServerArgs
):
global
_global_server_args
_global_server_args
=
server_args
def
get_global_server_args
()
->
ServerArgs
:
if
_global_server_args
is
None
:
raise
ValueError
(
"Global server args is not set yet!"
)
return
_global_server_args
def
prepare_server_args
(
argv
:
List
[
str
])
->
ServerArgs
:
def
prepare_server_args
(
argv
:
List
[
str
])
->
ServerArgs
:
"""
"""
Prepare the server arguments from the command line arguments.
Prepare the server arguments from the command line arguments.
...
@@ -3374,8 +3357,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
...
@@ -3374,8 +3357,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
ServerArgs
.
add_cli_args
(
parser
)
raw_args
=
parser
.
parse_args
(
argv
)
raw_args
=
parser
.
parse_args
(
argv
)
server_args
=
ServerArgs
.
from_cli_args
(
raw_args
)
return
S
erver
Args
.
from_cli_args
(
raw_args
)
return
s
erver
_args
ZMQ_TCP_PORT_DELTA
=
233
ZMQ_TCP_PORT_DELTA
=
233
...
...
python/sglang/srt/single_batch_overlap.py
View file @
1bdd0102
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.layers.moe
import
get_moe_runner_backend
from
sglang.srt.layers.moe.utils
import
is_sbo_enabled
from
sglang.srt.layers.moe.utils
import
is_sbo_enabled
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
get_int_env_var
from
sglang.srt.utils
import
get_int_env_var
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
1bdd0102
...
@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
...
@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.overlap_utils
import
FutureIndices
from
sglang.srt.managers.overlap_utils
import
FutureIndices
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.common
import
(
from
sglang.srt.mem_cache.common
import
(
alloc_paged_token_slots_extend
,
alloc_paged_token_slots_extend
,
...
@@ -19,7 +19,6 @@ from sglang.srt.mem_cache.common import (
...
@@ -19,7 +19,6 @@ from sglang.srt.mem_cache.common import (
get_last_loc
,
get_last_loc
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.eagle_info_v2
import
(
from
sglang.srt.speculative.eagle_info_v2
import
(
EagleDraftInputV2Mixin
,
EagleDraftInputV2Mixin
,
EagleVerifyInputV2Mixin
,
EagleVerifyInputV2Mixin
,
...
@@ -333,8 +332,12 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -333,8 +332,12 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
get_global_server_args
().
speculative_accept_threshold_single
,
threshold_single
=
global_server_args_dict
[
threshold_acc
=
get_global_server_args
().
speculative_accept_threshold_acc
,
"speculative_accept_threshold_single"
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
deterministic
=
True
,
)
)
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
1bdd0102
...
@@ -11,6 +11,7 @@ import triton.language as tl
...
@@ -11,6 +11,7 @@ import triton.language as tl
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.scheduler
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
CaptureHiddenMode
,
...
@@ -18,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -18,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.build_eagle_tree
import
TreeMaskMode
from
sglang.srt.speculative.build_eagle_tree
import
TreeMaskMode
from
sglang.srt.speculative.spec_utils
import
(
from
sglang.srt.speculative.spec_utils
import
(
SIMULATE_ACC_LEN
,
SIMULATE_ACC_LEN
,
...
@@ -265,8 +265,12 @@ class EagleVerifyInputV2Mixin:
...
@@ -265,8 +265,12 @@ class EagleVerifyInputV2Mixin:
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
get_global_server_args
().
speculative_accept_threshold_single
,
threshold_single
=
global_server_args_dict
[
threshold_acc
=
get_global_server_args
().
speculative_accept_threshold_acc
,
"speculative_accept_threshold_single"
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
deterministic
=
True
,
)
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
1bdd0102
...
@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
...
@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.mem_cache.common
import
(
from
sglang.srt.mem_cache.common
import
(
...
@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.server_args
import
ServerArgs
,
get_global_server_args
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
from
sglang.srt.speculative.eagle_draft_cuda_graph_runner
import
(
EAGLEDraftCudaGraphRunner
,
EAGLEDraftCudaGraphRunner
,
...
@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
)
)
def
_create_flashinfer_decode_backend
(
self
):
def
_create_flashinfer_decode_backend
(
self
):
if
not
get_
global_server_args
().
use_mla_backend
:
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferMultiStepDraftBackend
,
FlashInferMultiStepDraftBackend
,
)
)
...
@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
)
)
def
_create_trtllm_mla_decode_backend
(
self
):
def
_create_trtllm_mla_decode_backend
(
self
):
if
not
get_
global_server_args
().
use_mla_backend
:
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
raise
ValueError
(
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
)
...
@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
)
)
def
_create_flashinfer_prefill_backend
(
self
):
def
_create_flashinfer_prefill_backend
(
self
):
if
not
get_
global_server_args
().
use_mla_backend
:
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
FlashInferAttnBackend
,
)
)
...
@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
return
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
def
_create_trtllm_mla_prefill_backend
(
self
):
def
_create_trtllm_mla_prefill_backend
(
self
):
if
not
get_
global_server_args
().
use_mla_backend
:
if
not
global_server_args
_dict
[
"
use_mla_backend
"
]
:
raise
ValueError
(
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
)
...
...
python/sglang/srt/speculative/ngram_info.py
View file @
1bdd0102
...
@@ -7,8 +7,6 @@ from typing import Optional, Tuple
...
@@ -7,8 +7,6 @@ from typing import Optional, Tuple
import
torch
import
torch
import
triton
import
triton
from
sglang.srt.server_args
import
get_global_server_args
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -18,7 +16,7 @@ import torch.nn.functional as F
...
@@ -18,7 +16,7 @@ import torch.nn.functional as F
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.common
import
(
from
sglang.srt.mem_cache.common
import
(
alloc_paged_token_slots_extend
,
alloc_paged_token_slots_extend
,
alloc_token_slots
,
alloc_token_slots
,
...
@@ -352,8 +350,10 @@ class NgramVerifyInput(SpecInput):
...
@@ -352,8 +350,10 @@ class NgramVerifyInput(SpecInput):
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
get_global_server_args
().
speculative_accept_threshold_single
,
threshold_single
=
global_server_args_dict
[
threshold_acc
=
get_global_server_args
().
speculative_accept_threshold_acc
,
"speculative_accept_threshold_single"
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
deterministic
=
True
,
)
)
...
...
python/sglang/srt/two_batch_overlap.py
View file @
1bdd0102
...
@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
...
@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
)
)
from
sglang.srt.layers.moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
...
@@ -30,7 +30,6 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -30,7 +30,6 @@ from sglang.srt.model_executor.forward_batch_info import (
)
)
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
...
@@ -154,7 +153,7 @@ def _update_device_and_sum_field_from_cpu_field(
...
@@ -154,7 +153,7 @@ def _update_device_and_sum_field_from_cpu_field(
cpu_value
cpu_value
if
isinstance
(
cpu_value
,
torch
.
Tensor
)
if
isinstance
(
cpu_value
,
torch
.
Tensor
)
else
torch
.
tensor
(
cpu_value
,
dtype
=
old_device_value
.
dtype
)
else
torch
.
tensor
(
cpu_value
,
dtype
=
old_device_value
.
dtype
)
).
to
(
device
=
get_
global_server_args
().
device
,
non_blocking
=
True
)
).
to
(
device
=
global_server_args
_dict
[
"
device
"
]
,
non_blocking
=
True
)
setattr
(
batch
,
device_field
,
new_device_value
)
setattr
(
batch
,
device_field
,
new_device_value
)
if
sum_field
is
not
None
:
if
sum_field
is
not
None
:
...
@@ -583,7 +582,7 @@ class TboForwardBatchPreparer:
...
@@ -583,7 +582,7 @@ class TboForwardBatchPreparer:
sum_field
=
None
,
sum_field
=
None
,
)
)
_
,
child_b
.
extend_start_loc
=
compute_position
(
_
,
child_b
.
extend_start_loc
=
compute_position
(
get_
global_server_args
().
attention_backend
,
global_server_args
_dict
[
"
attention_backend
"
]
,
child_b
.
extend_prefix_lens
,
child_b
.
extend_prefix_lens
,
child_b
.
extend_seq_lens
,
child_b
.
extend_seq_lens
,
child_b
.
extend_num_tokens
,
child_b
.
extend_num_tokens
,
...
@@ -688,7 +687,7 @@ class TboForwardBatchPreparer:
...
@@ -688,7 +687,7 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw`
# TODO improve, e.g. unify w/ `init_raw`
if
(
if
(
get_
global_server_args
().
moe_dense_tp_size
==
1
global_server_args
_dict
[
"
moe_dense_tp_size
"
]
==
1
and
batch
.
global_dp_buffer_len
is
not
None
and
batch
.
global_dp_buffer_len
is
not
None
):
):
sum_len
=
end_token_index
-
start_token_index
sum_len
=
end_token_index
-
start_token_index
...
@@ -756,7 +755,7 @@ class TboForwardBatchPreparer:
...
@@ -756,7 +755,7 @@ class TboForwardBatchPreparer:
value_a
=
min
(
tbo_split_token_index
,
num_token_non_padded
)
value_a
=
min
(
tbo_split_token_index
,
num_token_non_padded
)
value_b
=
max
(
0
,
num_token_non_padded
-
tbo_split_token_index
)
value_b
=
max
(
0
,
num_token_non_padded
-
tbo_split_token_index
)
return
torch
.
tensor
([
value_a
,
value_b
],
dtype
=
torch
.
int32
).
to
(
return
torch
.
tensor
([
value_a
,
value_b
],
dtype
=
torch
.
int32
).
to
(
device
=
get_
global_server_args
().
device
,
non_blocking
=
True
device
=
global_server_args
_dict
[
"
device
"
]
,
non_blocking
=
True
)
)
@
classmethod
@
classmethod
...
...
test/srt/rl/test_fp32_lm_head.py
View file @
1bdd0102
...
@@ -7,11 +7,7 @@ import torch.nn as nn
...
@@ -7,11 +7,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.server_args
import
(
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
ServerArgs
,
get_global_server_args
,
set_global_server_args_for_scheduler
,
)
class
LMHeadStub
(
nn
.
Module
):
class
LMHeadStub
(
nn
.
Module
):
...
@@ -36,10 +32,8 @@ class TestLMHeadFP32(unittest.TestCase):
...
@@ -36,10 +32,8 @@ class TestLMHeadFP32(unittest.TestCase):
raise
unittest
.
SkipTest
(
"needs CUDA GPU"
)
raise
unittest
.
SkipTest
(
"needs CUDA GPU"
)
def
_make_logprocessor
(
self
,
vocab_size
,
enable_fp32
):
def
_make_logprocessor
(
self
,
vocab_size
,
enable_fp32
):
ServerArgs
.
__post_init__
=
lambda
self
:
None
# disable validation
global_server_args_dict
[
"enable_dp_lm_head"
]
=
False
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
global_server_args_dict
[
"enable_fp32_lm_head"
]
=
enable_fp32
get_global_server_args
().
enable_dp_lm_head
=
False
get_global_server_args
().
enable_fp32_lm_head
=
enable_fp32
cfg
=
SimpleNamespace
(
vocab_size
=
vocab_size
,
final_logit_softcapping
=
None
)
cfg
=
SimpleNamespace
(
vocab_size
=
vocab_size
,
final_logit_softcapping
=
None
)
return
LogitsProcessor
(
cfg
,
skip_all_gather
=
True
,
logit_scale
=
None
)
return
LogitsProcessor
(
cfg
,
skip_all_gather
=
True
,
logit_scale
=
None
)
...
...
test/srt/test_gptqmodel_dynamic.py
View file @
1bdd0102
...
@@ -4,7 +4,6 @@ import unittest
...
@@ -4,7 +4,6 @@ import unittest
import
requests
import
requests
import
torch
import
torch
from
sglang.srt.server_args
import
set_global_server_args_for_scheduler
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
@@ -17,15 +16,17 @@ from sglang.test.test_utils import (
...
@@ -17,15 +16,17 @@ from sglang.test.test_utils import (
def
check_quant_method
(
model_path
:
str
,
use_marlin_kernel
:
bool
):
def
check_quant_method
(
model_path
:
str
,
use_marlin_kernel
:
bool
):
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tp_group
,
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
,
initialize_model_parallel
,
set_custom_all_reduce
,
)
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.quantization.utils
import
get_dynamic_override
from
sglang.srt.layers.quantization.utils
import
get_dynamic_override
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
try
:
try
:
init_distributed_environment
(
init_distributed_environment
(
...
@@ -42,7 +43,6 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
...
@@ -42,7 +43,6 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
pass
pass
server_args
=
ServerArgs
(
model_path
=
model_path
,
dtype
=
torch
.
float16
)
server_args
=
ServerArgs
(
model_path
=
model_path
,
dtype
=
torch
.
float16
)
set_global_server_args_for_scheduler
(
server_args
)
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
model_config
=
ModelConfig
.
from_server_args
(
server_args
)
load_config
=
LoadConfig
()
load_config
=
LoadConfig
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment