Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ad1bc7a
Unverified
Commit
2ad1bc7a
authored
Feb 15, 2025
by
Mark McLoughlin
Committed by
GitHub
Feb 15, 2025
Browse files
[V1][Metrics] Add iteration_tokens_total histogram from V0 (#13288)
parent
7fdaaf48
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
8 deletions
+35
-8
tests/entrypoints/openai/test_metrics.py
tests/entrypoints/openai/test_metrics.py
+10
-3
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+1
-1
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+24
-4
No files found.
tests/entrypoints/openai/test_metrics.py
View file @
2ad1bc7a
...
@@ -96,9 +96,14 @@ EXPECTED_VALUES = {
...
@@ -96,9 +96,14 @@ EXPECTED_VALUES = {
[(
"_sum"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
[(
"_sum"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
(
"_count"
,
_NUM_REQUESTS
)],
(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_params_n"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_params_n"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_params_max_tokens"
:
"vllm:request_params_max_tokens"
:
[
[(
"_sum"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
(
"_sum"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
(
"_count"
,
_NUM_REQUESTS
)],
(
"_count"
,
_NUM_REQUESTS
)
],
"vllm:iteration_tokens_total"
:
[(
"_sum"
,
_NUM_REQUESTS
*
(
_NUM_PROMPT_TOKENS_PER_REQUEST
+
_NUM_GENERATION_TOKENS_PER_REQUEST
)),
(
"_count"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
)],
"vllm:prompt_tokens"
:
[(
"_total"
,
"vllm:prompt_tokens"
:
[(
"_total"
,
_NUM_REQUESTS
*
_NUM_PROMPT_TOKENS_PER_REQUEST
)],
_NUM_REQUESTS
*
_NUM_PROMPT_TOKENS_PER_REQUEST
)],
"vllm:generation_tokens"
:
[
"vllm:generation_tokens"
:
[
...
@@ -197,6 +202,7 @@ EXPECTED_METRICS = [
...
@@ -197,6 +202,7 @@ EXPECTED_METRICS = [
"vllm:request_params_max_tokens_sum"
,
"vllm:request_params_max_tokens_sum"
,
"vllm:request_params_max_tokens_bucket"
,
"vllm:request_params_max_tokens_bucket"
,
"vllm:request_params_max_tokens_count"
,
"vllm:request_params_max_tokens_count"
,
"vllm:iteration_tokens_total"
,
"vllm:num_preemptions_total"
,
"vllm:num_preemptions_total"
,
"vllm:prompt_tokens_total"
,
"vllm:prompt_tokens_total"
,
"vllm:generation_tokens_total"
,
"vllm:generation_tokens_total"
,
...
@@ -223,6 +229,7 @@ EXPECTED_METRICS_V1 = [
...
@@ -223,6 +229,7 @@ EXPECTED_METRICS_V1 = [
"vllm:gpu_prefix_cache_hits"
,
"vllm:gpu_prefix_cache_hits"
,
"vllm:prompt_tokens_total"
,
"vllm:prompt_tokens_total"
,
"vllm:generation_tokens_total"
,
"vllm:generation_tokens_total"
,
"vllm:iteration_tokens_total"
,
"vllm:request_success_total"
,
"vllm:request_success_total"
,
"vllm:request_prompt_tokens_sum"
,
"vllm:request_prompt_tokens_sum"
,
"vllm:request_prompt_tokens_bucket"
,
"vllm:request_prompt_tokens_bucket"
,
...
...
vllm/v1/engine/async_llm.py
View file @
2ad1bc7a
...
@@ -57,7 +57,7 @@ class AsyncLLM(EngineClient):
...
@@ -57,7 +57,7 @@ class AsyncLLM(EngineClient):
if
self
.
log_stats
:
if
self
.
log_stats
:
self
.
stat_loggers
.
extend
([
self
.
stat_loggers
.
extend
([
LoggingStatLogger
(),
LoggingStatLogger
(),
PrometheusStatLogger
(
vllm_config
.
model_config
),
PrometheusStatLogger
(
vllm_config
),
])
])
# Tokenizer (+ ensure liveness if running in another process).
# Tokenizer (+ ensure liveness if running in another process).
...
...
vllm/v1/metrics/loggers.py
View file @
2ad1bc7a
...
@@ -7,7 +7,7 @@ from typing import Dict, List
...
@@ -7,7 +7,7 @@ from typing import Dict, List
import
numpy
as
np
import
numpy
as
np
import
prometheus_client
import
prometheus_client
from
vllm.config
import
Model
Config
from
vllm.config
import
Vllm
Config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.engine
import
FinishReason
...
@@ -92,13 +92,13 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -92,13 +92,13 @@ class LoggingStatLogger(StatLoggerBase):
class
PrometheusStatLogger
(
StatLoggerBase
):
class
PrometheusStatLogger
(
StatLoggerBase
):
def
__init__
(
self
,
model
_config
:
Model
Config
):
def
__init__
(
self
,
vllm
_config
:
Vllm
Config
):
self
.
_unregister_vllm_metrics
()
self
.
_unregister_vllm_metrics
()
labelnames
=
[
"model_name"
]
labelnames
=
[
"model_name"
]
labelvalues
=
[
model_config
.
served_model_name
]
labelvalues
=
[
vllm_config
.
model_config
.
served_model_name
]
max_model_len
=
model_config
.
max_model_len
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
gauge_scheduler_running
=
prometheus_client
.
Gauge
(
self
.
gauge_scheduler_running
=
prometheus_client
.
Gauge
(
name
=
"vllm:num_requests_running"
,
name
=
"vllm:num_requests_running"
,
...
@@ -162,6 +162,13 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -162,6 +162,13 @@ class PrometheusStatLogger(StatLoggerBase):
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_iteration_tokens
=
\
prometheus_client
.
Histogram
(
name
=
"vllm:iteration_tokens_total"
,
documentation
=
"Histogram of number of tokens per engine_step."
,
buckets
=
build_cudagraph_buckets
(
vllm_config
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_time_to_first_token
=
\
self
.
histogram_time_to_first_token
=
\
prometheus_client
.
Histogram
(
prometheus_client
.
Histogram
(
name
=
"vllm:time_to_first_token_seconds"
,
name
=
"vllm:time_to_first_token_seconds"
,
...
@@ -237,6 +244,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -237,6 +244,9 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
counter_prompt_tokens
.
inc
(
iteration_stats
.
num_prompt_tokens
)
self
.
counter_prompt_tokens
.
inc
(
iteration_stats
.
num_prompt_tokens
)
self
.
counter_generation_tokens
.
inc
(
self
.
counter_generation_tokens
.
inc
(
iteration_stats
.
num_generation_tokens
)
iteration_stats
.
num_generation_tokens
)
self
.
histogram_iteration_tokens
.
observe
(
iteration_stats
.
num_prompt_tokens
+
\
iteration_stats
.
num_generation_tokens
)
for
finished_request
in
iteration_stats
.
finished_requests
:
for
finished_request
in
iteration_stats
.
finished_requests
:
self
.
counter_request_success
[
finished_request
.
finish_reason
].
inc
()
self
.
counter_request_success
[
finished_request
.
finish_reason
].
inc
()
...
@@ -293,3 +303,13 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
...
@@ -293,3 +303,13 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
[1, 2, 5, 10, 20, 50, 100]
[1, 2, 5, 10, 20, 50, 100]
"""
"""
return
build_buckets
([
1
,
2
,
5
],
max_value
)
return
build_buckets
([
1
,
2
,
5
],
max_value
)
def
build_cudagraph_buckets
(
vllm_config
:
VllmConfig
)
->
List
[
int
]:
if
not
vllm_config
.
model_config
.
enforce_eager
:
buckets
=
vllm_config
.
compilation_config
.
\
cudagraph_capture_sizes
.
copy
()
buckets
.
sort
()
return
buckets
else
:
return
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8096
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment