Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6406408a
Unverified
Commit
6406408a
authored
Jun 10, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 10, 2025
Browse files
Clean up server_args.py (#7037)
parent
019851d0
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
278 additions
and
215 deletions
+278
-215
python/sglang/srt/layers/quantization/deep_gemm.py
python/sglang/srt/layers/quantization/deep_gemm.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+68
-37
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+183
-171
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+3
-0
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+3
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-0
No files found.
python/sglang/srt/layers/quantization/deep_gemm.py
View file @
6406408a
...
@@ -118,7 +118,7 @@ def _compile_warning_1():
...
@@ -118,7 +118,7 @@ def _compile_warning_1():
if
not
_IN_PRECOMPILE_STAGE
and
_IS_FIRST_RANK_ON_NODE
:
if
not
_IN_PRECOMPILE_STAGE
and
_IS_FIRST_RANK_ON_NODE
:
logger
.
warning
(
logger
.
warning
(
"Entering DeepGEMM JIT Pre-Compile session. "
"Entering DeepGEMM JIT Pre-Compile session. "
"
And i
t may takes a long time
(T
ypically 10-20 mins) "
"
I
t may takes a long time
(t
ypically 10-20 mins) "
"if you have not run `sglang.compile_deep_gemm`. "
"if you have not run `sglang.compile_deep_gemm`. "
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to reduce the overhead if you have not run it before. "
" for pre-compilation to reduce the overhead if you have not run it before. "
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
6406408a
...
@@ -72,32 +72,33 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...
@@ -72,32 +72,33 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
GLOBAL_SERVER_ARGS_KEYS
=
[
GLOBAL_SERVER_ARGS_KEYS
=
[
"attention_backend"
,
"attention_backend"
,
"mm_attention_backend"
,
"debug_tensor_dump_inject"
,
"debug_tensor_dump_inject"
,
"debug_tensor_dump_output_folder"
,
"debug_tensor_dump_output_folder"
,
"chunked_prefill_size"
,
"chunked_prefill_size"
,
"deepep_mode"
,
"device"
,
"device"
,
"disable_chunked_prefix_cache"
,
"disable_chunked_prefix_cache"
,
"disable_radix_cache"
,
"disable_radix_cache"
,
"enable_deepep_moe"
,
"enable_dp_attention"
,
"enable_dp_attention"
,
"enable_two_batch_overlap"
,
"enable_two_batch_overlap"
,
"enable_dp_lm_head"
,
"enable_dp_lm_head"
,
"enable_deepep_moe"
,
"deepep_mode"
,
"enable_ep_moe"
,
"enable_ep_moe"
,
"moe_dense_tp_size"
,
"ep_dispatch_algorithm"
,
"deepep_config"
,
"deepep_config"
,
"ep_num_redundant_experts"
,
"enable_nan_detection"
,
"enable_nan_detection"
,
"flashinfer_mla_disable_ragged"
,
"flashinfer_mla_disable_ragged"
,
"max_micro_batch_size"
,
"max_micro_batch_size"
,
"moe_dense_tp_size"
,
"ep_dispatch_algorithm"
,
"disable_shared_experts_fusion"
,
"disable_shared_experts_fusion"
,
"sampling_backend"
,
"sampling_backend"
,
"speculative_accept_threshold_acc"
,
"speculative_accept_threshold_acc"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_single"
,
"torchao_config"
,
"torchao_config"
,
"triton_attention_reduce_in_fp32"
,
"triton_attention_reduce_in_fp32"
,
"ep_num_redundant_experts"
,
"num_reserved_decode_tokens"
,
"mm_attention_backend"
,
]
]
# Put some global args for easy access
# Put some global args for easy access
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
6406408a
...
@@ -17,12 +17,14 @@ from __future__ import annotations
...
@@ -17,12 +17,14 @@ from __future__ import annotations
import
bisect
import
bisect
import
inspect
import
inspect
import
logging
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
torch
import
torch
import
tqdm
import
tqdm
from
torch.profiler
import
ProfilerActivity
,
profile
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
...
@@ -40,11 +42,14 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -40,11 +42,14 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
from
sglang.srt.two_batch_overlap
import
TboCudaGraphRunnerPlugin
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
empty_context
,
get_available_gpu_memory
,
get_available_gpu_memory
,
get_device_memory_capacity
,
get_device_memory_capacity
,
rank0_log
,
rank0_log
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -207,6 +212,9 @@ class CudaGraphRunner:
...
@@ -207,6 +212,9 @@ class CudaGraphRunner:
model_runner
.
server_args
.
enable_two_batch_overlap
model_runner
.
server_args
.
enable_two_batch_overlap
)
)
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
enable_profile_cuda_graph
=
(
model_runner
.
server_args
.
enable_profile_cuda_graph
)
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
...
@@ -339,44 +347,67 @@ class CudaGraphRunner:
...
@@ -339,44 +347,67 @@ class CudaGraphRunner:
return
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
return
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
def
capture
(
self
):
def
capture
(
self
)
->
None
:
with
graph_capture
()
as
graph_capture_context
:
profile_context
=
empty_context
()
self
.
stream
=
graph_capture_context
.
stream
if
self
.
enable_profile_cuda_graph
:
avail_mem
=
get_available_gpu_memory
(
profile_context
=
profile
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
activities
=
[
ProfilerActivity
.
CPU
,
ProfilerActivity
.
CUDA
],
record_shapes
=
True
,
)
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range
=
(
tqdm
.
tqdm
(
list
(
reversed
(
self
.
capture_bs
)))
if
get_tensor_model_parallel_rank
()
==
0
else
reversed
(
self
.
capture_bs
)
)
for
bs
in
capture_range
:
if
get_tensor_model_parallel_rank
()
==
0
:
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
capture_range
.
set_description
(
f
"Capturing batches (
{
avail_mem
=
:.
2
f
}
GB)"
)
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
num_tokens
=
bs
*
self
.
num_tokens_per_bs
,
tp_group
=
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
graph
,
output_buffers
,
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
output_buffers
[
bs
]
=
output_buffers
# Save gemlite cache after each capture
with
graph_capture
()
as
graph_capture_context
:
save_gemlite_cache
()
with
profile_context
as
prof
:
self
.
stream
=
graph_capture_context
.
stream
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range
=
(
tqdm
.
tqdm
(
list
(
reversed
(
self
.
capture_bs
)))
if
get_tensor_model_parallel_rank
()
==
0
else
reversed
(
self
.
capture_bs
)
)
for
i
,
bs
in
enumerate
(
capture_range
):
if
get_tensor_model_parallel_rank
()
==
0
:
avail_mem
=
get_available_gpu_memory
(
self
.
model_runner
.
device
,
self
.
model_runner
.
gpu_id
,
empty_cache
=
False
,
)
capture_range
.
set_description
(
f
"Capturing batches (
{
avail_mem
=
:.
2
f
}
GB)"
)
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
num_tokens
=
bs
*
self
.
num_tokens_per_bs
,
tp_group
=
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
graph
,
output_buffers
,
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
output_buffers
[
bs
]
=
output_buffers
# Save gemlite cache after each capture
save_gemlite_cache
()
if
self
.
enable_profile_cuda_graph
:
log_message
=
(
"Sorted by CUDA Time:
\n
"
+
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"cuda_time_total"
,
row_limit
=
10
)
+
"
\n\n
Sorted by CPU Time:
\n
"
+
prof
.
key_averages
(
group_by_input_shape
=
True
).
table
(
sort_by
=
"cpu_time_total"
,
row_limit
=
10
)
)
logger
.
info
(
log_message
)
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
...
@@ -443,7 +474,7 @@ class CudaGraphRunner:
...
@@ -443,7 +474,7 @@ class CudaGraphRunner:
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
.
sum
(),
seq_lens_sum
=
seq_lens
.
sum
()
.
item
()
,
encoder_lens
=
encoder_lens
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
...
...
python/sglang/srt/server_args.py
View file @
6406408a
This diff is collapsed.
Click to expand it.
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
6406408a
...
@@ -41,6 +41,9 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -41,6 +41,9 @@ class EAGLEDraftCudaGraphRunner:
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
enable_profile_cuda_graph
=
(
model_runner
.
server_args
.
enable_profile_cuda_graph
)
server_args
=
model_runner
.
server_args
server_args
=
model_runner
.
server_args
# Batch sizes to capture
# Batch sizes to capture
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
6406408a
...
@@ -39,6 +39,9 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -39,6 +39,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
self
.
enable_profile_cuda_graph
=
(
model_runner
.
server_args
.
enable_profile_cuda_graph
)
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
padded_static_len
=
-
1
self
.
padded_static_len
=
-
1
...
...
python/sglang/srt/utils.py
View file @
6406408a
...
@@ -837,6 +837,7 @@ class CustomCacheManager(FileCacheManager):
...
@@ -837,6 +837,7 @@ class CustomCacheManager(FileCacheManager):
def
set_ulimit
(
target_soft_limit
=
65535
):
def
set_ulimit
(
target_soft_limit
=
65535
):
# number of open files
resource_type
=
resource
.
RLIMIT_NOFILE
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
...
@@ -846,6 +847,18 @@ def set_ulimit(target_soft_limit=65535):
...
@@ -846,6 +847,18 @@ def set_ulimit(target_soft_limit=65535):
except
ValueError
as
e
:
except
ValueError
as
e
:
logger
.
warning
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
logger
.
warning
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
# stack size
resource_type
=
resource
.
RLIMIT_STACK
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
target_soft_limit_stack_size
=
1024
*
target_soft_limit
if
current_soft
<
target_soft_limit_stack_size
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit_stack_size
,
current_hard
)
)
except
ValueError
as
e
:
logger
.
warning
(
f
"Fail to set RLIMIT_STACK:
{
e
}
"
)
def
add_api_key_middleware
(
app
,
api_key
:
str
):
def
add_api_key_middleware
(
app
,
api_key
:
str
):
@
app
.
middleware
(
"http"
)
@
app
.
middleware
(
"http"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment