Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2d62af6b
Unverified
Commit
2d62af6b
authored
Oct 01, 2025
by
Lianmin Zheng
Committed by
GitHub
Oct 01, 2025
Browse files
Fix metrics and request tracing (TimeStats) (#11123)
parent
a28b394f
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
462 additions
and
393 deletions
+462
-393
python/pyproject.toml
python/pyproject.toml
+20
-21
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+17
-3
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+7
-2
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-10
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+8
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+82
-89
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+96
-33
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+3
-2
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+81
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+35
-102
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+70
-120
python/sglang/srt/tracing/trace.py
python/sglang/srt/tracing/trace.py
+32
-6
No files found.
python/pyproject.toml
View file @
2d62af6b
...
...
@@ -14,18 +14,17 @@ classifiers = [
"License :: OSI Approved :: Apache Software License"
,
]
dependencies
=
[
"aiohttp"
,
"requests"
,
"tqdm"
,
"numpy"
,
"IPython"
,
"setproctitle"
,
"aiohttp"
,
"anthropic>=0.20.0"
,
"blobfile==3.0.0"
,
"build"
,
"compressed-tensors"
,
"cuda-python"
,
"datasets"
,
"einops"
,
"fastapi"
,
"flashinfer_python==0.4.0rc3"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
...
...
@@ -33,8 +32,10 @@ dependencies = [
"modelscope"
,
"msgspec"
,
"ninja"
,
"openai==1.99.1"
,
"numpy"
,
"nvidia-cutlass-dsl==4.2.1"
,
"openai-harmony==0.0.4"
,
"openai==1.99.1"
,
"orjson"
,
"outlines==0.1.11"
,
"packaging"
,
...
...
@@ -42,32 +43,30 @@ dependencies = [
"pillow"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"py-spy"
,
"pybase64"
,
"pydantic"
,
"pynvml"
,
"python-multipart"
,
"pyzmq>=25.1.2"
,
"requests"
,
"scipy"
,
"sentencepiece"
,
"setproctitle"
,
"sgl-kernel==0.3.13"
,
"soundfile==0.13.1"
,
"timm==1.0.16"
,
"tiktoken"
,
"timm==1.0.16"
,
"torch==2.8.0"
,
"torch_memory_saver==0.0.8"
,
"torchao==0.9.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"tqdm"
,
"transformers==4.56.1"
,
"uvicorn"
,
"uvloop"
,
"xgrammar==0.1.24"
,
"sgl-kernel==0.3.13"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"cuda-python"
,
"flashinfer_python==0.4.0rc3"
,
"openai==1.99.1"
,
"tiktoken"
,
"anthropic>=0.20.0"
,
"torch_memory_saver==0.0.8"
,
"nvidia-cutlass-dsl==4.2.1"
,
"xgrammar==0.1.24"
]
[project.optional-dependencies]
...
...
@@ -79,15 +78,15 @@ test = [
"matplotlib"
,
"pandas"
,
"peft"
,
"sentence_transformers"
,
"pytest"
,
"sentence_transformers"
,
"tabulate"
,
]
tracing
=
[
"opentelemetry-sdk"
,
"opentelemetry-api"
,
"opentelemetry-exporter-otlp"
,
"opentelemetry-exporter-otlp-proto-grpc"
,
"opentelemetry-sdk"
,
]
all
=
["sglang[test]
", "
sglang
[decord]"]
blackwell
=
["sglang[test]
", "
sglang
[decord]"]
...
...
python/sglang/srt/disaggregation/decode.py
View file @
2d62af6b
...
...
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from
__future__
import
annotations
import
logging
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
...
...
@@ -422,9 +423,13 @@ class DecodePreallocQueue:
kv_indices
,
self
.
token_to_kv_pool_allocator
.
page_size
)
decode_req
.
kv_receiver
.
init
(
page_indices
,
decode_req
.
metadata_buffer_index
)
decode_req
.
req
.
add_latency
(
RequestStage
.
DECODE_BOOTSTRAP
)
preallocated_reqs
.
append
(
decode_req
)
indices_to_remove
.
add
(
i
)
decode_req
.
req
.
time_stats
.
decode_transfer_queue_entry_time
=
(
time
.
perf_counter
()
)
decode_req
.
req
.
add_latency
(
RequestStage
.
DECODE_BOOTSTRAP
)
self
.
queue
=
[
entry
for
i
,
entry
in
enumerate
(
self
.
queue
)
if
i
not
in
indices_to_remove
...
...
@@ -625,6 +630,7 @@ class DecodeTransferQueue:
decode_req
.
req
.
output_topk_p
=
output_topk_p
decode_req
.
req
.
output_topk_index
=
output_topk_index
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
output_token_logprobs_val
[
0
].
item
()
...
...
@@ -645,10 +651,17 @@ class DecodeTransferQueue:
if
hasattr
(
decode_req
.
kv_receiver
,
"clear"
):
decode_req
.
kv_receiver
.
clear
()
decode_req
.
kv_receiver
=
None
indices_to_remove
.
add
(
i
)
decode_req
.
req
.
time_stats
.
wait_queue_entry_time
=
time
.
perf_counter
()
# special handling for sampling_params.max_new_tokens == 1
if
decode_req
.
req
.
sampling_params
.
max_new_tokens
==
1
:
# finish immediately
decode_req
.
req
.
time_stats
.
forward_entry_time
=
(
decode_req
.
req
.
time_stats
.
completion_time
)
=
time
.
perf_counter
()
decode_req
.
req
.
check_finished
()
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
...
...
@@ -656,8 +669,6 @@ class DecodeTransferQueue:
self
.
tree_cache
.
cache_finished_req
(
decode_req
.
req
)
else
:
transferred_reqs
.
append
(
decode_req
.
req
)
indices_to_remove
.
add
(
i
)
elif
poll
in
[
KVPoll
.
Bootstrapping
,
KVPoll
.
WaitingForInput
,
...
...
@@ -877,6 +888,9 @@ class SchedulerDisaggregationDecodeMixin:
if
len
(
can_run_list
)
==
0
:
return
None
for
req
in
can_run_list
:
req
.
time_stats
.
forward_entry_time
=
time
.
perf_counter
()
# construct a schedule batch with those requests and mark as decode
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
2d62af6b
...
...
@@ -21,6 +21,7 @@ from __future__ import annotations
import
logging
import
threading
import
time
from
collections
import
deque
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Type
...
...
@@ -263,9 +264,10 @@ class PrefillBootstrapQueue:
num_pages
=
kv_to_page_num
(
num_kv_indices
,
self
.
token_to_kv_pool
.
page_size
)
req
.
disagg_kv_sender
.
init
(
num_pages
,
req
.
metadata_buffer_index
)
req
.
add_latency
(
RequestStage
.
PREFILL_BOOTSTRAP
)
bootstrapped_reqs
.
append
(
req
)
indices_to_remove
.
add
(
i
)
req
.
time_stats
.
wait_queue_entry_time
=
time
.
perf_counter
()
req
.
add_latency
(
RequestStage
.
PREFILL_BOOTSTRAP
)
self
.
queue
=
[
entry
for
i
,
entry
in
enumerate
(
self
.
queue
)
if
i
not
in
indices_to_remove
...
...
@@ -407,7 +409,6 @@ class SchedulerDisaggregationPrefillMixin:
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
,
strict
=
True
)
):
req
:
Req
if
req
.
is_chunked
<=
0
:
# There is no output_ids for prefill
req
.
output_ids
.
append
(
next_token_id
)
...
...
@@ -450,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
)
logprob_pt
+=
num_input_logprobs
self
.
send_kv_chunk
(
req
,
last_chunk
=
True
)
req
.
time_stats
.
prefill_transfer_queue_entry_time
=
time
.
perf_counter
()
if
req
.
grammar
is
not
None
:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
...
...
@@ -547,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
else
:
assert
False
,
f
"Unexpected polling state
{
poll
=
}
"
for
req
in
done_reqs
:
req
.
time_stats
.
completion_time
=
time
.
perf_counter
()
# Stream requests which have finished transfer
self
.
stream_output
(
done_reqs
,
...
...
python/sglang/srt/disaggregation/utils.py
View file @
2d62af6b
...
...
@@ -5,7 +5,7 @@ import random
from
collections
import
deque
from
contextlib
import
nullcontext
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Type
import
numpy
as
np
import
torch
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
2d62af6b
...
...
@@ -41,7 +41,7 @@ import time
from
enum
import
Enum
,
auto
from
http
import
HTTPStatus
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -54,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
from
sglang.srt.disaggregation.decode_schedule_batch_mixin
import
(
ScheduleBatchDisaggregationDecodeMixin
,
)
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
...
...
@@ -452,6 +453,7 @@ class Req:
bootstrap_host
:
Optional
[
str
]
=
None
,
bootstrap_port
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
disagg_mode
:
Optional
[
DisaggregationMode
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
vocab_size
:
Optional
[
int
]
=
None
,
priority
:
Optional
[
int
]
=
None
,
...
...
@@ -628,10 +630,8 @@ class Req:
# For metrics
self
.
metrics_collector
=
metrics_collector
self
.
time_stats
:
TimeStats
=
TimeStats
()
self
.
time_stats
:
TimeStats
=
TimeStats
(
disagg_mode
=
disagg_mode
)
self
.
has_log_time_stats
:
bool
=
False
self
.
queue_time_start
=
None
self
.
queue_time_end
=
None
self
.
last_tic
=
time
.
monotonic
()
# For disaggregation
...
...
@@ -668,9 +668,9 @@ class Req:
def
add_latency
(
self
,
stage
:
RequestStage
):
if
self
.
metrics_collector
is
None
:
return
assert
stage
.
name
in
RequestStage
.
__members__
,
f
"
{
stage
=
}
is invalid"
now
=
time
.
monotonic
()
self
.
metrics_collector
.
observe_
request_latency_seconds
(
self
.
metrics_collector
.
observe_
per_stage_req_latency
(
stage
.
value
,
now
-
self
.
last_tic
)
self
.
last_tic
=
now
...
...
@@ -834,10 +834,10 @@ class Req:
return
if
self
.
bootstrap_room
is
not
None
:
prefix
=
f
"Req Time Stats(rid=
{
self
.
rid
}
, bootstrap_room=
{
self
.
bootstrap_room
}
, input len=
{
len
(
self
.
origin_input_ids
)
}
, output len=
{
len
(
self
.
output_ids
)
}
, type=
{
self
.
time_stats
.
get_type
().
value
}
)"
prefix
=
f
"Req Time Stats(rid=
{
self
.
rid
}
, bootstrap_room=
{
self
.
bootstrap_room
}
, input len=
{
len
(
self
.
origin_input_ids
)
}
, output len=
{
len
(
self
.
output_ids
)
}
, type=
{
self
.
time_stats
.
disagg_mode_str
()
}
)"
else
:
prefix
=
f
"Req Time Stats(rid=
{
self
.
rid
}
, input len=
{
len
(
self
.
origin_input_ids
)
}
, output len=
{
len
(
self
.
output_ids
)
}
, type=
{
self
.
time_stats
.
get_type
().
value
}
)"
logger
.
info
(
f
"
{
prefix
}
:
{
self
.
time_stats
}
"
)
prefix
=
f
"Req Time Stats(rid=
{
self
.
rid
}
, input len=
{
len
(
self
.
origin_input_ids
)
}
, output len=
{
len
(
self
.
output_ids
)
}
, type=
{
self
.
time_stats
.
disagg_mode_str
()
}
)"
logger
.
info
(
f
"
{
prefix
}
:
{
self
.
time_stats
.
convert_to_duration
()
}
"
)
self
.
has_log_time_stats
=
True
def
set_finish_with_abort
(
self
,
error_msg
:
str
):
...
...
@@ -1544,7 +1544,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
/
total_max_new_tokens
new_estimate_ratio
=
min
(
1.0
,
new_estimate_ratio
)
return
retracted_reqs
,
new_estimate_ratio
return
retracted_reqs
,
new_estimate_ratio
,
[]
def
release_req
(
self
,
idx
:
int
,
remaing_req_count
:
int
,
server_args
:
ServerArgs
):
req
=
self
.
reqs
[
idx
]
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
2d62af6b
...
...
@@ -276,9 +276,13 @@ class SchedulePolicy:
)
->
None
:
"""Sorts the waiting queue based on the request priority then received titmestamp."""
if
schedule_low_priority_values_first
:
waiting_queue
.
sort
(
key
=
lambda
x
:
(
x
.
priority
,
x
.
queue_time_start
))
waiting_queue
.
sort
(
key
=
lambda
x
:
(
x
.
priority
,
x
.
time_stats
.
wait_queue_entry_time
)
)
else
:
waiting_queue
.
sort
(
key
=
lambda
x
:
(
-
x
.
priority
,
x
.
queue_time_start
))
waiting_queue
.
sort
(
key
=
lambda
x
:
(
-
x
.
priority
,
x
.
time_stats
.
wait_queue_entry_time
)
)
@
staticmethod
def
_calc_weight
(
cur_node
:
TreeNode
,
node_to_weight
:
Dict
[
TreeNode
,
int
])
->
None
:
...
...
@@ -642,12 +646,12 @@ class PrefillAdder:
if
server_args
.
schedule_low_priority_values_first
:
sorted_running_reqs
=
sorted
(
self
.
running_batch
.
reqs
,
key
=
lambda
x
:
(
-
x
.
priority
,
-
x
.
queue_
time_sta
r
t
),
key
=
lambda
x
:
(
-
x
.
priority
,
-
x
.
time_stat
s
.
wait_queue_entry_time
),
)
else
:
sorted_running_reqs
=
sorted
(
self
.
running_batch
.
reqs
,
key
=
lambda
x
:
(
x
.
priority
,
-
x
.
queue_
time_sta
r
t
),
key
=
lambda
x
:
(
x
.
priority
,
-
x
.
time_stat
s
.
wait_queue_entry_time
),
)
preemptible_reqs
=
[]
min_tokens_to_remove
=
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
2d62af6b
...
...
@@ -157,10 +157,9 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.tracing.trace
import
(
process_tracing_init
,
trace_event
,
trace_set_proc_propagate_context
,
trace_set_thread_info
,
trace_slice
,
trace_slice
_batch
,
trace_slice_end
,
trace_slice_start
,
)
...
...
@@ -263,6 +262,7 @@ class Scheduler(
server_args
.
enable_metrics_for_all_schedulers
)
self
.
enable_kv_cache_events
=
server_args
.
kv_events_config
and
tp_rank
==
0
self
.
enable_trace
=
server_args
.
enable_trace
self
.
stream_interval
=
server_args
.
stream_interval
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
...
...
@@ -899,10 +899,6 @@ class Scheduler(
batch
=
self
.
get_next_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
for
req
in
batch
.
reqs
:
trace_event
(
"schedule"
,
req
.
rid
)
if
batch
:
result
=
self
.
run_batch
(
batch
)
self
.
process_batch_result
(
batch
,
result
)
...
...
@@ -924,10 +920,6 @@ class Scheduler(
batch
=
self
.
get_next_batch_to_run
()
self
.
cur_batch
=
batch
if
batch
:
for
req
in
batch
.
reqs
:
trace_event
(
"schedule"
,
req
.
rid
)
if
batch
:
batch
.
launch_done
=
threading
.
Event
()
result
=
self
.
run_batch
(
batch
)
...
...
@@ -1192,10 +1184,13 @@ class Scheduler(
src
=
self
.
tp_group
.
ranks
[
0
],
)
for
req
in
recv_reqs
:
if
isinstance
(
req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)):
trace_set_proc_propagate_context
(
req
.
rid
,
req
.
trace_context
)
trace_slice_start
(
""
,
req
.
rid
,
anonymous
=
True
)
if
self
.
enable_trace
:
for
req
in
recv_reqs
:
if
isinstance
(
req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
):
trace_set_proc_propagate_context
(
req
.
rid
,
req
.
trace_context
)
trace_slice_start
(
""
,
req
.
rid
,
anonymous
=
True
)
return
recv_reqs
...
...
@@ -1277,6 +1272,7 @@ class Scheduler(
bootstrap_host
=
recv_req
.
bootstrap_host
,
bootstrap_port
=
recv_req
.
bootstrap_port
,
bootstrap_room
=
recv_req
.
bootstrap_room
,
disagg_mode
=
self
.
disaggregation_mode
,
data_parallel_rank
=
recv_req
.
data_parallel_rank
,
vocab_size
=
self
.
model_config
.
vocab_size
,
priority
=
recv_req
.
priority
,
...
...
@@ -1403,7 +1399,6 @@ class Scheduler(
req
.
set_finish_with_abort
(
error_msg
)
if
add_to_grammar_queue
:
req
.
queue_time_start
=
time
.
perf_counter
()
self
.
grammar_queue
.
append
(
req
)
else
:
self
.
_add_request_to_queue
(
req
)
...
...
@@ -1419,23 +1414,6 @@ class Scheduler(
for
tokenized_req
in
recv_req
:
self
.
handle_generate_request
(
tokenized_req
)
def
_add_request_to_queue
(
self
,
req
:
Req
):
req
.
queue_time_start
=
time
.
perf_counter
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
_prefetch_kvcache
(
req
)
self
.
disagg_prefill_bootstrap_queue
.
add
(
req
,
self
.
model_config
.
num_key_value_heads
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
disagg_decode_prealloc_queue
.
add
(
req
)
else
:
self
.
_set_or_validate_priority
(
req
)
if
self
.
_abort_on_queued_limit
(
req
):
return
self
.
_prefetch_kvcache
(
req
)
self
.
waiting_queue
.
append
(
req
)
trace_slice_end
(
"process req"
,
req
.
rid
,
auto_next_anon
=
True
)
def
_prefetch_kvcache
(
self
,
req
:
Req
):
if
self
.
enable_hicache_storage
:
req
.
init_next_round_input
(
self
.
tree_cache
)
...
...
@@ -1449,19 +1427,27 @@ class Scheduler(
req
.
rid
,
req
.
last_host_node
,
new_input_tokens
,
last_hash
)
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
],
is_retracted
:
bool
=
False
):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
disagg_prefill_bootstrap_queue
.
extend
(
reqs
,
self
.
model_config
.
num_key_value_heads
def
_add_request_to_queue
(
self
,
req
:
Req
,
is_retracted
:
bool
=
False
):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
NULL
:
self
.
_set_or_validate_priority
(
req
)
if
self
.
_abort_on_queued_limit
(
req
):
return
self
.
_prefetch_kvcache
(
req
)
self
.
waiting_queue
.
append
(
req
)
req
.
time_stats
.
wait_queue_entry_time
=
time
.
perf_counter
()
trace_slice_end
(
"process req"
,
req
.
rid
,
auto_next_anon
=
True
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
_prefetch_kvcache
(
req
)
self
.
disagg_prefill_bootstrap_queue
.
add
(
req
,
self
.
model_config
.
num_key_value_heads
)
req
.
time_stats
.
prefill_bootstrap_queue_entry_time
=
time
.
perf_counter
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
# If this is a decode server, we put the request to the decode pending prealloc queue
self
.
disagg_decode_prealloc_queue
.
extend
(
reqs
,
is_retracted
)
self
.
disagg_decode_prealloc_queue
.
add
(
req
,
is_retracted
=
is_retracted
)
if
not
is_retracted
:
req
.
time_stats
.
decode_prealloc_queue_entry_time
=
time
.
perf_counter
()
else
:
for
req
in
reqs
:
self
.
_set_or_validate_priority
(
req
)
if
not
self
.
_abort_on_queued_limit
(
req
):
self
.
waiting_queue
.
append
(
req
)
raise
ValueError
(
f
"Invalid
{
self
.
disaggregation_mode
=
}
"
)
def
_set_or_validate_priority
(
self
,
req
:
Req
):
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
...
...
@@ -1500,7 +1486,7 @@ class Scheduler(
direction
=
1
if
self
.
schedule_low_priority_values_first
else
-
1
key_fn
=
lambda
item
:
(
direction
*
item
[
1
].
priority
,
item
[
1
].
queue_
time_sta
r
t
,
item
[
1
].
time_stat
s
.
wait_queue_entry_time
,
)
idx
,
candidate_req
=
max
(
enumerate
(
self
.
waiting_queue
),
key
=
key_fn
)
abort_existing_req
=
(
...
...
@@ -1902,14 +1888,14 @@ class Scheduler(
if
self
.
enable_metrics
:
# only record queue time when enable_metrics is True to avoid overhead
for
req
in
can_run_list
:
req
.
queue_time_end
=
time
.
perf_counter
()
req
.
add_latency
(
RequestStage
.
PREFILL_WAITING
)
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
if
adder
.
preempt_list
:
self
.
_extend_requests_to_queue
(
adder
.
preempt_list
)
for
req
in
adder
.
preempt_list
:
self
.
_add_request_to_queue
(
req
)
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
...
...
@@ -1920,7 +1906,16 @@ class Scheduler(
# Print stats
if
self
.
current_scheduler_metrics_enabled
():
self
.
log_prefill_stats
(
adder
,
can_run_list
,
running_bs
)
self
.
log_prefill_stats
(
adder
,
can_run_list
,
running_bs
,
0
)
for
req
in
can_run_list
:
if
req
.
time_stats
.
forward_entry_time
==
0
:
# Avoid update chunked request many times
req
.
time_stats
.
forward_entry_time
=
time
.
perf_counter
()
if
self
.
enable_metrics
:
self
.
metrics_collector
.
observe_queue_time
(
req
.
time_stats
.
get_queueing_time
(),
)
# Create a new batch
new_batch
=
ScheduleBatch
.
init_new
(
...
...
@@ -1975,19 +1970,25 @@ class Scheduler(
TEST_RETRACT
and
batch
.
batch_size
()
>
10
):
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
(
self
.
server_args
)
num_retracted_reqs
=
len
(
retracted_reqs
)
retracted_reqs
,
new_token_ratio
,
reqs_to_abort
=
batch
.
retract_decode
(
self
.
server_args
)
self
.
num_retracted_reqs
=
len
(
retracted_reqs
)
self
.
new_token_ratio
=
new_token_ratio
for
req
in
reqs_to_abort
:
self
.
send_to_tokenizer
.
send_pyobj
(
AbortReq
(
req
.
rid
,
abort_reason
=
req
.
to_abort_message
)
)
logger
.
info
(
"KV cache pool is full. Retract requests. "
f
"#retracted_reqs:
{
num_retracted_reqs
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
self
.
new_token_ratio
:.
4
f
}
"
f
"#retracted_reqs:
{
len
(
retracted_reqs
)
}
, "
f
"#aborted_retracted_reqs:
{
len
(
reqs_to_abort
)
}
, "
f
"#new_token_ratio:
{
old_ratio
:.
4
f
}
->
{
new_token_ratio
:.
4
f
}
"
)
self
.
_extend_requests_to_queue
(
retracted_reqs
,
is_
retracted
=
True
)
self
.
total_retracted_reqs
+=
num
_retracted
_reqs
for
req
in
retracted
_reqs
:
self
.
_add_request_to_queue
(
req
,
is
_retracted
=
True
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
...
...
@@ -2086,23 +2087,14 @@ class Scheduler(
):
if
batch
.
forward_mode
.
is_decode
():
self
.
process_batch_result_decode
(
batch
,
result
,
launch_done
)
for
req
in
batch
.
reqs
:
trace_slice
(
"decode loop"
,
req
.
rid
,
auto_next_anon
=
not
req
.
finished
(),
thread_finish_flag
=
req
.
finished
(),
)
if
self
.
enable_trace
:
trace_slice_batch
(
"decode loop"
,
batch
.
reqs
)
elif
batch
.
forward_mode
.
is_extend
():
self
.
process_batch_result_prefill
(
batch
,
result
,
launch_done
)
for
req
in
batch
.
reqs
:
trace_slice
(
"prefill"
,
req
.
rid
,
auto_next_anon
=
not
req
.
finished
(),
thread_finish_flag
=
req
.
finished
(),
)
if
self
.
enable_trace
:
trace_slice_batch
(
"prefill"
,
batch
.
reqs
)
elif
batch
.
forward_mode
.
is_idle
():
if
self
.
enable_overlap
:
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
...
...
@@ -2261,12 +2253,13 @@ class Scheduler(
if
req
.
finished
():
# It is aborted by AbortReq
num_ready_reqs
+=
1
continue
req
.
grammar
=
req
.
grammar
.
result
(
timeout
=
0.03
)
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
if
req
.
grammar
is
INVALID_GRAMMAR_OBJ
:
req
.
set_finish_with_abort
(
f
"Invalid grammar request:
{
req
.
grammar_key
=
}
"
)
error_msg
=
f
"Invalid grammar request:
{
req
.
grammar_key
=
}
"
req
.
set_finish_with_abort
(
error_msg
)
num_ready_reqs
+=
1
except
futures
.
_base
.
TimeoutError
:
req
.
grammar_wait_ct
+=
1
...
...
@@ -2298,9 +2291,8 @@ class Scheduler(
req
.
grammar
=
req
.
grammar
.
result
()
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
req
.
grammar
.
copy
())
if
req
.
grammar
is
INVALID_GRAMMAR_OBJ
:
req
.
set_finish_with_abort
(
f
"Invalid grammar request:
{
req
.
grammar_key
=
}
"
)
error_msg
=
f
"Invalid grammar request:
{
req
.
grammar_key
=
}
"
req
.
set_finish_with_abort
(
error_msg
)
else
:
num_ready_reqs_max
=
num_ready_reqs
num_timeout_reqs_max
=
num_timeout_reqs
...
...
@@ -2308,12 +2300,14 @@ class Scheduler(
for
i
in
range
(
num_ready_reqs
,
num_ready_reqs
+
num_timeout_reqs_max
):
req
=
self
.
grammar_queue
[
i
]
req
.
grammar
.
cancel
()
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
INVALID_GRAMMAR_OBJ
)
error_msg
=
f
"Grammar preprocessing timed out for
{
req
.
grammar_key
=
}
"
req
.
set_finish_with_abort
(
error_msg
)
self
.
grammar_backend
.
set_cache
(
req
.
grammar_key
,
INVALID_GRAMMAR_OBJ
)
num_ready_reqs
=
num_ready_reqs_max
+
num_timeout_reqs_max
self
.
_extend_requests_to_queue
(
self
.
grammar_queue
[:
num_ready_reqs
])
for
req
in
self
.
grammar_queue
[:
num_ready_reqs
]:
self
.
_add_request_to_queue
(
req
)
self
.
grammar_queue
=
self
.
grammar_queue
[
num_ready_reqs
:]
def
set_next_batch_sampling_info_done
(
self
,
batch
:
ScheduleBatch
):
...
...
@@ -2795,17 +2789,11 @@ def run_scheduler_process(
pipe_writer
,
balance_meta
:
Optional
[
DPBalanceMeta
]
=
None
,
):
if
server_args
.
enable_trace
:
process_tracing_init
(
server_args
.
oltp_traces_endpoint
,
"sglang"
)
if
server_args
.
disaggregation_mode
==
"null"
:
thread_label
=
"Scheduler"
trace_set_thread_info
(
thread_label
,
tp_rank
,
dp_rank
)
if
(
numa_node
:
=
server_args
.
numa_node
)
is
not
None
:
numa_bind_to_node
(
numa_node
[
gpu_id
])
# Generate the prefix
# Generate the logger prefix
prefix
=
""
if
dp_rank
is
None
and
"SGLANG_DP_RANK"
in
os
.
environ
:
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
dp_rank
=
int
(
os
.
environ
[
"SGLANG_DP_RANK"
])
if
dp_rank
is
not
None
:
prefix
+=
f
" DP
{
dp_rank
}
"
if
server_args
.
tp_size
>
1
:
...
...
@@ -2821,10 +2809,6 @@ def run_scheduler_process(
kill_itself_when_parent_died
()
parent_process
=
psutil
.
Process
().
parent
()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if
dp_rank
is
None
and
"SGLANG_DP_RANK"
in
os
.
environ
:
dp_rank
=
int
(
os
.
environ
[
"SGLANG_DP_RANK"
])
# Configure the logger
configure_logger
(
server_args
,
prefix
=
prefix
)
suppress_other_loggers
()
...
...
@@ -2832,6 +2816,15 @@ def run_scheduler_process(
# Set cpu affinity to this gpu process
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
if
(
numa_node
:
=
server_args
.
numa_node
)
is
not
None
:
numa_bind_to_node
(
numa_node
[
gpu_id
])
# Set up tracing
if
server_args
.
enable_trace
:
process_tracing_init
(
server_args
.
oltp_traces_endpoint
,
"sglang"
)
if
server_args
.
disaggregation_mode
==
"null"
:
thread_label
=
"Scheduler"
trace_set_thread_info
(
thread_label
,
tp_rank
,
dp_rank
)
# Create a scheduler and run the event loop
try
:
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
2d62af6b
...
...
@@ -47,8 +47,11 @@ class SchedulerMetricsMixin:
self
.
spec_num_total_forward_ct
=
0
self
.
cum_spec_accept_length
=
0
self
.
cum_spec_accept_count
=
0
self
.
total_retracted_reqs
=
0
self
.
kv_transfer_speed_gb_s
:
float
=
0.0
self
.
kv_transfer_latency_ms
:
float
=
0.0
self
.
stats
=
SchedulerStats
()
if
self
.
enable_metrics
:
engine_type
=
"unified"
labels
=
{
...
...
@@ -82,12 +85,14 @@ class SchedulerMetricsMixin:
adder
:
PrefillAdder
,
can_run_list
:
List
[
Req
],
running_bs
:
int
,
running_bs_offline_batch
:
int
,
):
gap_latency
=
time
.
perf_counter
()
-
self
.
last_prefill_stats_tic
self
.
last_prefill_stats_tic
=
time
.
perf_counter
()
self
.
last_input_throughput
=
self
.
last_prefill_tokens
/
gap_latency
self
.
last_prefill_tokens
=
adder
.
log_input_tokens
# TODO: generalize this for various memory pools
if
self
.
is_hybrid
:
(
full_num_used
,
...
...
@@ -101,51 +106,53 @@ class SchedulerMetricsMixin:
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
token_
usage_
msg
=
(
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"swa token usage:
{
swa_token_usage
:.
2
f
}
, "
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"token usage:
{
token_usage
:.
2
f
}
, "
token_
usage_
msg
=
f
"token usage:
{
token_usage
:.
2
f
}
, "
num_new_seq
=
len
(
can_run_list
)
f
=
(
f
"Prefill batch. "
f
"#new-seq:
{
num_new_seq
}
, "
f
"#new-seq:
{
len
(
can_run_list
)
}
, "
f
"#new-token:
{
adder
.
log_input_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"
{
token_msg
}
"
f
"
{
token_usage_msg
}
"
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
f
+=
f
"#unbootstrapped-req:
{
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
+=
f
"#transferring-req:
{
len
(
self
.
disagg_prefill_inflight_queue
)
}
, "
f
+=
f
"input throughput (token/s):
{
self
.
last_input_throughput
:.
2
f
}
, "
else
:
f
+=
f
"#running-req:
{
running_bs
}
, "
f
+=
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
f
+=
f
"#prealloc-req:
{
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
}
, "
f
+=
f
"#inflight-req:
{
len
(
self
.
disagg_prefill_inflight_queue
)
}
, "
logger
.
info
(
f
)
if
self
.
enable_metrics
:
# Basics
total_tokens
=
adder
.
log_input_tokens
+
adder
.
log_hit_tokens
cache_hit_rate
=
(
adder
.
log_hit_tokens
/
total_tokens
if
total_tokens
>
0
else
0.0
)
self
.
stats
.
num_running_reqs
=
running_bs
self
.
stats
.
num_running_reqs_offline_batch
=
running_bs_offline_batch
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
token_usage
=
token_usage
if
self
.
is_hybrid
:
self
.
stats
.
swa_token_usage
=
swa_token_usage
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
cache_hit_rate
=
cache_hit_rate
total_queue_latency
=
0
for
req
in
can_run_list
:
total_queue_latency
+=
req
.
queue_time_end
-
req
.
queue_time_start
self
.
stats
.
avg_request_queue_latency
=
total_queue_latency
/
num_new_seq
# Retract
self
.
stats
.
num_retracted_reqs
=
self
.
num_retracted_reqs
self
.
stats
.
num_paused_reqs
=
self
.
num_paused_reqs
self
.
num_retracted_reqs
=
self
.
num_paused_reqs
=
0
# PD disaggregation
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
stats
.
num_prefill_prealloc_queue_reqs
=
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
...
...
@@ -153,7 +160,18 @@ class SchedulerMetricsMixin:
self
.
stats
.
num_prefill_inflight_queue_reqs
=
len
(
self
.
disagg_prefill_inflight_queue
)
self
.
stats
.
kv_transfer_speed_gb_s
=
self
.
kv_transfer_speed_gb_s
self
.
stats
.
kv_transfer_latency_ms
=
self
.
kv_transfer_latency_ms
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
stats
.
num_decode_prealloc_queue_reqs
=
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
self
.
stats
.
num_decode_transfer_queue_reqs
=
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
# Others
self
.
calculate_utilization
()
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
...
...
@@ -166,8 +184,12 @@ class SchedulerMetricsMixin:
gap_latency
=
time
.
perf_counter
()
-
self
.
last_decode_stats_tic
self
.
last_decode_stats_tic
=
time
.
perf_counter
()
self
.
last_gen_throughput
=
self
.
num_generated_tokens
/
gap_latency
self
.
num_generated_tokens
=
0
num_running_reqs
=
len
(
batch
.
reqs
)
num_running_reqs_offline_batch
=
0
# TODO: generalize this for various memory pools
if
self
.
is_hybrid
:
(
full_num_used
,
...
...
@@ -181,7 +203,7 @@ class SchedulerMetricsMixin:
)
=
self
.
_get_swa_token_info
()
num_used
=
max
(
full_num_used
,
swa_num_used
)
token_usage
=
max
(
full_token_usage
,
swa_token_usage
)
token_msg
=
(
token_
usage_
msg
=
(
f
"#full token:
{
full_num_used
}
, "
f
"full token usage:
{
full_token_usage
:.
2
f
}
, "
f
"#swa token:
{
swa_num_used
}
, "
...
...
@@ -189,14 +211,14 @@ class SchedulerMetricsMixin:
)
else
:
num_used
,
token_usage
,
_
,
_
=
self
.
_get_token_info
()
token_msg
=
f
"#token:
{
num_used
}
,
"
f
"
token usage:
{
token_usage
:.
2
f
}
, "
token_
usage_
msg
=
f
"#token:
{
num_used
}
, token usage:
{
token_usage
:.
2
f
}
, "
if
RECORD_STEP_TIME
:
self
.
step_time_dict
[
num_running_reqs
].
append
(
gap_latency
/
self
.
server_args
.
decode_log_interval
)
msg
=
f
"Decode batch. #running-req:
{
num_running_reqs
}
,
{
token_msg
}
"
msg
=
f
"Decode batch. #running-req:
{
num_running_reqs
}
,
{
token_
usage_
msg
}
"
if
self
.
spec_algorithm
.
is_none
():
spec_accept_length
=
0
...
...
@@ -208,41 +230,66 @@ class SchedulerMetricsMixin:
self
.
cum_spec_accept_count
+=
self
.
spec_num_total_forward_ct
self
.
spec_num_total_accepted_tokens
=
self
.
spec_num_total_forward_ct
=
0
msg
+=
f
"accept len:
{
spec_accept_length
:.
2
f
}
, "
cache_hit_rate
=
0.0
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
msg
+=
f
"pre-allocated usage:
{
self
.
disagg_decode_prealloc_queue
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
f
"#prealloc-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
}
, "
msg
+=
f
"#transfer-req:
{
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
}
, "
msg
+=
f
"#retracted-req:
{
len
(
self
.
disagg_decode_prealloc_queue
.
retracted_queue
)
}
, "
msg
+=
(
f
"
{
'c
p
u graph'
if
self
.
device
==
'c
p
u'
else
'cu
da
graph'
}
:
{
can_run_cuda_graph
}
, "
f
"
{
'cu
da
graph'
if
self
.
device
==
'cu
da
'
else
'c
p
u graph'
}
:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
, "
)
logger
.
info
(
msg
)
if
self
.
enable_metrics
:
# Basics
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_running_reqs_offline_batch
=
num_running_reqs_offline_batch
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
token_usage
=
round
(
token_usage
,
2
)
self
.
stats
.
cache_hit_rate
=
0.0
self
.
stats
.
token_usage
=
token_usage
if
self
.
is_hybrid
:
self
.
stats
.
swa_token_usage
=
swa_token_usage
self
.
stats
.
gen_throughput
=
self
.
last_gen_throughput
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
cache_hit_rate
=
cache_hit_rate
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
stats
.
total_retracted_reqs
=
self
.
total_retracted_reqs
self
.
stats
.
avg_request_queue_latency
=
0.0
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
# Retract
self
.
stats
.
num_retracted_reqs
=
self
.
num_retracted_reqs
self
.
stats
.
num_paused_reqs
=
self
.
num_paused_reqs
self
.
num_retracted_reqs
=
self
.
num_paused_reqs
=
0
# PD disaggregation
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
stats
.
num_prefill_prealloc_queue_reqs
=
len
(
self
.
disagg_prefill_bootstrap_queue
.
queue
)
self
.
stats
.
num_prefill_inflight_queue_reqs
=
len
(
self
.
disagg_prefill_inflight_queue
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
stats
.
num_decode_prealloc_queue_reqs
=
len
(
self
.
disagg_decode_prealloc_queue
.
queue
)
self
.
stats
.
num_decode_transfer_queue_reqs
=
len
(
self
.
disagg_decode_transfer_queue
.
queue
)
# Others
self
.
calculate_utilization
()
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_emit_kv_metrics
()
self
.
_publish_kv_events
()
def
_emit_kv_metrics
(
self
:
Scheduler
):
if
not
self
.
enable_kv_cache_events
:
return
kv_metrics
=
KvMetrics
()
kv_metrics
.
request_active_slots
=
self
.
stats
.
num_running_reqs
kv_metrics
.
request_total_slots
=
self
.
max_running_requests
...
...
@@ -259,11 +306,13 @@ class SchedulerMetricsMixin:
self
.
send_metrics_from_scheduler
.
send_pyobj
(
kv_metrics
)
def
_publish_kv_events
(
self
:
Scheduler
):
if
self
.
enable_kv_cache_events
:
events
=
self
.
tree_cache
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
if
not
self
.
enable_kv_cache_events
:
return
events
=
self
.
tree_cache
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
def
maybe_update_dp_balance_data
(
self
:
Scheduler
,
recv_req
:
TokenizedGenerateReqInput
...
...
@@ -349,3 +398,17 @@ class SchedulerMetricsMixin:
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta
.
set_shared_onfly_info
(
onfly_list
)
meta
.
set_shared_local_tokens
(
local_tokens
)
def
calculate_utilization
(
self
):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
stats
.
utilization
=
-
1
else
:
if
(
self
.
stats
.
max_running_requests_under_SLO
is
not
None
and
self
.
stats
.
max_running_requests_under_SLO
>
0
):
self
.
stats
.
utilization
=
max
(
self
.
stats
.
num_running_reqs
/
self
.
stats
.
max_running_requests_under_SLO
,
self
.
stats
.
token_usage
/
0.9
,
)
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
2d62af6b
...
...
@@ -91,7 +91,7 @@ class SchedulerOutputProcessorMixin:
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
req
.
time_stats
.
completion_time
=
time
.
time
()
req
.
time_stats
.
completion_time
=
time
.
perf_counter
()
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
# This updates radix so others can match
self
.
tree_cache
.
cache_unfinished_req
(
req
)
...
...
@@ -257,7 +257,7 @@ class SchedulerOutputProcessorMixin:
else
:
self
.
tree_cache
.
cache_finished_req
(
req
)
req
.
time_stats
.
completion_time
=
time
.
time
()
req
.
time_stats
.
completion_time
=
time
.
perf_counter
()
if
req
.
return_logprob
and
batch
.
spec_algorithm
.
is_none
():
# speculative worker handles logprob in speculative decoding
...
...
@@ -707,6 +707,7 @@ class SchedulerOutputProcessorMixin:
and
self
.
tp_rank
==
0
and
self
.
server_args
.
enable_request_time_stats_logging
):
print
(
f
"
{
req
.
finished_reason
=
}
"
)
req
.
log_time_stats
()
# Send to detokenizer
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
View file @
2d62af6b
...
...
@@ -5,6 +5,7 @@ import copy
import
logging
import
os
import
time
import
uuid
from
collections
import
deque
from
typing
import
(
TYPE_CHECKING
,
...
...
@@ -24,6 +25,7 @@ import zmq
from
sglang.srt.managers.io_struct
import
(
ClearHiCacheReqInput
,
ClearHiCacheReqOutput
,
CloseSessionReqInput
,
DestroyWeightsUpdateGroupReqInput
,
DestroyWeightsUpdateGroupReqOutput
,
ExpertDistributionReq
,
...
...
@@ -44,6 +46,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqOutput
,
LoRAUpdateResult
,
MultiTokenizerWrapper
,
OpenSessionReqInput
,
ProfileReq
,
ProfileReqOutput
,
ProfileReqType
,
...
...
@@ -588,3 +591,81 @@ class TokenizerCommunicatorMixin:
async
def
get_load
(
self
:
TokenizerManager
)
->
List
[
GetLoadReqOutput
]:
req
=
GetLoadReqInput
()
return
await
self
.
get_load_communicator
(
req
)
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
self
.
auto_create_handle_loop
()
if
obj
.
session_id
is
None
:
obj
.
session_id
=
uuid
.
uuid4
().
hex
elif
obj
.
session_id
in
self
.
session_futures
:
return
None
if
self
.
server_args
.
tokenizer_worker_num
>
1
:
obj
=
MultiTokenizerWrapper
(
self
.
worker_id
,
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
session_futures
[
obj
.
session_id
]
=
asyncio
.
Future
()
session_id
=
await
self
.
session_futures
[
obj
.
session_id
]
del
self
.
session_futures
[
obj
.
session_id
]
return
session_id
async
def
close_session
(
self
,
obj
:
CloseSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
def
get_log_request_metadata
(
self
):
max_length
=
None
skip_names
=
None
out_skip_names
=
None
if
self
.
log_requests
:
if
self
.
log_requests_level
==
0
:
max_length
=
1
<<
30
skip_names
=
set
(
[
"text"
,
"input_ids"
,
"input_embeds"
,
"image_data"
,
"audio_data"
,
"lora_path"
,
"sampling_params"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
"embedding"
,
]
)
elif
self
.
log_requests_level
==
1
:
max_length
=
1
<<
30
skip_names
=
set
(
[
"text"
,
"input_ids"
,
"input_embeds"
,
"image_data"
,
"audio_data"
,
"lora_path"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
"embedding"
,
]
)
elif
self
.
log_requests_level
==
2
:
max_length
=
2048
elif
self
.
log_requests_level
==
3
:
max_length
=
1
<<
30
else
:
raise
ValueError
(
f
"Invalid --log-requests-level:
{
self
.
log_requests_level
=
}
"
)
return
max_length
,
skip_names
,
out_skip_names
python/sglang/srt/managers/tokenizer_manager.py
View file @
2d62af6b
...
...
@@ -164,6 +164,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else
None
)
self
.
crash_dump_folder
=
server_args
.
crash_dump_folder
self
.
enable_trace
=
server_args
.
enable_trace
# Read model args
self
.
model_path
=
server_args
.
model_path
...
...
@@ -381,23 +382,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# If it's a single value, add worker_id prefix
obj
.
rid
=
f
"
{
self
.
worker_id
}
_
{
obj
.
rid
}
"
if
obj
.
is_single
:
bootstrap_room
=
(
obj
.
bootstrap_room
if
hasattr
(
obj
,
"bootstrap_room"
)
else
None
)
trace_req_start
(
obj
.
rid
,
bootstrap_room
,
ts
=
int
(
created_time
*
1e9
))
trace_slice_start
(
""
,
obj
.
rid
,
ts
=
int
(
created_time
*
1e9
),
anonymous
=
True
)
else
:
for
i
in
range
(
len
(
obj
.
rid
)):
bootstrap_room
=
(
obj
.
bootstrap_room
[
i
]
if
hasattr
(
obj
,
"bootstrap_room"
)
and
obj
.
bootstrap_room
else
None
)
trace_req_start
(
obj
.
rid
[
i
],
bootstrap_room
,
ts
=
int
(
created_time
*
1e9
))
trace_slice_start
(
""
,
obj
.
rid
[
i
],
ts
=
int
(
created_time
*
1e9
),
anonymous
=
True
)
if
self
.
enable_trace
:
self
.
_trace_request_start
(
obj
,
created_time
)
if
self
.
log_requests
:
max_length
,
skip_names
,
_
=
self
.
log_request_metadata
...
...
@@ -1055,7 +1041,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
req
=
AbortReq
(
rid
,
abort_all
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
if
self
.
enable_metrics
:
self
.
metrics_collector
.
observe_one_aborted_request
()
# TODO: also use custom_labels from the request
self
.
metrics_collector
.
observe_one_aborted_request
(
self
.
metrics_collector
.
labels
)
async
def
pause_generation
(
self
):
async
with
self
.
is_pause_cond
:
...
...
@@ -1117,84 +1106,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
all_paused_requests
=
[
r
.
num_paused_requests
for
r
in
result
]
return
all_success
,
all_message
,
all_paused_requests
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
self
.
auto_create_handle_loop
()
if
obj
.
session_id
is
None
:
obj
.
session_id
=
uuid
.
uuid4
().
hex
elif
obj
.
session_id
in
self
.
session_futures
:
return
None
if
self
.
server_args
.
tokenizer_worker_num
>
1
:
obj
=
MultiTokenizerWrapper
(
self
.
worker_id
,
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
session_futures
[
obj
.
session_id
]
=
asyncio
.
Future
()
session_id
=
await
self
.
session_futures
[
obj
.
session_id
]
del
self
.
session_futures
[
obj
.
session_id
]
return
session_id
async
def
close_session
(
self
,
obj
:
CloseSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
def
get_log_request_metadata
(
self
):
max_length
=
None
skip_names
=
None
out_skip_names
=
None
if
self
.
log_requests
:
if
self
.
log_requests_level
==
0
:
max_length
=
1
<<
30
skip_names
=
set
(
[
"text"
,
"input_ids"
,
"input_embeds"
,
"image_data"
,
"audio_data"
,
"lora_path"
,
"sampling_params"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
"embedding"
,
]
)
elif
self
.
log_requests_level
==
1
:
max_length
=
1
<<
30
skip_names
=
set
(
[
"text"
,
"input_ids"
,
"input_embeds"
,
"image_data"
,
"audio_data"
,
"lora_path"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
"embedding"
,
]
)
elif
self
.
log_requests_level
==
2
:
max_length
=
2048
elif
self
.
log_requests_level
==
3
:
max_length
=
1
<<
30
else
:
raise
ValueError
(
f
"Invalid --log-requests-level:
{
self
.
log_requests_level
=
}
"
)
return
max_length
,
skip_names
,
out_skip_names
def
configure_logging
(
self
,
obj
:
ConfigureLoggingReq
):
if
obj
.
log_requests
is
not
None
:
self
.
log_requests
=
obj
.
log_requests
...
...
@@ -1353,12 +1264,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# Drain requests
while
True
:
remain_num_req
=
len
(
self
.
rid_to_state
)
remaining_rids
=
list
(
self
.
rid_to_state
.
keys
())
if
self
.
server_status
==
ServerStatus
.
UnHealthy
:
# if health check failed, we should exit immediately
logger
.
error
(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d"
,
remain_num_req
,
"Signal SIGTERM received while health check failed. Force exiting."
)
self
.
dump_requests_before_crash
()
break
...
...
@@ -1366,13 +1277,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
elif
get_bool_env_var
(
"SGL_FORCE_SHUTDOWN"
):
# if force shutdown flag set, exit immediately
logger
.
error
(
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d"
,
remain_num_req
,
"Signal SIGTERM received while force shutdown flag set. Force exiting."
)
break
logger
.
info
(
f
"Gracefully exiting...
r
emaining number of requests
{
remain_num_req
}
"
f
"Gracefully exiting...
R
emaining number of requests
{
remain_num_req
}
. Remaining requests
{
remaining_rids
=
}
.
"
)
if
remain_num_req
>
0
:
await
asyncio
.
sleep
(
5
)
...
...
@@ -1888,6 +1798,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
load_udpate_req
=
WatchLoadUpdateReq
(
loads
=
loads
)
self
.
send_to_scheduler
.
send_pyobj
(
load_udpate_req
)
def
_trace_request_start
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
created_time
:
Optional
[
float
]
=
None
,
):
if
obj
.
is_single
:
bootstrap_room
=
(
obj
.
bootstrap_room
if
hasattr
(
obj
,
"bootstrap_room"
)
else
None
)
trace_req_start
(
obj
.
rid
,
bootstrap_room
,
ts
=
int
(
created_time
*
1e9
))
trace_slice_start
(
""
,
obj
.
rid
,
ts
=
int
(
created_time
*
1e9
),
anonymous
=
True
)
else
:
for
i
in
range
(
len
(
obj
.
rid
)):
bootstrap_room
=
(
obj
.
bootstrap_room
[
i
]
if
hasattr
(
obj
,
"bootstrap_room"
)
and
obj
.
bootstrap_room
else
None
)
trace_req_start
(
obj
.
rid
[
i
],
bootstrap_room
,
ts
=
int
(
created_time
*
1e9
))
trace_slice_start
(
""
,
obj
.
rid
[
i
],
ts
=
int
(
created_time
*
1e9
),
anonymous
=
True
)
class
ServerStatus
(
Enum
):
Up
=
"Up"
...
...
@@ -1933,7 +1866,7 @@ class SignalHandler:
def
running_phase_sigquit_handler
(
self
,
signum
=
None
,
frame
=
None
):
logger
.
error
(
"Received sigquit from a child process
. It usually means
th
e child failed."
f
"SIGQUIT received.
{
signum
=
}
,
{
frame
=
}
. It usually means
on
e child failed."
)
self
.
tokenizer_manager
.
dump_requests_before_crash
()
kill_process_tree
(
os
.
getpid
())
...
...
python/sglang/srt/metrics/collector.py
View file @
2d62af6b
...
...
@@ -14,9 +14,9 @@
"""Utilities for Prometheus Metrics Collection."""
import
time
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.metrics.utils
import
exponential_buckets
,
generate_buckets
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
...
...
@@ -34,6 +34,7 @@ class TimeStats:
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
"""
disagg_mode
:
DisaggregationMode
=
DisaggregationMode
.
NULL
lb_entry_time
:
float
=
0.0
wait_queue_entry_time
:
float
=
0.0
forward_entry_time
:
float
=
0.0
...
...
@@ -43,20 +44,11 @@ class TimeStats:
decode_prealloc_queue_entry_time
:
float
=
0.0
decode_transfer_queue_entry_time
:
float
=
0.0
class
RequestType
(
Enum
):
UNIFIED
=
"unified"
PREFILL
=
"prefill"
DECODE
=
"decode"
INVALID
=
"invalid"
def
get_queueing_time
(
self
)
->
float
:
return
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
def
__str__
(
self
)
->
str
:
# if unified
_type
=
self
.
get_type
()
if
_type
==
self
.
RequestType
.
UNIFIED
:
def
convert_to_duration
(
self
)
->
str
:
if
self
.
disagg_mode
==
DisaggregationMode
.
NULL
:
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
...
...
@@ -65,30 +57,28 @@ class TimeStats:
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
wait_queue_entry_time
}
"
elif
_type
==
self
.
RequestTyp
e
.
PREFILL
:
return
f
"queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
wait_queue_entry_time
:.
3
f
}
"
elif
self
.
disagg_mode
==
DisaggregationMod
e
.
PREFILL
:
bootstrap_duration
=
(
self
.
wait_queue_entry_time
-
self
.
prefill_bootstrap_queue_entry_time
)
queue_duration
=
self
.
forward_entry_time
-
self
.
wait_queue_entry_time
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
bootstrap_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"bootstrap_duration=
{
bootstrap_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"bootstrap_duration=
{
self
.
format_duration
(
bootstrap_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
prefill_bootstrap_queue_entry_time
}
"
# if decode
elif
_type
==
self
.
RequestType
.
DECODE
:
if
self
.
wait_queue_entry_time
>
0
:
assert
(
bootstrap_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"bootstrap_duration=
{
bootstrap_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"bootstrap_duration=
{
self
.
format_duration
(
bootstrap_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
prefill_bootstrap_queue_entry_time
:.
3
f
}
"
elif
self
.
disagg_mode
==
DisaggregationMode
.
DECODE
:
prealloc_duration
=
(
self
.
decode_transfer_queue_entry_time
-
self
.
decode_prealloc_queue_entry_time
)
transfer_duration
=
(
self
.
wait_queue_entry_time
-
self
.
decode_transfer_queue_entry_time
)
...
...
@@ -96,42 +86,30 @@ class TimeStats:
forward_duration
=
self
.
completion_time
-
self
.
forward_entry_time
if
SGLANG_TEST_REQUEST_TIME_STATS
:
assert
(
prealloc_duration
>=
0
and
transfer_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"prealloc_duration=
{
prealloc_duration
}
< 0 or transfer_duration=
{
transfer_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0"
return
f
"prealloc_duration=
{
self
.
format_duration
(
prealloc_duration
)
}
, transfer_duration=
{
self
.
format_duration
(
transfer_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
decode_prealloc_queue_entry_time
}
"
if
self
.
wait_queue_entry_time
>
0
:
assert
(
prealloc_duration
>=
0
and
transfer_duration
>=
0
and
queue_duration
>=
0
and
forward_duration
>=
0
),
f
"prealloc_duration=
{
prealloc_duration
}
< 0 or transfer_duration=
{
transfer_duration
}
< 0 or queue_duration=
{
queue_duration
}
< 0 or forward_duration=
{
forward_duration
}
< 0.
{
self
=
}
"
return
f
"prealloc_duration=
{
self
.
format_duration
(
prealloc_duration
)
}
, transfer_duration=
{
self
.
format_duration
(
transfer_duration
)
}
, queue_duration=
{
self
.
format_duration
(
queue_duration
)
}
, forward_duration=
{
self
.
format_duration
(
forward_duration
)
}
, start_time=
{
self
.
decode_prealloc_queue_entry_time
:.
3
f
}
"
else
:
return
"
Invalid
Time Stats"
return
"
Unknown
Time Stats"
def
format_duration
(
self
,
duration
:
float
)
->
str
:
return
f
"
{
duration
*
1e3
:.
2
f
}
ms"
def
get_type
(
self
)
->
RequestType
:
"""Determine the type of request based on timestamp values."""
if
(
self
.
prefill_bootstrap_queue_entry_time
==
0.0
and
self
.
prefill_transfer_queue_entry_time
==
0.0
and
self
.
decode_prealloc_queue_entry_time
==
0.0
and
self
.
decode_transfer_queue_entry_time
==
0.0
):
return
self
.
RequestType
.
UNIFIED
elif
(
self
.
prefill_bootstrap_queue_entry_time
>
0.0
and
self
.
prefill_transfer_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
PREFILL
elif
(
self
.
decode_prealloc_queue_entry_time
>
0.0
and
self
.
decode_transfer_queue_entry_time
>
0.0
and
self
.
wait_queue_entry_time
>
0.0
):
return
self
.
RequestType
.
DECODE
def
disagg_mode_str
(
self
)
->
str
:
if
self
.
disagg_mode
==
DisaggregationMode
.
NULL
:
return
"unified"
elif
self
.
disagg_mode
==
DisaggregationMode
.
DECODE
:
return
"decode"
elif
self
.
disagg_mode
==
DisaggregationMode
.
PREFILL
:
return
"prefill"
else
:
return
self
.
RequestType
.
INVALID
return
"unknown"
@
dataclass
...
...
@@ -145,12 +123,15 @@ class SchedulerStats:
num_queue_reqs
:
int
=
0
num_grammar_queue_reqs
:
int
=
0
num_running_reqs_offline_batch
:
int
=
0
avg_request_queue_latency
:
float
=
0.0
cache_hit_rate
:
float
=
0.0
# Speculative decoding
spec_accept_length
:
float
=
0.0
# Retract
num_retracted_reqs
:
int
=
0
num_paused_reqs
:
int
=
0
# PD disaggregation
num_prefill_prealloc_queue_reqs
:
int
=
0
num_prefill_inflight_queue_reqs
:
int
=
0
...
...
@@ -159,11 +140,6 @@ class SchedulerStats:
kv_transfer_speed_gb_s
:
float
=
0.0
kv_transfer_latency_ms
:
float
=
0.0
# Retract
total_retracted_reqs
:
int
=
0
num_retracted_reqs
:
int
=
0
num_paused_reqs
:
int
=
0
# Utilization
utilization
:
float
=
0.0
max_running_requests_under_SLO
:
Optional
[
int
]
=
None
...
...
@@ -230,12 +206,6 @@ class SchedulerMetricsCollector:
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
avg_request_queue_latency
=
Gauge
(
name
=
"sglang:avg_request_queue_latency"
,
documentation
=
"The average request queue latency for the last batch of requests in seconds."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
cache_hit_rate
=
Gauge
(
name
=
"sglang:cache_hit_rate"
,
documentation
=
"The prefix cache hit rate."
,
...
...
@@ -251,6 +221,18 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
)
# Retract
self
.
num_retracted_reqs
=
Gauge
(
name
=
"sglang:num_retracted_reqs"
,
documentation
=
"The number of retracted requests."
,
labelnames
=
labels
.
keys
(),
)
self
.
num_paused_reqs
=
Gauge
(
name
=
"sglang:num_paused_reqs"
,
documentation
=
"The number of paused requests by async weight sync."
,
labelnames
=
labels
.
keys
(),
)
# PD disaggregation
self
.
num_prefill_prealloc_queue_reqs
=
Gauge
(
name
=
"sglang:num_prefill_prealloc_queue_reqs"
,
...
...
@@ -299,24 +281,6 @@ class SchedulerMetricsCollector:
multiprocess_mode
=
"mostrecent"
,
)
# Retract
self
.
total_retracted_reqs
=
Gauge
(
name
=
"sglang:total_retracted_reqs"
,
documentation
=
"The total number of retracted requests due to kvcache full."
,
labelnames
=
labels
.
keys
(),
multiprocess_mode
=
"mostrecent"
,
)
self
.
num_retracted_reqs
=
Gauge
(
name
=
"sglang:num_retracted_reqs"
,
documentation
=
"The number of retracted requests."
,
labelnames
=
labels
.
keys
(),
)
self
.
num_paused_reqs
=
Gauge
(
name
=
"sglang:num_paused_reqs"
,
documentation
=
"The number of paused requests by async weight sync."
,
labelnames
=
labels
.
keys
(),
)
# Utilization
self
.
utilization
=
Gauge
(
name
=
"sglang:utilization"
,
...
...
@@ -347,7 +311,7 @@ class SchedulerMetricsCollector:
# Additional queueing time histogram
self
.
queue_time
=
Histogram
(
name
=
"sglang:queue_time_s"
,
name
=
"sglang:queue_time_s
econds
"
,
documentation
=
"Histogram of queueing time in seconds."
,
labelnames
=
labels
.
keys
(),
buckets
=
[
...
...
@@ -513,8 +477,8 @@ class SchedulerMetricsCollector:
buckets
=
tree_traversal_time_buckets
,
)
self
.
request
_latency_seconds
=
Histogram
(
name
=
"sglang:
request
_latency_seconds"
,
self
.
per_stage_req
_latency_seconds
=
Histogram
(
name
=
"sglang:
per_stage_req
_latency_seconds"
,
documentation
=
"The latency of each stage of requests."
,
# captures latency in range [1ms - ~1191s]
buckets
=
exponential_buckets
(
start
=
0.001
,
width
=
1.62
,
length
=
30
),
...
...
@@ -525,7 +489,7 @@ class SchedulerMetricsCollector:
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
def
log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
def
_
log_histogram
(
self
,
histogram
,
data
:
Union
[
int
,
float
])
->
None
:
histogram
.
labels
(
**
self
.
labels
).
observe
(
data
)
def
increment_bootstrap_failed_reqs
(
self
)
->
None
:
...
...
@@ -534,9 +498,12 @@ class SchedulerMetricsCollector:
def
increment_transfer_failed_reqs
(
self
)
->
None
:
self
.
num_transfer_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
observe_
request_latency_seconds
(
self
,
stage
:
str
,
latency
:
float
)
->
None
:
def
observe_
per_stage_req_latency
(
self
,
stage
:
str
,
latency
:
float
)
->
None
:
labels_with_stage
=
{
**
self
.
labels
,
"stage"
:
stage
}
self
.
request_latency_seconds
.
labels
(
**
labels_with_stage
).
observe
(
latency
)
self
.
per_stage_req_latency_seconds
.
labels
(
**
labels_with_stage
).
observe
(
latency
)
def
observe_queue_time
(
self
,
latency
:
float
)
->
None
:
self
.
_log_histogram
(
self
.
queue_time
,
latency
)
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
...
...
@@ -550,7 +517,6 @@ class SchedulerMetricsCollector:
self
.
num_running_reqs_offline_batch
,
stats
.
num_running_reqs_offline_batch
)
self
.
_log_gauge
(
self
.
cache_hit_rate
,
stats
.
cache_hit_rate
)
self
.
_log_gauge
(
self
.
avg_request_queue_latency
,
stats
.
avg_request_queue_latency
)
# Speculative decoding
self
.
_log_gauge
(
self
.
spec_accept_length
,
stats
.
spec_accept_length
)
...
...
@@ -572,7 +538,6 @@ class SchedulerMetricsCollector:
self
.
_log_gauge
(
self
.
kv_transfer_latency_ms
,
stats
.
kv_transfer_latency_ms
)
# Retract
self
.
_log_gauge
(
self
.
total_retracted_reqs
,
stats
.
total_retracted_reqs
)
self
.
_log_gauge
(
self
.
num_retracted_reqs
,
stats
.
num_retracted_reqs
)
self
.
_log_gauge
(
self
.
num_paused_reqs
,
stats
.
num_paused_reqs
)
...
...
@@ -596,19 +561,19 @@ class SchedulerMetricsCollector:
def
log_grammar_stats
(
self
,
grammar_stats
)
->
None
:
# Duck-typed GrammarStats to avoid cross-package dependency
if
getattr
(
grammar_stats
,
"compilation_time"
,
None
)
is
not
None
:
self
.
log_histogram
(
self
.
_
log_histogram
(
self
.
grammar_compilation_time
,
grammar_stats
.
compilation_time
)
if
getattr
(
grammar_stats
,
"schema_count"
,
None
)
is
not
None
:
self
.
log_histogram
(
self
.
grammar_schema_count
,
grammar_stats
.
schema_count
)
self
.
_
log_histogram
(
self
.
grammar_schema_count
,
grammar_stats
.
schema_count
)
if
getattr
(
grammar_stats
,
"ebnf_size"
,
None
)
is
not
None
:
self
.
log_histogram
(
self
.
grammar_ebnf_size
,
grammar_stats
.
ebnf_size
)
self
.
_
log_histogram
(
self
.
grammar_ebnf_size
,
grammar_stats
.
ebnf_size
)
tree_times
=
getattr
(
grammar_stats
,
"tree_traversal_time"
,
None
)
if
tree_times
:
max_time
=
max
(
tree_times
)
avg_time
=
sum
(
tree_times
)
/
len
(
tree_times
)
self
.
log_histogram
(
self
.
grammar_tree_traversal_time_max
,
max_time
)
self
.
log_histogram
(
self
.
grammar_tree_traversal_time_avg
,
avg_time
)
self
.
_
log_histogram
(
self
.
grammar_tree_traversal_time_max
,
max_time
)
self
.
_
log_histogram
(
self
.
grammar_tree_traversal_time_avg
,
avg_time
)
if
getattr
(
grammar_stats
,
"is_cache_hit"
,
False
):
self
.
num_grammar_cache_hit
.
labels
(
**
self
.
labels
).
inc
(
1
)
if
getattr
(
grammar_stats
,
"is_grammar_aborted"
,
False
):
...
...
@@ -714,7 +679,7 @@ class TokenizerMetricsCollector:
)
self
.
num_aborted_requests_total
=
Counter
(
name
=
"sglang:num_aborted_requests"
,
name
=
"sglang:num_aborted_requests
_total
"
,
documentation
=
"Number of requests aborted."
,
labelnames
=
labels
.
keys
(),
)
...
...
@@ -801,7 +766,7 @@ class TokenizerMetricsCollector:
buckets
=
bucket_time_to_first_token
,
)
self
.
histogram_inter_token_latency
_seconds
=
Histogram
(
self
.
histogram_inter_token_latency
=
Histogram
(
name
=
"sglang:inter_token_latency_seconds"
,
documentation
=
"Histogram of inter-token latency in seconds."
,
labelnames
=
labels
.
keys
(),
...
...
@@ -815,14 +780,6 @@ class TokenizerMetricsCollector:
buckets
=
bucket_e2e_request_latency
,
)
# Offline batch specific TTFB histogram
self
.
histogram_time_to_first_token_offline_batch
=
Histogram
(
name
=
"sglang:time_to_first_token_seconds_offline_batch"
,
documentation
=
"Histogram of time to first token in seconds for offline batch requests."
,
labelnames
=
labels
.
keys
(),
buckets
=
bucket_time_to_first_token
,
)
def
observe_one_finished_request
(
self
,
labels
:
Dict
[
str
,
str
],
...
...
@@ -846,15 +803,8 @@ class TokenizerMetricsCollector:
float
(
generation_tokens
)
)
def
observe_time_to_first_token
(
self
,
labels
:
Dict
[
str
,
str
],
value
:
float
,
type
:
str
=
""
):
if
type
==
"batch"
:
self
.
histogram_time_to_first_token_offline_batch
.
labels
(
**
labels
).
observe
(
value
)
else
:
self
.
histogram_time_to_first_token
.
labels
(
**
labels
).
observe
(
value
)
def
observe_time_to_first_token
(
self
,
labels
:
Dict
[
str
,
str
],
value
:
float
):
self
.
histogram_time_to_first_token
.
labels
(
**
labels
).
observe
(
value
)
def
check_time_to_first_token_straggler
(
self
,
value
:
float
)
->
bool
:
his
=
self
.
histogram_time_to_first_token
.
labels
(
**
self
.
labels
)
...
...
@@ -876,7 +826,7 @@ class TokenizerMetricsCollector:
# A faster version of the Histogram::observe which observes multiple values at the same time.
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his
=
self
.
histogram_inter_token_latency
_seconds
.
labels
(
**
labels
)
his
=
self
.
histogram_inter_token_latency
.
labels
(
**
labels
)
his
.
_sum
.
inc
(
internval
)
for
i
,
bound
in
enumerate
(
his
.
_upper_bounds
):
...
...
@@ -884,8 +834,8 @@ class TokenizerMetricsCollector:
his
.
_buckets
[
i
].
inc
(
num_new_tokens
)
break
def
observe_one_aborted_request
(
self
):
self
.
num_aborted_requests_total
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
observe_one_aborted_request
(
self
,
labels
:
Dict
[
str
,
str
]
):
self
.
num_aborted_requests_total
.
labels
(
**
labels
).
inc
(
1
)
@
dataclass
...
...
python/sglang/srt/tracing/trace.py
View file @
2d62af6b
...
...
@@ -15,7 +15,6 @@
from
__future__
import
annotations
import
ctypes
import
logging
import
os
import
random
...
...
@@ -23,7 +22,10 @@ import threading
import
time
import
uuid
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
if
TYPE_CHECKING
:
from
sglang.srt.managers.scheduler
import
Req
logger
=
logging
.
getLogger
(
__name__
)
opentelemetry_imported
=
False
...
...
@@ -407,9 +409,11 @@ def trace_slice_start(
ts
:
Optional
[
int
]
=
None
,
anonymous
:
bool
=
False
,
):
if
not
tracing_enabled
:
return
rid
=
str
(
rid
)
if
not
tracing_enabled
or
rid
not
in
reqs_context
:
if
rid
not
in
reqs_context
:
return
pid
=
threading
.
get_native_id
()
...
...
@@ -458,8 +462,11 @@ def trace_slice_end(
auto_next_anon
:
bool
=
False
,
thread_finish_flag
:
bool
=
False
,
):
if
not
tracing_enabled
:
return
rid
=
str
(
rid
)
if
not
tracing_enabled
or
rid
not
in
reqs_context
:
if
rid
not
in
reqs_context
:
return
pid
=
threading
.
get_native_id
()
...
...
@@ -512,10 +519,13 @@ trace_slice = trace_slice_end
# Add event to the current slice on the same thread with the same rid.
def
trace_event
(
name
:
str
,
rid
:
str
,
ts
:
Optional
[
int
]
=
None
):
if
not
tracing_enabled
or
rid
not
in
reqs_context
:
if
not
tracing_enabled
:
return
rid
=
str
(
rid
)
if
rid
not
in
reqs_context
:
return
pid
=
threading
.
get_native_id
()
if
pid
not
in
reqs_context
[
rid
].
threads_context
:
return
...
...
@@ -534,10 +544,13 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
# Add attrs to the current slice on the same thread with the same rid.
def
trace_slice_add_attr
(
rid
:
str
,
attrs
:
Dict
[
str
,
Any
]):
if
not
tracing_enabled
or
rid
not
in
reqs_context
:
if
not
tracing_enabled
:
return
rid
=
str
(
rid
)
if
rid
not
in
reqs_context
:
return
pid
=
threading
.
get_native_id
()
if
pid
not
in
reqs_context
[
rid
].
threads_context
:
return
...
...
@@ -550,3 +563,16 @@ def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
slice_info
=
thread_context
.
cur_slice_stack
[
-
1
]
slice_info
.
span
.
set_attributes
(
attrs
)
def
trace_slice_batch
(
name
:
str
,
reqs
:
List
[
Req
],
):
for
req
in
reqs
:
trace_slice
(
name
,
req
.
rid
,
auto_next_anon
=
not
req
.
finished
(),
thread_finish_flag
=
req
.
finished
(),
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment