Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
160e1d8c
Unverified
Commit
160e1d8c
authored
Jul 16, 2024
by
Cody Yu
Committed by
GitHub
Jul 16, 2024
Browse files
[Misc] Log spec decode metrics (#6454)
parent
94162beb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
137 additions
and
14 deletions
+137
-14
tests/metrics/test_metrics.py
tests/metrics/test_metrics.py
+49
-0
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+36
-8
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+12
-6
vllm/engine/metrics.py
vllm/engine/metrics.py
+40
-0
No files found.
tests/metrics/test_metrics.py
View file @
160e1d8c
...
...
@@ -168,6 +168,55 @@ def test_engine_log_metrics_regression(
assert_metrics
(
engine
,
disable_log_stats
,
len
(
example_prompts
))
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
10
])
def
test_metric_spec_decode
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
k
=
5
with
vllm_runner
(
model
,
dtype
=
dtype
,
disable_log_stats
=
False
,
gpu_memory_utilization
=
0.4
,
speculative_model
=
model
,
num_speculative_tokens
=
k
,
use_v2_block_manager
=
True
)
as
vllm_model
:
# Force log interval to be 0 to catch all metrics.
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_loggers
[
'prometheus'
]
stat_logger
.
local_interval
=
0
# Note that the purpose of this test is to verify spec decode
# metrics instead of functional correctness, so the expected values
# are intended to be loose.
metric_name_to_expected_fn
=
{
"gauge_spec_decode_draft_acceptance_rate"
:
lambda
v
:
0
<=
v
<=
1
,
"gauge_spec_decode_efficiency"
:
lambda
v
:
0
<=
v
<=
1
,
"counter_spec_decode_num_accepted_tokens"
:
lambda
v
:
0
<=
v
<=
k
,
"counter_spec_decode_num_draft_tokens"
:
lambda
v
:
v
==
k
,
"counter_spec_decode_num_emitted_tokens"
:
lambda
v
:
0
<=
v
<=
k
+
1
,
}
# Use one request to better inspect the metrics.
prompts
=
example_prompts
[:
1
]
_
=
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
)
for
metric_name
,
is_expected
in
metric_name_to_expected_fn
.
items
():
metric_val
=
getattr
(
stat_logger
.
metrics
,
metric_name
).
labels
(
**
stat_logger
.
labels
).
_value
.
get
()
assert
is_expected
(
metric_val
),
(
f
"the value of metric
{
metric_name
}
(
{
metric_val
}
) "
"does not meet expectation"
)
def
assert_metrics
(
engine
:
LLMEngine
,
disable_log_stats
:
bool
,
num_requests
:
int
)
->
None
:
if
disable_log_stats
:
...
...
tests/spec_decode/e2e/conftest.py
View file @
160e1d8c
...
...
@@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
}
test_name
=
request
.
node
.
name
model
=
kwargs
[
"model"
]
draft_model
=
kwargs
.
get
(
"speculative_model"
,
None
)
same_draft_target_model
=
(
draft_model
is
not
None
and
draft_model
==
model
)
def
generator_inner
():
wait_for_gpu_memory_to_clear
(
...
...
@@ -177,6 +182,13 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
AsyncLLM
(
**
kwargs
)
if
use_async
else
LLM
(
**
kwargs
)
# Override logging interval to 0 for spec decode test run to
# log all metrics in time.
if
(
baseline_or_test
==
"test"
and
not
use_async
and
llm
.
llm_engine
.
log_stats
):
for
sate_logger
in
llm
.
llm_engine
.
stat_loggers
.
values
():
sate_logger
.
local_interval
=
0
set_random_seed
(
seed
)
yield
llm
...
...
@@ -188,6 +200,9 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
yield
llm
del
llm
# Set an attribute to the generator_outer function to allow us to
# determine whether to further check the acceptance rate in tests.
generator_outer
.
same_draft_target_model
=
same_draft_target_model
# type: ignore
return
generator_outer
...
...
@@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm):
def
get_output_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
)
->
Tuple
[
List
[
str
],
List
[
List
[
int
]]]:
sampling_params
)
->
Tuple
[
List
[
str
],
List
[
List
[
int
]]
,
float
]:
tokens
:
List
[
str
]
=
[]
token_ids
:
List
[
List
[
int
]]
=
[]
acceptance_rate
:
float
=
-
1.0
for
llm
in
llm_generator
():
maybe_assert_ngram_worker
(
llm
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
tokens
=
[
output
.
outputs
[
0
].
text
for
output
in
outputs
]
# Fetch acceptance rate if logging is enabled.
if
stat_loggers
:
=
getattr
(
llm
.
llm_engine
,
"stat_loggers"
,
None
):
stat_logger
=
stat_loggers
[
"prometheus"
]
acceptance_rate
=
(
stat_logger
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
.
labels
(
**
stat_logger
.
labels
).
_value
.
get
())
del
llm
return
tokens
,
token_ids
return
tokens
,
token_ids
,
acceptance_rate
def
get_logprobs_from_llm_generator
(
...
...
@@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
print_tokens
:
bool
=
False
):
print_tokens
:
bool
=
False
,
ensure_all_accepted
:
bool
=
False
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
...
...
@@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
temperature
=
temperature
,
)
spec_batch_tokens
,
spec_batch_token_ids
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
spec_batch_tokens
,
spec_batch_token_ids
,
acceptance_rate
)
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
,
_
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_token_ids
)
==
len
(
prompts
)
assert
len
(
spec_batch_token_ids
)
==
len
(
prompts
)
...
...
@@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
if
ensure_all_accepted
:
assert
acceptance_rate
==
1.0
tests/spec_decode/e2e/test_multistep_correctness.py
View file @
160e1d8c
...
...
@@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
temperature
=
temperature
,
)
batch_tokens
,
batch_token_ids
=
get_output_from_llm_generator
(
batch_tokens
,
batch_token_ids
,
_
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
# Expect a generation for each prompt in the batch.
...
...
@@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
ensure_all_accepted
=
test_llm_generator
.
same_draft_target_model
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
,
ensure_all_accepted
=
ensure_all_accepted
)
@
pytest
.
mark
.
parametrize
(
...
...
vllm/engine/metrics.py
View file @
160e1d8c
...
...
@@ -133,6 +133,30 @@ class Metrics:
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Speculatie decoding stats
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
labelnames
=
labelnames
)
self
.
gauge_spec_decode_efficiency
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:spec_decode_efficiency"
,
documentation
=
"Speculative decoding system efficiency."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_base_library
.
Counter
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_base_library
.
Counter
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
_base_library
.
Counter
(
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
# Deprecated in favor of vllm:prompt_tokens_total
self
.
gauge_avg_prompt_throughput
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
...
...
@@ -454,6 +478,22 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_draft_acceptance_rate
,
stats
.
spec_decode_metrics
.
draft_acceptance_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_spec_decode_efficiency
,
stats
.
spec_decode_metrics
.
system_efficiency
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_accepted_tokens
,
stats
.
spec_decode_metrics
.
accepted_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_draft_tokens
,
stats
.
spec_decode_metrics
.
draft_tokens
)
self
.
_log_counter
(
self
.
metrics
.
counter_spec_decode_num_emitted_tokens
,
stats
.
spec_decode_metrics
.
emitted_tokens
)
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment