Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a79cc68b
Unverified
Commit
a79cc68b
authored
Apr 01, 2025
by
Mark McLoughlin
Committed by
GitHub
Apr 01, 2025
Browse files
[V1][Metrics] Initial speculative decoding metrics (#15151)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
7e3f7a4e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
204 additions
and
2 deletions
+204
-2
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+95
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+13
-2
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+33
-0
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+4
-0
vllm/v1/spec_decode/metrics.py
vllm/v1/spec_decode/metrics.py
+59
-0
No files found.
tests/v1/core/test_scheduler.py
View file @
a79cc68b
...
...
@@ -611,3 +611,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs_dict
=
{},
)
scheduler
.
update_from_output
(
scheduler_output1
,
model_runner_output
)
# Note - these test cases mirror some of those in test_rejection_sampler.py
@
pytest
.
mark
.
parametrize
(
"spec_tokens,output_tokens,expected"
,
[
([[
1
,
2
,
3
]],
[[
1
,
2
,
3
,
4
]],
(
3
,
3
)),
# perfect match
([[
1
,
2
,
3
]],
[[
1
,
5
]],
(
3
,
1
)),
# early mismatch
([[
1
,
2
],
[
3
]],
[[
1
,
2
,
5
],
[
3
,
4
]],
(
3
,
3
)),
# multiple sequences
([[
1
]],
[[
1
,
2
]],
(
1
,
1
)),
# single token sequence
([[]],
[[
5
]],
(
0
,
0
)),
# empty sequence
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
[[
1
,
2
,
7
],
[
4
,
8
]],
(
6
,
3
)),
# multiple mismatches
])
def
test_schedule_spec_decoding_stats
(
spec_tokens
,
output_tokens
,
expected
):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
scheduler
=
create_scheduler
()
requests
=
create_requests
(
num_requests
=
len
(
spec_tokens
),
num_tokens
=
1
)
req_ids
=
[]
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
scheduler
.
add_request
(
request
)
req_ids
.
append
(
request
.
request_id
)
req_to_index
[
request
.
request_id
]
=
i
# Schedule a decode, which will also draft speculative tokens
output
=
scheduler
.
schedule
()
assert
len
(
output
.
scheduled_new_reqs
)
==
len
(
requests
)
assert
output
.
total_num_scheduled_tokens
==
len
(
requests
)
for
i
in
range
(
len
(
requests
)):
req_id
=
requests
[
i
].
request_id
assert
output
.
num_scheduled_tokens
[
req_id
]
==
1
assert
req_id
not
in
output
.
scheduled_spec_decode_tokens
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[[
0
]
for
_
in
range
(
len
(
requests
))],
spec_token_ids
=
spec_tokens
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
engine_core_outputs
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
for
i
in
range
(
len
(
requests
)):
running_req
=
scheduler
.
running
[
i
]
# The prompt token
assert
running_req
.
num_computed_tokens
==
1
# The prompt token and the sampled token
assert
running_req
.
num_tokens
==
2
# The prompt token, the sampled token, and the speculated tokens
assert
running_req
.
num_tokens_with_spec
==
2
+
len
(
spec_tokens
[
i
])
# No draft or accepted tokens counted yet
assert
engine_core_outputs
.
scheduler_stats
.
spec_decoding_stats
is
not
None
stats
=
engine_core_outputs
.
scheduler_stats
.
spec_decoding_stats
assert
stats
.
num_draft_tokens
==
0
assert
stats
.
num_accepted_tokens
==
0
# Schedule the speculated tokens for validation
output
=
scheduler
.
schedule
()
assert
len
(
output
.
scheduled_new_reqs
)
==
0
# The sampled token and speculated tokens
assert
output
.
total_num_scheduled_tokens
==
\
len
(
requests
)
+
sum
(
len
(
ids
)
for
ids
in
spec_tokens
)
for
i
in
range
(
len
(
requests
)):
req_id
=
requests
[
i
].
request_id
assert
output
.
num_scheduled_tokens
[
req_id
]
==
1
+
len
(
spec_tokens
[
i
])
if
spec_tokens
[
i
]:
assert
len
(
output
.
scheduled_spec_decode_tokens
[
req_id
])
==
\
len
(
spec_tokens
[
i
])
else
:
assert
req_id
not
in
output
.
scheduled_spec_decode_tokens
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
output_tokens
,
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
engine_core_outputs
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
assert
engine_core_outputs
.
scheduler_stats
.
spec_decoding_stats
is
not
None
stats
=
engine_core_outputs
.
scheduler_stats
.
spec_decoding_stats
assert
stats
.
num_draft_tokens
==
expected
[
0
]
assert
stats
.
num_accepted_tokens
==
expected
[
1
]
vllm/v1/core/sched/scheduler.py
View file @
a79cc68b
...
...
@@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
logger
=
init_logger
(
__name__
)
...
...
@@ -552,6 +553,7 @@ class Scheduler(SchedulerInterface):
spec_token_ids
=
model_runner_output
.
spec_token_ids
logprobs
=
model_runner_output
.
logprobs
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
spec_decoding_stats
=
SpecDecodingStats
()
if
self
.
log_stats
else
None
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
new_running
:
list
[
Request
]
=
[]
...
...
@@ -584,6 +586,11 @@ class Scheduler(SchedulerInterface):
len
(
generated_token_ids
))
request
.
num_computed_tokens
-=
num_tokens_rejected
if
spec_decoding_stats
is
not
None
:
spec_decoding_stats
.
observe
(
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_accepted_tokens
=
len
(
generated_token_ids
)
-
1
)
cached_encoder_input_ids
=
(
self
.
encoder_cache_manager
.
get_cached_input_ids
(
request
))
# OPTIMIZATION: Avoid list(set) if the set is empty.
...
...
@@ -657,7 +664,7 @@ class Scheduler(SchedulerInterface):
self
.
running
=
new_running
engine_core_outputs
=
EngineCoreOutputs
(
outputs
=
outputs
,
scheduler_stats
=
self
.
make_stats
(),
scheduler_stats
=
self
.
make_stats
(
spec_decoding_stats
),
)
if
self
.
include_finished_set
:
#TODO currently sending duplicates here, improve this
...
...
@@ -724,7 +731,10 @@ class Scheduler(SchedulerInterface):
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
kv_cache_manager
.
reset_prefix_cache
()
def
make_stats
(
self
)
->
Optional
[
SchedulerStats
]:
def
make_stats
(
self
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
,
)
->
Optional
[
SchedulerStats
]:
if
not
self
.
log_stats
:
return
None
return
SchedulerStats
(
...
...
@@ -732,4 +742,5 @@ class Scheduler(SchedulerInterface):
num_waiting_reqs
=
len
(
self
.
waiting
),
gpu_cache_usage
=
self
.
kv_cache_manager
.
usage
,
prefix_cache_stats
=
self
.
kv_cache_manager
.
make_prefix_cache_stats
(),
spec_decoding_stats
=
spec_decoding_stats
,
)
vllm/v1/metrics/loggers.py
View file @
a79cc68b
...
...
@@ -12,6 +12,7 @@ from vllm.logger import init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingMetrics
logger
=
init_logger
(
__name__
)
...
...
@@ -38,6 +39,7 @@ class LoggingStatLogger(StatLoggerBase):
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
spec_decoding_metrics
=
SpecDecodingMetrics
()
def
_reset
(
self
,
now
):
self
.
last_log_time
=
now
...
...
@@ -65,6 +67,10 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_metrics
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
...
...
@@ -94,6 +100,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_metrics
.
log
()
class
PrometheusStatLogger
(
StatLoggerBase
):
...
...
@@ -302,6 +311,24 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
labelname_running_lora_adapters
,
])
#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self
.
counter_spec_decode_num_draft_tokens
=
\
prometheus_client
.
Counter
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_spec_decode_num_accepted_tokens
=
\
prometheus_client
.
Counter
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
#
# Cache config info metric
#
...
...
@@ -338,6 +365,12 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
counter_spec_decode_num_draft_tokens
.
inc
(
scheduler_stats
.
spec_decoding_stats
.
num_draft_tokens
)
self
.
counter_spec_decode_num_accepted_tokens
.
inc
(
scheduler_stats
.
spec_decoding_stats
.
num_accepted_tokens
)
if
iteration_stats
is
None
:
return
...
...
vllm/v1/metrics/stats.py
View file @
a79cc68b
...
...
@@ -4,6 +4,8 @@ import time
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
if
TYPE_CHECKING
:
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine.output_processor
import
RequestState
...
...
@@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats
:
PrefixCacheStats
=
field
(
default_factory
=
PrefixCacheStats
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
@
dataclass
class
LoRAStats
:
...
...
vllm/v1/spec_decode/metrics.py
0 → 100644
View file @
a79cc68b
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
numpy
as
np
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
class
SpecDecodingStats
:
num_draft_tokens
:
int
=
0
num_accepted_tokens
:
int
=
0
def
take
(
self
):
copied
=
SpecDecodingStats
(
self
.
num_draft_tokens
,
self
.
num_accepted_tokens
)
self
.
reset
()
return
copied
def
reset
(
self
):
self
.
num_draft_tokens
=
0
self
.
num_accepted_tokens
=
0
def
observe
(
self
,
num_draft_tokens
:
int
,
num_accepted_tokens
:
int
):
self
.
num_draft_tokens
+=
num_draft_tokens
self
.
num_accepted_tokens
+=
num_accepted_tokens
class
SpecDecodingMetrics
:
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
num_draft_tokens
:
list
[
int
]
=
[]
self
.
num_accepted_tokens
:
list
[
int
]
=
[]
def
observe
(
self
,
spec_decoding_stats
:
SpecDecodingStats
):
self
.
num_draft_tokens
.
append
(
spec_decoding_stats
.
num_draft_tokens
)
self
.
num_accepted_tokens
.
append
(
spec_decoding_stats
.
num_accepted_tokens
)
def
log
(
self
):
num_draft_tokens
=
np
.
sum
(
self
.
num_draft_tokens
)
num_accepted_tokens
=
np
.
sum
(
self
.
num_accepted_tokens
)
draft_acceptance_rate
=
(
num_accepted_tokens
/
num_draft_tokens
if
num_draft_tokens
>
0
else
float
(
"nan"
))
logger
.
info
(
"Speculative metrics: "
"Draft acceptance rate: %.3f, "
"Number of accepted tokens: %d, "
"Number of draft tokens: %d, "
,
draft_acceptance_rate
,
num_accepted_tokens
,
num_draft_tokens
)
self
.
reset
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment