Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b1721edb
Unverified
Commit
b1721edb
authored
Sep 16, 2025
by
Yingchun Lai
Committed by
GitHub
Sep 16, 2025
Browse files
[PD metrics] Add latency Histogram metrics of each stage for generate requests (#8710)
parent
57234d0c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
77 additions
and
11 deletions
+77
-11
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+5
-1
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+11
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+34
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+5
-0
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+13
-1
python/sglang/srt/metrics/func_timer.py
python/sglang/srt/metrics/func_timer.py
+2
-7
python/sglang/srt/metrics/utils.py
python/sglang/srt/metrics/utils.py
+7
-0
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
b1721edb
...
@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -45,7 +45,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort
,
prepare_abort
,
)
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
RequestStage
,
ScheduleBatch
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
...
@@ -253,6 +253,7 @@ class DecodePreallocQueue:
...
@@ -253,6 +253,7 @@ class DecodePreallocQueue:
prefill_dp_rank
=
req
.
data_parallel_rank
,
prefill_dp_rank
=
req
.
data_parallel_rank
,
)
)
req
.
add_latency
(
RequestStage
.
DECODE_PREPARE
)
self
.
queue
.
append
(
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
,
waiting_for_input
=
False
)
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
,
waiting_for_input
=
False
)
)
)
...
@@ -421,6 +422,7 @@ class DecodePreallocQueue:
...
@@ -421,6 +422,7 @@ class DecodePreallocQueue:
kv_indices
,
self
.
token_to_kv_pool_allocator
.
page_size
kv_indices
,
self
.
token_to_kv_pool_allocator
.
page_size
)
)
decode_req
.
kv_receiver
.
init
(
page_indices
,
decode_req
.
metadata_buffer_index
)
decode_req
.
kv_receiver
.
init
(
page_indices
,
decode_req
.
metadata_buffer_index
)
decode_req
.
req
.
add_latency
(
RequestStage
.
DECODE_BOOTSTRAP
)
preallocated_reqs
.
append
(
decode_req
)
preallocated_reqs
.
append
(
decode_req
)
indices_to_remove
.
add
(
i
)
indices_to_remove
.
add
(
i
)
...
@@ -662,6 +664,7 @@ class DecodeTransferQueue:
...
@@ -662,6 +664,7 @@ class DecodeTransferQueue:
for
i
in
indices_to_remove
:
for
i
in
indices_to_remove
:
idx
=
self
.
queue
[
i
].
metadata_buffer_index
idx
=
self
.
queue
[
i
].
metadata_buffer_index
assert
idx
!=
-
1
assert
idx
!=
-
1
self
.
queue
[
i
].
req
.
add_latency
(
RequestStage
.
DECODE_TRANSFERRED
)
self
.
req_to_metadata_buffer_idx_allocator
.
free
(
idx
)
self
.
req_to_metadata_buffer_idx_allocator
.
free
(
idx
)
self
.
queue
=
[
self
.
queue
=
[
...
@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -853,6 +856,7 @@ class SchedulerDisaggregationDecodeMixin:
# we can only add at least `num_not_used_batch` new batch to the running queue
# we can only add at least `num_not_used_batch` new batch to the running queue
if
i
<
num_not_used_batch
:
if
i
<
num_not_used_batch
:
can_run_list
.
append
(
req
)
can_run_list
.
append
(
req
)
req
.
add_latency
(
RequestStage
.
DECODE_WAITING
)
req
.
init_next_round_input
(
self
.
tree_cache
)
req
.
init_next_round_input
(
self
.
tree_cache
)
else
:
else
:
waiting_queue
.
append
(
req
)
waiting_queue
.
append
(
req
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
b1721edb
...
@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import (
...
@@ -42,7 +42,12 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce
,
poll_and_all_reduce
,
prepare_abort
,
prepare_abort
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
(
FINISH_LENGTH
,
Req
,
RequestStage
,
ScheduleBatch
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
PPProxyTensors
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
DynamicGradMode
,
DynamicGradMode
,
...
@@ -170,6 +175,7 @@ class PrefillBootstrapQueue:
...
@@ -170,6 +175,7 @@ class PrefillBootstrapQueue:
pp_rank
=
self
.
pp_rank
,
pp_rank
=
self
.
pp_rank
,
)
)
self
.
_process_req
(
req
)
self
.
_process_req
(
req
)
req
.
add_latency
(
RequestStage
.
PREFILL_PREPARE
)
self
.
queue
.
append
(
req
)
self
.
queue
.
append
(
req
)
def
extend
(
self
,
reqs
:
List
[
Req
],
num_kv_heads
:
int
)
->
None
:
def
extend
(
self
,
reqs
:
List
[
Req
],
num_kv_heads
:
int
)
->
None
:
...
@@ -256,6 +262,8 @@ class PrefillBootstrapQueue:
...
@@ -256,6 +262,8 @@ class PrefillBootstrapQueue:
num_pages
=
kv_to_page_num
(
num_kv_indices
,
self
.
token_to_kv_pool
.
page_size
)
num_pages
=
kv_to_page_num
(
num_kv_indices
,
self
.
token_to_kv_pool
.
page_size
)
req
.
disagg_kv_sender
.
init
(
num_pages
,
req
.
metadata_buffer_index
)
req
.
disagg_kv_sender
.
init
(
num_pages
,
req
.
metadata_buffer_index
)
req
.
add_latency
(
RequestStage
.
PREFILL_BOOTSTRAP
)
bootstrapped_reqs
.
append
(
req
)
bootstrapped_reqs
.
append
(
req
)
indices_to_remove
.
add
(
i
)
indices_to_remove
.
add
(
i
)
...
@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -404,6 +412,7 @@ class SchedulerDisaggregationPrefillMixin:
# There is no output_ids for prefill
# There is no output_ids for prefill
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# update the tree and lock
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# update the tree and lock
req
.
add_latency
(
RequestStage
.
PREFILL_FORWARD
)
self
.
disagg_prefill_inflight_queue
.
append
(
req
)
self
.
disagg_prefill_inflight_queue
.
append
(
req
)
if
(
if
(
logits_output
is
not
None
logits_output
is
not
None
...
@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -539,6 +548,7 @@ class SchedulerDisaggregationPrefillMixin:
)
)
for
req
in
done_reqs
:
for
req
in
done_reqs
:
req
:
Req
req
:
Req
req
.
add_latency
(
RequestStage
.
PREFILL_TRANSFER_KV_CACHE
)
self
.
req_to_metadata_buffer_idx_allocator
.
free
(
req
.
metadata_buffer_index
)
self
.
req_to_metadata_buffer_idx_allocator
.
free
(
req
.
metadata_buffer_index
)
req
.
metadata_buffer_index
=
-
1
req
.
metadata_buffer_index
=
-
1
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b1721edb
from
__future__
import
annotations
from
__future__
import
annotations
import
enum
# Copyright 2023-2024 SGLang Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -35,6 +37,7 @@ import copy
...
@@ -35,6 +37,7 @@ import copy
import
dataclasses
import
dataclasses
import
logging
import
logging
import
threading
import
threading
import
time
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
itertools
import
chain
from
itertools
import
chain
...
@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
...
@@ -61,7 +64,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from
sglang.srt.mem_cache.lora_radix_cache
import
LoRAKey
,
LoRARadixCache
from
sglang.srt.mem_cache.lora_radix_cache
import
LoRAKey
,
LoRARadixCache
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.metrics.collector
import
TimeStats
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
TimeStats
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
@@ -407,6 +410,23 @@ class MultimodalInputs:
...
@@ -407,6 +410,23 @@ class MultimodalInputs:
# other args would be kept intact
# other args would be kept intact
class
RequestStage
(
str
,
enum
.
Enum
):
# prefill
PREFILL_WAITING
=
"prefill_waiting"
# disaggregation prefill
PREFILL_PREPARE
=
"prefill_prepare"
PREFILL_BOOTSTRAP
=
"prefill_bootstrap"
PREFILL_FORWARD
=
"prefill_forward"
PREFILL_TRANSFER_KV_CACHE
=
"prefill_transfer_kv_cache"
# disaggregation decode
DECODE_PREPARE
=
"decode_prepare"
DECODE_BOOTSTRAP
=
"decode_bootstrap"
DECODE_WAITING
=
"decode_waiting"
DECODE_TRANSFERRED
=
"decode_transferred"
class
Req
:
class
Req
:
"""The input and output status of a request."""
"""The input and output status of a request."""
...
@@ -433,6 +453,7 @@ class Req:
...
@@ -433,6 +453,7 @@ class Req:
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
vocab_size
:
Optional
[
int
]
=
None
,
vocab_size
:
Optional
[
int
]
=
None
,
metrics_collector
:
Optional
[
SchedulerMetricsCollector
]
=
None
,
):
):
# Input and output info
# Input and output info
self
.
rid
=
rid
self
.
rid
=
rid
...
@@ -590,10 +611,12 @@ class Req:
...
@@ -590,10 +611,12 @@ class Req:
self
.
spec_verify_ct
=
0
self
.
spec_verify_ct
=
0
# For metrics
# For metrics
self
.
metrics_collector
=
metrics_collector
self
.
time_stats
:
TimeStats
=
TimeStats
()
self
.
time_stats
:
TimeStats
=
TimeStats
()
self
.
has_log_time_stats
:
bool
=
False
self
.
has_log_time_stats
:
bool
=
False
self
.
queue_time_start
=
None
self
.
queue_time_start
=
None
self
.
queue_time_end
=
None
self
.
queue_time_end
=
None
self
.
last_tic
=
time
.
monotonic
()
# For disaggregation
# For disaggregation
self
.
bootstrap_host
:
str
=
bootstrap_host
self
.
bootstrap_host
:
str
=
bootstrap_host
...
@@ -626,6 +649,16 @@ class Req:
...
@@ -626,6 +649,16 @@ class Req:
"""Check if this request is prefill-only (no token generation needed)."""
"""Check if this request is prefill-only (no token generation needed)."""
return
self
.
sampling_params
.
max_new_tokens
==
0
return
self
.
sampling_params
.
max_new_tokens
==
0
def
add_latency
(
self
,
stage
:
RequestStage
):
if
self
.
metrics_collector
is
None
:
return
assert
stage
.
name
in
RequestStage
.
__members__
,
f
"
{
stage
=
}
is invalid"
now
=
time
.
monotonic
()
self
.
metrics_collector
.
observe_request_latency_seconds
(
stage
.
value
,
now
-
self
.
last_tic
)
self
.
last_tic
=
now
def
extend_image_inputs
(
self
,
image_inputs
):
def
extend_image_inputs
(
self
,
image_inputs
):
if
self
.
multimodal_inputs
is
None
:
if
self
.
multimodal_inputs
is
None
:
self
.
multimodal_inputs
=
image_inputs
self
.
multimodal_inputs
=
image_inputs
...
...
python/sglang/srt/managers/scheduler.py
View file @
b1721edb
...
@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
FINISH_ABORT
,
FINISH_ABORT
,
MultimodalInputs
,
MultimodalInputs
,
Req
,
Req
,
RequestStage
,
ScheduleBatch
,
ScheduleBatch
,
global_server_args_dict
,
global_server_args_dict
,
)
)
...
@@ -1232,6 +1233,9 @@ class Scheduler(
...
@@ -1232,6 +1233,9 @@ class Scheduler(
bootstrap_room
=
recv_req
.
bootstrap_room
,
bootstrap_room
=
recv_req
.
bootstrap_room
,
data_parallel_rank
=
recv_req
.
data_parallel_rank
,
data_parallel_rank
=
recv_req
.
data_parallel_rank
,
vocab_size
=
self
.
model_config
.
vocab_size
,
vocab_size
=
self
.
model_config
.
vocab_size
,
metrics_collector
=
(
self
.
metrics_collector
if
self
.
enable_metrics
else
None
),
)
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
...
@@ -1768,6 +1772,7 @@ class Scheduler(
...
@@ -1768,6 +1772,7 @@ class Scheduler(
# only record queue time when enable_metrics is True to avoid overhead
# only record queue time when enable_metrics is True to avoid overhead
for
req
in
can_run_list
:
for
req
in
can_run_list
:
req
.
queue_time_end
=
time
.
perf_counter
()
req
.
queue_time_end
=
time
.
perf_counter
()
req
.
add_latency
(
RequestStage
.
PREFILL_WAITING
)
self
.
waiting_queue
=
[
self
.
waiting_queue
=
[
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
...
...
python/sglang/srt/metrics/collector.py
View file @
b1721edb
...
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
...
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.metrics.utils
import
generate_buckets
from
sglang.srt.metrics.utils
import
exponential_buckets
,
generate_buckets
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
...
@@ -513,6 +513,14 @@ class SchedulerMetricsCollector:
...
@@ -513,6 +513,14 @@ class SchedulerMetricsCollector:
buckets
=
tree_traversal_time_buckets
,
buckets
=
tree_traversal_time_buckets
,
)
)
self
.
request_latency_seconds
=
Histogram
(
name
=
"sglang:request_latency_seconds"
,
documentation
=
"The latency of each stage of requests."
,
# captures latency in range [1ms - ~1191s]
buckets
=
exponential_buckets
(
start
=
0.001
,
width
=
1.62
,
length
=
30
),
labelnames
=
list
(
labels
.
keys
())
+
[
"stage"
],
)
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
...
@@ -526,6 +534,10 @@ class SchedulerMetricsCollector:
...
@@ -526,6 +534,10 @@ class SchedulerMetricsCollector:
def
increment_transfer_failed_reqs
(
self
)
->
None
:
def
increment_transfer_failed_reqs
(
self
)
->
None
:
self
.
num_transfer_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
self
.
num_transfer_failed_reqs
.
labels
(
**
self
.
labels
).
inc
(
1
)
def
observe_request_latency_seconds
(
self
,
stage
:
str
,
latency
:
float
)
->
None
:
labels_with_stage
=
{
**
self
.
labels
,
"stage"
:
stage
}
self
.
request_latency_seconds
.
labels
(
**
labels_with_stage
).
observe
(
latency
)
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
def
log_stats
(
self
,
stats
:
SchedulerStats
)
->
None
:
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
self
.
_log_gauge
(
self
.
num_running_reqs
,
stats
.
num_running_reqs
)
self
.
_log_gauge
(
self
.
num_used_tokens
,
stats
.
num_used_tokens
)
self
.
_log_gauge
(
self
.
num_used_tokens
,
stats
.
num_used_tokens
)
...
...
python/sglang/srt/metrics/func_timer.py
View file @
b1721edb
...
@@ -20,6 +20,8 @@ import time
...
@@ -20,6 +20,8 @@ import time
from
functools
import
wraps
from
functools
import
wraps
from
typing
import
Any
,
Callable
,
List
,
Optional
from
typing
import
Any
,
Callable
,
List
,
Optional
from
sglang.srt.metrics.utils
import
exponential_buckets
enable_metrics
=
False
enable_metrics
=
False
...
@@ -42,13 +44,6 @@ def enable_func_timer():
...
@@ -42,13 +44,6 @@ def enable_func_timer():
FUNC_LATENCY
=
None
FUNC_LATENCY
=
None
def
exponential_buckets
(
start
:
float
,
width
:
float
,
length
:
int
)
->
List
[
float
]:
buckets
=
[]
for
i
in
range
(
length
):
buckets
.
append
(
start
*
(
width
**
i
))
return
buckets
def
time_func_latency
(
def
time_func_latency
(
func
:
Callable
=
None
,
name
:
Optional
[
str
]
=
None
func
:
Callable
=
None
,
name
:
Optional
[
str
]
=
None
)
->
Callable
[...,
Any
]:
)
->
Callable
[...,
Any
]:
...
...
python/sglang/srt/metrics/utils.py
View file @
b1721edb
...
@@ -46,3 +46,10 @@ def generate_buckets(
...
@@ -46,3 +46,10 @@ def generate_buckets(
return
sorted
(
set
(
default_buckets
))
return
sorted
(
set
(
default_buckets
))
assert
rule
==
"customer"
assert
rule
==
"customer"
return
sorted
(
set
([
float
(
x
)
for
x
in
buckets_rule
[
1
:]]))
return
sorted
(
set
([
float
(
x
)
for
x
in
buckets_rule
[
1
:]]))
def
exponential_buckets
(
start
:
float
,
width
:
float
,
length
:
int
)
->
List
[
float
]:
buckets
=
[]
for
i
in
range
(
length
):
buckets
.
append
(
start
*
(
width
**
i
))
return
buckets
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment