Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
22352d47
Unverified
Commit
22352d47
authored
Jun 29, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 29, 2025
Browse files
Improve streaming, log_level, memory report, weight loading, and benchmark script (#7632)
Co-authored-by:
Kan Wu
<
wukanustc@gmail.com
>
parent
c5131f7a
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
461 additions
and
155 deletions
+461
-155
docs/backend/server_arguments.md
docs/backend/server_arguments.md
+1
-1
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+14
-2
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+0
-1
python/sglang/srt/configs/internvl.py
python/sglang/srt/configs/internvl.py
+2
-3
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+23
-6
python/sglang/srt/layers/elementwise.py
python/sglang/srt/layers/elementwise.py
+76
-12
python/sglang/srt/layers/moe/router.py
python/sglang/srt/layers/moe/router.py
+60
-22
python/sglang/srt/managers/configure_logging.py
python/sglang/srt/managers/configure_logging.py
+1
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-3
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+52
-0
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-51
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+36
-20
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+8
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+124
-25
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-2
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+38
-0
python/sglang/srt/models/mistral.py
python/sglang/srt/models/mistral.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-2
No files found.
docs/backend/server_arguments.md
View file @
22352d47
...
...
@@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`--log-level`
| The logging level of all loggers. | info |
|
`--log-level-http`
| The logging level of HTTP server. If not set, reuse --log-level by default. | None |
|
`--log-requests`
| Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
|
`--log-requests-level`
| 0: Log metadata
. 1. Log metadata
and partial input/output.
2.
Log every input/output. | 0 |
|
`--log-requests-level`
| 0: Log metadata
(no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters
and partial input/output.
3:
Log every input/output. | 0 |
|
`--show-time-cost`
| Show time cost of custom marks. | False |
|
`--enable-metrics`
| Enable log prometheus metrics. | False |
|
`--bucket-time-to-first-token`
| The buckets of time to first token, specified as a list of floats. | None |
...
...
python/sglang/bench_one_batch_server.py
View file @
22352d47
...
...
@@ -38,6 +38,7 @@ class BenchArgs:
output_len
:
Tuple
[
int
]
=
(
16
,)
temperature
:
float
=
0.0
return_logprob
:
bool
=
False
client_stream_interval
:
int
=
1
input_len_step_percentage
:
float
=
0.0
result_filename
:
str
=
"result.jsonl"
base_url
:
str
=
""
...
...
@@ -60,6 +61,11 @@ class BenchArgs:
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--client-stream-interval"
,
type
=
int
,
default
=
BenchArgs
.
client_stream_interval
,
)
parser
.
add_argument
(
"--input-len-step-percentage"
,
type
=
float
,
...
...
@@ -120,6 +126,7 @@ def run_one_case(
output_len
:
int
,
temperature
:
float
,
return_logprob
:
bool
,
stream_interval
:
int
,
input_len_step_percentage
:
float
,
run_name
:
str
,
result_filename
:
str
,
...
...
@@ -168,6 +175,7 @@ def run_one_case(
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
"stream_interval"
:
stream_interval
,
},
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
...
...
@@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
else
:
proc
,
base_url
=
launch_server_process
(
server_args
)
tokenizer_id
=
server_args
.
tokenizer_path
or
server_args
.
model_path
tokenizer
=
get_tokenizer
(
tokenizer_id
)
server_info
=
requests
.
get
(
base_url
+
"/get_server_info"
)
tokenizer_path
=
server_info
.
json
()[
"tokenizer_path"
]
tokenizer
=
get_tokenizer
(
tokenizer_path
)
# warmup
if
not
bench_args
.
skip_warmup
:
...
...
@@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
output_len
=
16
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
""
,
result_filename
=
""
,
...
...
@@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
...
...
@@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
stream_interval
=
bench_args
.
client_stream_interval
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
...
...
python/sglang/bench_serving.py
View file @
22352d47
...
...
@@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
if
args
.
base_url
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
args
.
apply_chat_template
=
True
elif
args
.
backend
in
[
"sglang-oai"
,
"vllm"
,
"lmdeploy"
]:
api_url
=
(
f
"
{
args
.
base_url
}
/v1/completions"
...
...
python/sglang/srt/configs/internvl.py
View file @
22352d47
...
...
@@ -147,12 +147,11 @@ class InternLM2Config(PretrainedConfig):
)
if
(
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
not
isinstance
(
rope_scaling_factor
,
int
)
or
not
isinstance
(
rope_scaling_factor
,
(
float
,
int
))
or
rope_scaling_factor
<
1.0
):
raise
ValueError
(
f
"`rope_scaling`'s factor field must be a float|int >= 1, got
{
rope_scaling_factor
}
"
f
"`rope_scaling`'s factor field must be a float|int >= 1, got
{
rope_scaling_factor
=
}
,
{
type
(
rope_scaling_factor
)
=
}
"
)
if
isinstance
(
rope_scaling_factor
,
int
):
rope_scaling_factor
=
float
(
rope_scaling_factor
)
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
22352d47
...
...
@@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
@
asynccontextmanager
async
def
lifespan
(
fast_api_app
:
FastAPI
):
server_args
:
ServerArgs
=
fast_api_app
.
server_args
# Initialize OpenAI serving handlers
fast_api_app
.
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
_global_state
.
tokenizer_manager
,
_global_state
.
template_manager
...
...
@@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
_global_state
.
tokenizer_manager
)
server_args
:
ServerArgs
=
fast_api_app
.
server_args
if
server_args
.
warmups
is
not
None
:
await
execute_warmups
(
server_args
.
warmups
.
split
(
","
),
_global_state
.
tokenizer_manager
server_args
.
disaggregation_mode
,
server_args
.
warmups
.
split
(
","
),
_global_state
.
tokenizer_manager
,
)
logger
.
info
(
"Warmup ended"
)
...
...
@@ -280,13 +281,17 @@ async def get_model_info():
"model_path"
:
_global_state
.
tokenizer_manager
.
model_path
,
"tokenizer_path"
:
_global_state
.
tokenizer_manager
.
server_args
.
tokenizer_path
,
"is_generation"
:
_global_state
.
tokenizer_manager
.
is_generation
,
"preferred_sampling_params"
:
_global_state
.
tokenizer_manager
.
server_args
.
preferred_sampling_params
,
}
return
result
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_info
():
internal_states
=
await
_global_state
.
tokenizer_manager
.
get_internal_state
()
# Returns interna states per DP.
internal_states
:
List
[
Dict
[
Any
,
Any
]]
=
(
await
_global_state
.
tokenizer_manager
.
get_internal_state
()
)
return
{
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
_global_state
.
scheduler_info
,
...
...
@@ -300,6 +305,8 @@ async def get_load():
return
await
_global_state
.
tokenizer_manager
.
get_load
()
# example usage:
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
@
app
.
api_route
(
"/set_internal_state"
,
methods
=
[
"POST"
,
"PUT"
])
async
def
set_internal_state
(
obj
:
SetInternalStateReq
,
request
:
Request
):
res
=
await
_global_state
.
tokenizer_manager
.
set_internal_state
(
obj
)
...
...
@@ -886,6 +893,15 @@ def launch_server(
add_prometheus_middleware
(
app
)
enable_func_timer
()
image_token_text
=
None
if
(
tokenizer_manager
.
image_token_id
is
not
None
and
not
server_args
.
skip_tokenizer_init
):
image_token_text
=
tokenizer_manager
.
tokenizer
.
decode
(
[
tokenizer_manager
.
image_token_id
]
)
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread
=
threading
.
Thread
(
...
...
@@ -893,7 +909,7 @@ def launch_server(
args
=
(
server_args
,
pipe_finish_writer
,
_global_state
.
tokenizer_manager
.
image_token_
id
,
image_token_
text
,
launch_callback
,
),
)
...
...
@@ -1022,9 +1038,10 @@ def _wait_and_warmup(
return
# Debug print
# logger.info(f"{res.json()=}")
# logger.info(f"
warmup request returns:
{res.json()=}")
logger
.
info
(
"The server is fired up and ready to roll!"
)
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
"ready"
)
...
...
python/sglang/srt/layers/elementwise.py
View file @
22352d47
...
...
@@ -8,6 +8,7 @@ from sglang.srt.utils import is_hip
_is_hip
=
is_hip
()
fused_softcap_autotune
=
triton
.
autotune
(
configs
=
[
triton
.
Config
(
kwargs
=
{
"BLOCK_SIZE"
:
128
},
num_warps
=
4
),
...
...
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert
x
.
shape
==
residual
.
shape
and
x
.
dtype
==
residual
.
dtype
output
,
mid
=
torch
.
empty_like
(
x
),
torch
.
empty_like
(
x
)
bs
,
hidden_dim
=
x
.
shape
min_num_warps
=
16
if
_is_hip
else
32
if
autotune
:
fused_dual_residual_rmsnorm_kernel_autotune
[(
bs
,)](
output
,
mid
,
x
,
residual
,
weight1
,
weight2
,
eps
=
eps
,
hidden_dim
=
hidden_dim
)
else
:
max_warps
=
16
if
_is_hip
else
32
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
min_num_warps
),
4
,
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
max_warps
),
4
),
}
...
...
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else
:
output
=
torch
.
empty_like
(
x
)
bs
,
hidden_dim
=
x
.
shape
min_num_warps
=
16
if
_is_hip
else
32
max_warps
=
16
if
_is_hip
else
32
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
in_num
_warps
),
4
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
ax
_warps
),
4
),
}
...
...
@@ -331,6 +325,75 @@ class FusedDualResidualRMSNorm:
return
self
.
rmsnorm2
.
forward_native
(
residual
),
residual
@
triton
.
jit
def
experts_combine_kernel
(
out_hidden_states
,
moe_hidden_states
,
mlp_hidden_states
,
combine_k
:
tl
.
constexpr
,
hidden_dim
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_index_mlp
=
pid
*
hidden_dim
start_index_rmoe
=
pid
*
hidden_dim
*
combine_k
offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
hidden_dim
combine_k_offsets
=
tl
.
arange
(
0
,
combine_k
)
moe_x
=
tl
.
load
(
moe_hidden_states
+
start_index_rmoe
+
combine_k_offsets
[:,
None
]
*
hidden_dim
+
offsets
[
None
,
:],
mask
=
mask
[
None
,
:],
other
=
0.0
,
)
moe_x
=
tl
.
sum
(
moe_x
,
axis
=
0
)
mlp_x
=
tl
.
load
(
mlp_hidden_states
+
start_index_mlp
+
offsets
,
mask
=
mask
,
other
=
0.0
)
combined_x
=
(
moe_x
+
mlp_x
)
/
1.4142135623730951
tl
.
store
(
out_hidden_states
+
start_index_mlp
+
offsets
,
combined_x
,
mask
=
mask
)
def
experts_combine_triton
(
moe_hidden_states
,
mlp_hidden_states
,
output_buffer
=
None
):
assert
moe_hidden_states
.
is_contiguous
()
assert
mlp_hidden_states
.
is_contiguous
()
if
len
(
moe_hidden_states
.
shape
)
==
2
:
combine_k
=
1
# pre-combined
else
:
combine_k
=
moe_hidden_states
.
shape
[
1
]
if
output_buffer
is
None
:
out_hidden_states
=
torch
.
empty_like
(
mlp_hidden_states
)
else
:
flat_output_buffer
=
output_buffer
.
view
(
mlp_hidden_states
.
dtype
).
reshape
(
-
1
)
assert
flat_output_buffer
.
numel
()
>=
mlp_hidden_states
.
numel
()
out_hidden_states
=
flat_output_buffer
[:
mlp_hidden_states
.
numel
()].
reshape
(
mlp_hidden_states
.
shape
)
bs
,
hidden_dim
=
mlp_hidden_states
.
shape
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
1024
)),
8
),
4
),
}
experts_combine_kernel
[(
bs
,)](
out_hidden_states
,
moe_hidden_states
,
mlp_hidden_states
,
combine_k
,
hidden_dim
,
**
config
,
)
return
out_hidden_states
# gelu on first half of vector
@
triton
.
jit
def
gelu_and_mul_kernel
(
...
...
@@ -400,10 +463,11 @@ def gelu_and_mul_triton(
out_scales
=
scales
static_scale
=
True
max_warps
=
16
if
_is_hip
else
32
config
=
{
# 8 ele per thread (not tuned)
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
8
*
32
)),
32
),
4
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
8
*
32
)),
max_warps
),
4
),
}
...
...
python/sglang/srt/layers/moe/router.py
View file @
22352d47
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
triton
...
...
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
moe_router_weight_ptr
,
# input (num_experts, hidden_dim)
topk_weights_ptr
,
# output (bs, topk)
topk_ids_ptr
,
# output (bs, topk)
correction_bias_ptr
,
is_correction_bias
:
tl
.
constexpr
,
num_experts
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
moe_softcapping
:
tl
.
constexpr
,
...
...
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
bottom
=
exped
+
1
logits_softcapped
=
top
/
bottom
*
moe_softcapping
# Add bias after softcapping
if
is_correction_bias
:
bias
=
tl
.
load
(
correction_bias_ptr
+
tl
.
arange
(
0
,
num_experts
))
logits_softcapped
=
logits_softcapped
+
bias
# topk
# assert 1 <= topk <= num_experts
...
...
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
router_weight
:
torch
.
Tensor
,
topk
:
int
,
moe_softcapping
:
float
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
assert
len
(
x
.
shape
)
==
2
and
x
.
shape
[
1
]
==
router_weight
.
shape
[
1
]
bs
,
hidden_dim
=
x
.
shape
...
...
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
topk_weights
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
topk_ids
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
int32
,
device
=
x
.
device
)
is_correction_bias
=
correction_bias
is
not
None
grid
=
lambda
meta
:
(
bs
,)
min_num_warps
=
16
if
_is_hip
else
32
max_warps
=
16
if
_is_hip
else
32
config
=
{
"BLOCK_SIZE"
:
triton
.
next_power_of_2
(
hidden_dim
),
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
in_num
_warps
),
4
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
256
)),
m
ax
_warps
),
4
),
}
fused_moe_router_kernel
[
grid
](
fused_moe_router_kernel
[
(
bs
,)
](
x
,
router_weight
,
topk_weights
,
topk_ids
,
correction_bias
,
is_correction_bias
=
is_correction_bias
,
num_experts
=
num_experts
,
topk
=
topk
,
moe_softcapping
=
moe_softcapping
,
...
...
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
topk_ids_ptr
,
# output (bs, topk)
bs
,
num_experts
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
# only support topk
=
=
1
topk
:
tl
.
constexpr
,
# only support topk
<
=
2
moe_softcapping
:
tl
.
constexpr
,
moe_renormalize
:
tl
.
constexpr
,
# not supported
K
:
tl
.
constexpr
,
...
...
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
# 5. top1
cond
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
<
num_experts
top1
=
tl
.
argmax
(
tl
.
where
(
cond
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
)
arange_block_size_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
cond_top1
=
arange_block_size_n
<
num_experts
top1
=
tl
.
argmax
(
tl
.
where
(
cond_top1
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
)
top1_v
=
tl
.
max
(
tl
.
where
(
cond
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
,
keep_dims
=
True
tl
.
where
(
cond
_top1
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
,
keep_dims
=
True
)
invsumexp
=
1.0
/
tl
.
sum
(
tl
.
where
(
cond
,
tl
.
exp
(
logits_softcapped
-
top1_v
),
0.0
),
axis
=
1
top1_
invsumexp
=
1.0
/
tl
.
sum
(
tl
.
where
(
cond
_top1
,
tl
.
exp
(
logits_softcapped
-
top1_v
),
0.0
),
axis
=
1
)
# 6. store to output
offs_top
k
=
pid
*
topk
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
top
k
_mask
=
offs_top
k
<
bs
tl
.
store
(
topk_ids_ptr
+
offs_top
k
,
top1
,
mask
=
top
k
_mask
)
# 6. store
top1
to output
offs_top
1
=
pid
*
topk
*
BLOCK_SIZE_M
+
topk
*
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
top
1
_mask
=
offs_top
1
<
bs
*
topk
tl
.
store
(
topk_ids_ptr
+
offs_top
1
,
top1
,
mask
=
top
1
_mask
)
tl
.
store
(
topk_weights_ptr
+
offs_top
k
,
invsumexp
,
mask
=
top
k
_mask
,
topk_weights_ptr
+
offs_top
1
,
top1_
invsumexp
,
mask
=
top
1
_mask
,
)
# 7. handle topk == 2
if
topk
==
2
:
cond_top2
=
(
arange_block_size_n
<
num_experts
)
and
(
arange_block_size_n
!=
top1
[:,
None
]
)
top2
=
tl
.
argmax
(
tl
.
where
(
cond_top2
,
logits_softcapped
,
float
(
"-inf"
)),
axis
=
1
,
keep_dims
=
True
,
)
top2_v
=
tl
.
sum
(
logits_softcapped
*
(
arange_block_size_n
==
top2
),
axis
=
1
,
keep_dims
=
True
)
top2_invsumexp
=
tl
.
exp
(
top2_v
-
top1_v
)
*
top1_invsumexp
[:,
None
]
# store top2
offs_top2
=
(
pid
*
topk
*
BLOCK_SIZE_M
+
topk
*
tl
.
arange
(
0
,
BLOCK_SIZE_M
)[:,
None
]
+
1
)
top2_mask
=
offs_top2
<
bs
*
topk
tl
.
store
(
topk_ids_ptr
+
offs_top2
,
top2
,
mask
=
top2_mask
)
tl
.
store
(
topk_weights_ptr
+
offs_top2
,
top2_invsumexp
,
mask
=
top2_mask
,
)
def
fused_moe_router_large_bs_impl
(
x
:
torch
.
Tensor
,
...
...
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
assert
num_experts
<=
BLOCK_SIZE_N
assert
hidden_dim
%
BLOCK_SIZE_K
==
0
assert
topk
=
=
1
assert
topk
<
=
2
topk_weights
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
topk_ids
=
torch
.
empty
((
bs
,
topk
),
dtype
=
torch
.
int32
,
device
=
x
.
device
)
...
...
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
gating_output
,
topk
,
renormalize
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
assert
not
renormalize
assert
(
...
...
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
BLOCK_SIZE_K
=
256
if
(
bs
>=
512
and
topk
=
=
1
and
topk
<
=
2
and
num_experts
<=
BLOCK_SIZE_N
and
hidden_dim
%
BLOCK_SIZE_K
==
0
):
...
...
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
router_weight
=
gating_output
,
topk
=
topk
,
moe_softcapping
=
moe_softcapping
,
correction_bias
=
correction_bias
,
)
...
...
python/sglang/srt/managers/configure_logging.py
View file @
22352d47
...
...
@@ -28,7 +28,7 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--url"
,
type
=
str
,
default
=
"http://localhost:30000"
)
parser
.
add_argument
(
"--log-requests"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--log-requests-level"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--log-requests-level"
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
"--dump-requests-folder"
,
type
=
str
,
default
=
"/tmp/sglang_request_dump"
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
22352d47
...
...
@@ -516,9 +516,6 @@ class EmbeddingReqInput:
# For cross-encoder requests
is_cross_encoder_request
:
bool
=
False
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
def
normalize_batch_and_arguments
(
self
):
# at least one of text, input_ids, or image should be provided
if
self
.
text
is
None
and
self
.
input_ids
is
None
and
self
.
image_data
is
None
:
...
...
@@ -572,6 +569,9 @@ class EmbeddingReqInput:
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
def
__getitem__
(
self
,
i
):
if
self
.
is_cross_encoder_request
:
return
EmbeddingReqInput
(
...
...
python/sglang/srt/managers/mm_utils.py
View file @
22352d47
...
...
@@ -2,12 +2,15 @@
Multi-modality utils
"""
import
hashlib
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
...
...
@@ -678,3 +681,52 @@ def get_multimodal_data_bounds(
# Convert valid pairs to tensor
valid_pairs_tensor
=
torch
.
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
return
valid_pairs_tensor
def
data_hash
(
data
)
->
int
:
hash_bytes
=
hashlib
.
sha256
(
data
).
digest
()[:
8
]
return
int
.
from_bytes
(
hash_bytes
,
byteorder
=
"big"
,
signed
=
False
)
def
tensor_hash
(
tensor_list
)
->
int
:
"""
hash a tensor or a tensor list
"""
tensor
=
tensor_list
if
isinstance
(
tensor_list
,
list
):
tensor_list
=
flatten_nested_list
(
tensor_list
)
tensor_list
=
[
x
.
flatten
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
tensor_list
]
tensor
=
torch
.
concat
(
tensor_list
)
if
tensor
.
is_cuda
:
return
gpu_tensor_hash
(
tensor
)
tensor
=
tensor
.
detach
().
contiguous
()
if
tensor
.
dtype
==
torch
.
bfloat16
:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor
=
tensor
.
float
()
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
is_cuda
:
# TODO: improve this
tensor_cpu
=
tensor
.
cpu
()
else
:
tensor_cpu
=
tensor
mv
=
memoryview
(
tensor_cpu
.
numpy
())
return
data_hash
(
mv
.
tobytes
())
def
hash_feature
(
f
):
if
isinstance
(
f
,
list
):
if
isinstance
(
f
[
0
],
torch
.
Tensor
):
return
tensor_hash
(
f
)
return
data_hash
(
tuple
(
flatten_nested_list
(
f
)))
elif
isinstance
(
f
,
np
.
ndarray
):
arr
=
np
.
ascontiguousarray
(
f
)
arr_bytes
=
arr
.
tobytes
()
return
data_hash
(
arr_bytes
)
elif
isinstance
(
f
,
torch
.
Tensor
):
return
tensor_hash
([
f
])
return
data_hash
(
f
)
python/sglang/srt/managers/multimodal_processor.py
View file @
22352d47
...
...
@@ -3,7 +3,6 @@ import importlib
import
inspect
import
logging
import
pkgutil
from
functools
import
lru_cache
from
sglang.srt.multimodal.processors.base_processor
import
BaseMultimodalProcessor
from
sglang.srt.server_args
import
ServerArgs
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
22352d47
...
...
@@ -33,7 +33,6 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import
copy
import
dataclasses
import
hashlib
import
logging
import
threading
from
enum
import
Enum
,
auto
...
...
@@ -53,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin
,
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
,
SWAChunkCache
...
...
@@ -96,8 +94,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"max_micro_batch_size"
,
"disable_shared_experts_fusion"
,
"sampling_backend"
,
"speculative_accept_threshold_acc"
,
"speculative_accept_threshold_single"
,
"speculative_accept_threshold_acc"
,
"torchao_config"
,
"triton_attention_reduce_in_fp32"
,
"num_reserved_decode_tokens"
,
...
...
@@ -180,7 +178,9 @@ class Modality(Enum):
@
dataclasses
.
dataclass
class
MultimodalDataItem
:
"""
A single multimodal data, from a single image/video/audio or others.
One MultimodalDataItem contains all inputs for one modality.
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio.
We put the common fields first and the model-specific fields last.
"""
...
...
@@ -232,53 +232,7 @@ class MultimodalDataItem:
"""
Set the pad value after first hashing the data
"""
def
data_hash
(
data
)
->
int
:
hash_bytes
=
hashlib
.
sha256
(
data
).
digest
()[:
8
]
return
int
.
from_bytes
(
hash_bytes
,
byteorder
=
"big"
,
signed
=
False
)
def
tensor_hash
(
tensor_list
)
->
int
:
"""
hash a tensor or a tensor list
"""
tensor
=
tensor_list
if
isinstance
(
tensor_list
,
list
):
tensor_list
=
flatten_nested_list
(
tensor_list
)
tensor_list
=
[
x
.
flatten
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
tensor_list
]
tensor
=
torch
.
concat
(
tensor_list
)
if
tensor
.
is_cuda
:
return
gpu_tensor_hash
(
tensor
)
tensor
=
tensor
.
detach
().
contiguous
()
if
tensor
.
dtype
==
torch
.
bfloat16
:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor
=
tensor
.
float
()
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
is_cuda
:
# TODO: improve this
tensor_cpu
=
tensor
.
cpu
()
else
:
tensor_cpu
=
tensor
mv
=
memoryview
(
tensor_cpu
.
numpy
())
return
data_hash
(
mv
.
tobytes
())
def
hash_feature
(
f
):
if
isinstance
(
f
,
list
):
if
isinstance
(
f
[
0
],
torch
.
Tensor
):
return
tensor_hash
(
f
)
return
data_hash
(
tuple
(
flatten_nested_list
(
f
)))
elif
isinstance
(
f
,
np
.
ndarray
):
arr
=
np
.
ascontiguousarray
(
f
)
arr_bytes
=
arr
.
tobytes
()
return
data_hash
(
arr_bytes
)
elif
isinstance
(
f
,
torch
.
Tensor
):
return
tensor_hash
([
f
])
return
data_hash
(
f
)
from
sglang.srt.managers.mm_utils
import
hash_feature
if
self
.
precomputed_features
is
not
None
:
self
.
hash
=
hash_feature
(
self
.
precomputed_features
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
22352d47
...
...
@@ -418,14 +418,16 @@ class Scheduler(
self
.
last_decode_stats_tic
=
time
.
perf_counter
()
self
.
last_prefill_stats_tic
=
time
.
perf_counter
()
self
.
return_health_check_ct
=
0
self
.
num_retracted_reqs
:
int
=
0
self
.
num_paused_reqs
:
int
=
0
self
.
kv_transfer_speed_gb_s
:
float
=
0.0
self
.
kv_transfer_latency_ms
:
float
=
0.0
self
.
sessions
:
Dict
[
str
,
Session
]
=
{}
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
if
self
.
device
==
"cpu"
:
self
.
current_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
forward_sleep_time
=
None
# Init session info
self
.
sessions
:
Dict
[
str
,
Session
]
=
{}
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
if
self
.
chunked_prefill_size
<=
0
:
# -1 means disable
...
...
@@ -473,26 +475,12 @@ class Scheduler(
t
=
threading
.
Thread
(
target
=
self
.
watchdog_thread
,
daemon
=
True
)
t
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
# Init memory saver, profiler and metric stats
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
server_args
.
enable_memory_saver
)
# Init profiler
self
.
torch_profiler
=
None
self
.
torch_profiler_output_dir
:
Optional
[
str
]
=
None
self
.
profiler_activities
:
Optional
[
List
[
str
]]
=
None
self
.
profile_id
:
Optional
[
str
]
=
None
self
.
profiler_target_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_decode_ct
:
Optional
[
int
]
=
None
self
.
profiler_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_decode_ct
:
Optional
[
int
]
=
None
self
.
profile_by_stage
:
bool
=
False
self
.
profile_steps
:
Optional
[
int
]
=
None
self
.
profile_in_progress
:
bool
=
False
self
.
rpd_profiler
=
None
# Init metrics stats
self
.
init_profier
()
self
.
init_metrics
()
self
.
init_kv_events
(
server_args
.
kv_events_config
)
...
...
@@ -526,6 +514,7 @@ class Scheduler(
]
)
# Init disaggregation
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
...
...
@@ -624,6 +613,21 @@ class Scheduler(
)
)
def
init_profier
(
self
):
self
.
torch_profiler
=
None
self
.
torch_profiler_output_dir
:
Optional
[
str
]
=
None
self
.
profiler_activities
:
Optional
[
List
[
str
]]
=
None
self
.
profile_id
:
Optional
[
str
]
=
None
self
.
profiler_target_forward_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_target_decode_ct
:
Optional
[
int
]
=
None
self
.
profiler_prefill_ct
:
Optional
[
int
]
=
None
self
.
profiler_decode_ct
:
Optional
[
int
]
=
None
self
.
profile_by_stage
:
bool
=
False
self
.
profile_steps
:
Optional
[
int
]
=
None
self
.
profile_in_progress
:
bool
=
False
self
.
rpd_profiler
=
None
def
init_metrics
(
self
):
self
.
last_gen_throughput
:
float
=
0.0
self
.
last_input_throughput
:
float
=
0.0
...
...
@@ -2107,6 +2111,18 @@ class Scheduler(
def
get_internal_state
(
self
,
recv_req
:
GetInternalStateReq
):
ret
=
dict
(
global_server_args_dict
)
ret
[
"last_gen_throughput"
]
=
self
.
last_gen_throughput
ret
[
"memory_usage"
]
=
{
"weight"
:
round
(
self
.
tp_worker
.
worker
.
model_runner
.
weight_load_mem_usage
,
2
),
"kvcache"
:
round
(
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
mem_usage
,
2
),
"cuda_graph"
:
round
(
self
.
tp_worker
.
worker
.
model_runner
.
cuda_graph_mem_usage
,
2
),
"token_capacity"
:
int
(
self
.
max_total_num_tokens
),
}
if
not
self
.
spec_algorithm
.
is_none
()
and
self
.
cum_spec_accept_count
>
0
:
ret
[
"avg_spec_accept_length"
]
=
(
self
.
cum_spec_accept_length
/
self
.
cum_spec_accept_count
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
22352d47
...
...
@@ -521,11 +521,17 @@ class SchedulerOutputProcessorMixin:
stream_interval
=
(
req
.
sampling_params
.
stream_interval
or
self
.
stream_interval
)
should_output
=
len
(
req
.
output_ids
)
%
stream_interval
==
0
should_output
=
(
len
(
req
.
output_ids
)
%
stream_interval
==
1
if
not
self
.
model_config
.
is_multimodal_gen
and
stream_interval
>
1
else
len
(
req
.
output_ids
)
%
stream_interval
==
0
)
else
:
should_output
=
(
len
(
req
.
output_ids
)
%
DEFAULT_FORCE_STREAM_INTERVAL
==
0
and
not
self
.
model_config
.
is_multimodal_gen
if
not
self
.
model_config
.
is_multimodal_gen
else
False
)
if
should_output
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
22352d47
...
...
@@ -111,11 +111,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
from
sglang.srt.managers.multimodal_processor
import
(
get_dummy_processor
,
get_mm_processor
,
import_processors
,
)
from
sglang.srt.managers.multimodal_processor
import
get_mm_processor
,
import_processors
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -187,6 +183,8 @@ class TokenizerManager:
if
server_args
.
preferred_sampling_params
else
None
)
self
.
crash_dump_folder
=
server_args
.
crash_dump_folder
self
.
crash_dump_performed
=
False
# Flag to ensure dump is only called once
# Init inter-process communication
context
=
zmq
.
asyncio
.
Context
(
2
)
...
...
@@ -251,10 +249,11 @@ class TokenizerManager:
self
.
dump_requests_folder
=
""
# By default do not dump
self
.
dump_requests_threshold
=
1000
self
.
dump_request_list
:
List
[
Tuple
]
=
[]
self
.
crash_dump_request_list
:
deque
[
Tuple
]
=
deque
()
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
self
.
asyncio_tasks
=
set
()
self
.
session_futures
=
{}
# session_id -> asyncio event
self
.
max_req_input_len
=
None
self
.
asyncio_tasks
=
set
()
# The event to notify the weight sync is finished.
self
.
model_update_lock
=
RWLock
()
...
...
@@ -266,14 +265,14 @@ class TokenizerManager:
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
)
self
.
transfer_backend
=
TransferBackend
(
self
.
disaggregation_
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
self
.
disaggregation_
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
server_args
.
disaggregation_bootstrap_port
...
...
@@ -324,7 +323,6 @@ class TokenizerManager:
self
.
profile_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
health_check_communitcator
=
_Communicator
(
self
.
send_to_scheduler
,
1
)
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
...
...
@@ -484,7 +482,7 @@ class TokenizerManager:
token_type_ids
=
encoded
.
get
(
"token_type_ids"
,
[
None
])[
0
]
if
self
.
mm_processor
and
obj
.
contains_mm_input
():
image_inputs
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_inputs
:
Dict
=
await
self
.
mm_processor
.
process_mm_data_async
(
image_data
=
obj
.
image_data
,
input_text
=
input_text
or
input_ids
,
request_obj
=
obj
,
...
...
@@ -547,6 +545,14 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature."
)
def
_validate_input_ids_in_vocab
(
self
,
input_ids
:
List
[
int
],
vocab_size
:
int
)
->
None
:
if
any
(
id
>=
vocab_size
for
id
in
input_ids
):
raise
ValueError
(
f
"The input_ids
{
input_ids
}
contains values greater than the vocab size (
{
vocab_size
}
)."
)
def
_create_tokenized_object
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
...
...
@@ -1096,12 +1102,36 @@ class TokenizerManager:
"image_data"
,
"audio_data"
,
"lora_path"
,
"sampling_params"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
]
)
out_skip_names
=
set
([
"text"
,
"output_ids"
,
"embedding"
])
elif
self
.
log_requests_level
==
1
:
max_length
=
2048
max_length
=
1
<<
30
skip_names
=
set
(
[
"text"
,
"input_ids"
,
"input_embeds"
,
"image_data"
,
"audio_data"
,
"lora_path"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
]
)
elif
self
.
log_requests_level
==
2
:
max_length
=
2048
elif
self
.
log_requests_level
==
3
:
max_length
=
1
<<
30
else
:
raise
ValueError
(
...
...
@@ -1118,6 +1148,8 @@ class TokenizerManager:
self
.
dump_requests_folder
=
obj
.
dump_requests_folder
if
obj
.
dump_requests_threshold
is
not
None
:
self
.
dump_requests_threshold
=
obj
.
dump_requests_threshold
if
obj
.
crash_dump_folder
is
not
None
:
self
.
crash_dump_folder
=
obj
.
crash_dump_folder
logging
.
info
(
f
"Config logging:
{
obj
=
}
"
)
self
.
log_request_metadata
=
self
.
get_log_request_metadata
()
...
...
@@ -1166,6 +1198,52 @@ class TokenizerManager:
loop
.
create_task
(
print_exception_wrapper
(
self
.
sigterm_watchdog
))
)
def
dump_requests_before_crash
(
self
):
if
self
.
crash_dump_performed
:
logger
.
info
(
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
)
return
logger
.
error
(
f
"Dumping requests before crash.
{
self
.
crash_dump_folder
=
}
"
)
self
.
crash_dump_performed
=
True
if
not
self
.
crash_dump_folder
:
return
data_to_dump
=
[]
if
self
.
crash_dump_request_list
:
data_to_dump
.
extend
(
self
.
crash_dump_request_list
)
# Add unfinished requests from rid_to_state
unfinished_requests
=
[]
for
rid
,
state
in
self
.
rid_to_state
.
items
():
if
not
state
.
finished
:
unfinished_requests
.
append
(
(
state
.
obj
,
{},
state
.
created_time
,
time
.
time
())
)
if
unfinished_requests
:
data_to_dump
.
extend
(
unfinished_requests
)
if
not
data_to_dump
:
return
filename
=
os
.
path
.
join
(
self
.
crash_dump_folder
,
os
.
getenv
(
"HOSTNAME"
,
None
),
f
'crash_dump_
{
datetime
.
now
().
strftime
(
"%Y-%m-%d_%H-%M-%S"
)
}
.pkl'
,
)
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
# Include server_args in the dump
data_to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
data_to_dump
,
}
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
data_to_dump_with_server_args
,
f
)
logger
.
error
(
f
"Dumped
{
len
(
self
.
crash_dump_request_list
)
}
finished and
{
len
(
unfinished_requests
)
}
unfinished requests before crash to
{
filename
}
"
)
async
def
sigterm_watchdog
(
self
):
while
not
self
.
gracefully_exit
:
await
asyncio
.
sleep
(
5
)
...
...
@@ -1175,11 +1253,12 @@ class TokenizerManager:
remain_num_req
=
len
(
self
.
rid_to_state
)
if
self
.
health_check_failed
:
# if health check failed, exit immediately
# if health check failed,
we should
exit immediately
logger
.
error
(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d"
,
remain_num_req
,
)
self
.
dump_requests_before_crash
()
break
elif
get_bool_env_var
(
"SGL_FORCE_SHUTDOWN"
):
...
...
@@ -1196,6 +1275,7 @@ class TokenizerManager:
if
remain_num_req
>
0
:
await
asyncio
.
sleep
(
5
)
else
:
self
.
dump_requests_before_crash
()
break
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
...
...
@@ -1273,16 +1353,7 @@ class TokenizerManager:
"meta_info"
:
meta_info
,
}
elif
isinstance
(
recv_obj
,
BatchMultimodalOut
):
if
isinstance
(
recv_obj
.
outputs
[
i
],
str
):
out_dict
=
{
"text"
:
recv_obj
.
outputs
[
i
],
"meta_info"
:
meta_info
,
}
else
:
out_dict
=
{
"outputs"
:
json
.
dumps
(
recv_obj
.
outputs
[
i
]),
"meta_info"
:
meta_info
,
}
raise
NotImplementedError
(
"BatchMultimodalOut not implemented"
)
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
...
...
@@ -1306,6 +1377,8 @@ class TokenizerManager:
self
.
collect_metrics
(
state
,
recv_obj
,
i
)
if
self
.
dump_requests_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
:
self
.
dump_requests
(
state
,
out_dict
)
if
self
.
crash_dump_folder
and
state
.
finished
and
state
.
obj
.
log_metrics
:
self
.
record_request_for_crash_dump
(
state
,
out_dict
)
def
convert_logprob_style
(
self
,
...
...
@@ -1317,6 +1390,9 @@ class TokenizerManager:
recv_obj
:
BatchStrOut
,
recv_obj_index
:
int
,
):
if
recv_obj
.
input_token_logprobs_val
is
None
:
return
if
len
(
recv_obj
.
input_token_logprobs_val
)
>
0
:
state
.
input_token_logprobs_val
.
extend
(
recv_obj
.
input_token_logprobs_val
[
recv_obj_index
]
...
...
@@ -1436,7 +1512,10 @@ class TokenizerManager:
else
0
)
if
state
.
first_token_time
==
0.0
:
if
(
state
.
first_token_time
==
0.0
and
self
.
disaggregation_mode
!=
DisaggregationMode
.
PREFILL
):
state
.
first_token_time
=
state
.
last_time
=
time
.
time
()
state
.
last_completion_tokens
=
completion_tokens
self
.
metrics_collector
.
observe_time_to_first_token
(
...
...
@@ -1484,14 +1563,31 @@ class TokenizerManager:
to_dump
=
self
.
dump_request_list
self
.
dump_request_list
=
[]
to_dump_with_server_args
=
{
"server_args"
:
self
.
server_args
,
"requests"
:
to_dump
,
}
def
background_task
():
os
.
makedirs
(
self
.
dump_requests_folder
,
exist_ok
=
True
)
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
to_dump
,
f
)
pickle
.
dump
(
to_dump
_with_server_args
,
f
)
# Schedule the task to run in the background without awaiting it
asyncio
.
create_task
(
asyncio
.
to_thread
(
background_task
))
def
record_request_for_crash_dump
(
self
,
state
:
ReqState
,
out_dict
:
dict
):
current_time
=
time
.
time
()
self
.
crash_dump_request_list
.
append
(
(
state
.
obj
,
out_dict
,
state
.
created_time
,
current_time
)
)
# Remove requests older than 5 minutes based on finish time
while
(
self
.
crash_dump_request_list
and
current_time
-
self
.
crash_dump_request_list
[
0
][
3
]
>=
300
):
self
.
crash_dump_request_list
.
popleft
()
def
_handle_abort_req
(
self
,
recv_obj
):
self
.
rid_to_state
.
pop
(
recv_obj
.
rid
,
None
)
...
...
@@ -1614,6 +1710,8 @@ async def print_exception_wrapper(func):
except
Exception
:
traceback
=
get_exception_traceback
()
logger
.
error
(
f
"TokenizerManager hit an exception:
{
traceback
}
"
)
if
hasattr
(
func
,
"__self__"
)
and
isinstance
(
func
.
__self__
,
TokenizerManager
):
func
.
__self__
.
dump_requests_before_crash
()
kill_process_tree
(
os
.
getpid
(),
include_parent
=
True
)
sys
.
exit
(
1
)
...
...
@@ -1632,6 +1730,7 @@ class SignalHandler:
logger
.
error
(
"Received sigquit from a child process. It usually means the child failed."
)
self
.
tokenizer_manager
.
dump_requests_before_crash
()
kill_process_tree
(
os
.
getpid
())
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
22352d47
...
...
@@ -123,6 +123,7 @@ class KVCache(abc.ABC):
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
self
.
mem_usage
=
0
# used for chunked cpu-offloading
self
.
cpu_offloading_chunk_size
=
8192
...
...
@@ -219,6 +220,7 @@ class MHATokenToKVPool(KVCache):
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
)
self
.
mem_usage
=
(
k_size
+
v_size
)
/
GB
def
_create_buffers
(
self
):
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
...
...
@@ -695,6 +697,7 @@ class MLATokenToKVPool(KVCache):
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, KV size:
{
kv_size
/
GB
:.
2
f
}
GB"
)
self
.
mem_usage
=
kv_size
/
GB
def
get_kv_size_bytes
(
self
):
assert
hasattr
(
self
,
"kv_buffer"
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
22352d47
...
...
@@ -604,12 +604,13 @@ class ModelRunner:
self
.
dtype
=
self
.
model_config
.
dtype
after_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
weight_load_mem_usage
=
before_avail_memory
-
after_avail_memory
logger
.
info
(
f
"Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"avail mem=
{
after_avail_memory
:.
2
f
}
GB, "
f
"mem usage=
{
(
before_avail_memory
-
after_avail_memory
)
:.
2
f
}
GB."
f
"mem usage=
{
self
.
weight_load_mem_usage
:.
2
f
}
GB."
)
# Handle the case where some ranks do not finish loading.
...
...
@@ -1250,6 +1251,7 @@ class ModelRunner:
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_mem_usage
=
0
if
not
self
.
is_generation
:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
...
...
@@ -1265,9 +1267,10 @@ class ModelRunner:
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
cuda_graph_mem_usage
=
before_mem
-
after_mem
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
perf_counter
()
-
tic
:.
2
f
}
s. "
f
"mem usage=
{
(
before_mem
-
after_mem
)
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
f
"mem usage=
{
self
.
cuda_graph_mem_usage
:.
2
f
}
GB. avail mem=
{
after_mem
:.
2
f
}
GB."
)
def
apply_torch_tp
(
self
):
...
...
python/sglang/srt/model_loader/loader.py
View file @
22352d47
...
...
@@ -534,6 +534,12 @@ class DummyModelLoader(BaseModelLoader):
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
if
get_bool_env_var
(
"SGL_CPU_QUANTIZATION"
):
return
load_model_with_cpu_quantization
(
self
,
model_config
=
model_config
,
device_config
=
device_config
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
...
...
@@ -1464,6 +1470,38 @@ class RemoteModelLoader(BaseModelLoader):
return
model
.
eval
()
def
load_model_with_cpu_quantization
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
)
if
not
isinstance
(
self
,
DummyModelLoader
):
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
model
.
to
(
target_device
)
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
...
...
python/sglang/srt/models/mistral.py
View file @
22352d47
...
...
@@ -13,7 +13,7 @@
# ==============================================================================
"""Inference-only Mistral model."""
from
typing
import
List
,
Union
from
typing
import
List
import
torch
from
transformers.models.mistral3.modeling_mistral3
import
Mistral3MultiModalProjector
...
...
python/sglang/srt/server_args.py
View file @
22352d47
...
...
@@ -99,6 +99,7 @@ class ServerArgs:
log_level_http
:
Optional
[
str
]
=
None
log_requests
:
bool
=
False
log_requests_level
:
int
=
0
crash_dump_folder
:
Optional
[
str
]
=
None
show_time_cost
:
bool
=
False
enable_metrics
:
bool
=
False
bucket_time_to_first_token
:
Optional
[
List
[
float
]]
=
None
...
...
@@ -927,8 +928,14 @@ class ServerArgs:
"--log-requests-level"
,
type
=
int
,
default
=
0
,
help
=
"0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output."
,
choices
=
[
0
,
1
,
2
],
help
=
"0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output."
,
choices
=
[
0
,
1
,
2
,
3
],
)
parser
.
add_argument
(
"--crash-dump-folder"
,
type
=
str
,
default
=
ServerArgs
.
crash_dump_folder
,
help
=
"Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled."
,
)
parser
.
add_argument
(
"--show-time-cost"
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment