Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54181767
Unverified
Commit
54181767
authored
May 16, 2025
by
Seiji Eicher
Committed by
GitHub
May 16, 2025
Browse files
[Misc] Add Ray Prometheus logger to V1 (#17925)
Signed-off-by:
Seiji Eicher
<
seiji@anyscale.com
>
parent
67da5720
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
223 additions
and
35 deletions
+223
-35
tests/v1/metrics/test_ray_metrics.py
tests/v1/metrics/test_ray_metrics.py
+57
-0
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+29
-25
vllm/v1/metrics/ray_wrappers.py
vllm/v1/metrics/ray_wrappers.py
+120
-0
vllm/v1/spec_decode/metrics.py
vllm/v1/spec_decode/metrics.py
+17
-10
No files found.
tests/v1/metrics/test_ray_metrics.py
0 → 100644
View file @
54181767
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
ray
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.engine.async_llm
import
AsyncEngineArgs
,
AsyncLLM
from
vllm.v1.metrics.ray_wrappers
import
RayPrometheusStatLogger
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v1_only
(
monkeypatch
):
"""
The change relies on V1 APIs, so set VLLM_USE_V1=1.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'1'
)
MODELS
=
[
"distilbert/distilgpt2"
,
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
16
])
def
test_engine_log_metrics_ray
(
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
""" Simple smoke test, verifying this can be used without exceptions.
Need to start a Ray cluster in order to verify outputs."""
@
ray
.
remote
(
num_gpus
=
1
)
class
EngineTestActor
:
async
def
run
(
self
):
engine_args
=
AsyncEngineArgs
(
model
=
model
,
dtype
=
dtype
,
disable_log_stats
=
False
,
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
,
stat_loggers
=
[
RayPrometheusStatLogger
])
for
i
,
prompt
in
enumerate
(
example_prompts
):
engine
.
generate
(
request_id
=
f
"request-id-
{
i
}
"
,
prompt
=
prompt
,
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
),
)
# Create the actor and call the async method
actor
=
EngineTestActor
.
remote
()
# type: ignore[attr-defined]
ray
.
get
(
actor
.
run
.
remote
())
vllm/v1/metrics/loggers.py
View file @
54181767
...
@@ -138,6 +138,10 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -138,6 +138,10 @@ class LoggingStatLogger(StatLoggerBase):
class
PrometheusStatLogger
(
StatLoggerBase
):
class
PrometheusStatLogger
(
StatLoggerBase
):
_gauge_cls
=
prometheus_client
.
Gauge
_counter_cls
=
prometheus_client
.
Counter
_histogram_cls
=
prometheus_client
.
Histogram
_spec_decoding_cls
=
SpecDecodingProm
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
self
.
_unregister_vllm_metrics
()
self
.
_unregister_vllm_metrics
()
...
@@ -156,18 +160,18 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -156,18 +160,18 @@ class PrometheusStatLogger(StatLoggerBase):
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
spec_decoding_prom
=
S
pec
D
ecoding
Prom
(
self
.
spec_decoding_prom
=
self
.
_s
pec
_d
ecoding
_cls
(
vllm_config
.
speculative_config
,
labelnames
,
labelvalues
)
vllm_config
.
speculative_config
,
labelnames
,
labelvalues
)
#
#
# Scheduler state
# Scheduler state
#
#
self
.
gauge_scheduler_running
=
prometheus_client
.
Gauge
(
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_running"
,
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests in model execution batches."
,
documentation
=
"Number of requests in model execution batches."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
gauge_scheduler_waiting
=
prometheus_client
.
Gauge
(
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
...
@@ -175,18 +179,18 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -175,18 +179,18 @@ class PrometheusStatLogger(StatLoggerBase):
#
#
# GPU cache
# GPU cache
#
#
self
.
gauge_gpu_cache_usage
=
prometheus_client
.
Gauge
(
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_gpu_prefix_cache_queries
=
prometheus_client
.
C
ounter
(
self
.
counter_gpu_prefix_cache_queries
=
self
.
_c
ounter
_cls
(
name
=
"vllm:gpu_prefix_cache_queries"
,
name
=
"vllm:gpu_prefix_cache_queries"
,
documentation
=
documentation
=
"GPU prefix cache queries, in terms of number of queried tokens."
,
"GPU prefix cache queries, in terms of number of queried tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_gpu_prefix_cache_hits
=
prometheus_client
.
C
ounter
(
self
.
counter_gpu_prefix_cache_hits
=
self
.
_c
ounter
_cls
(
name
=
"vllm:gpu_prefix_cache_hits"
,
name
=
"vllm:gpu_prefix_cache_hits"
,
documentation
=
documentation
=
"GPU prefix cache hits, in terms of number of cached tokens."
,
"GPU prefix cache hits, in terms of number of cached tokens."
,
...
@@ -195,24 +199,24 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -195,24 +199,24 @@ class PrometheusStatLogger(StatLoggerBase):
#
#
# Counters
# Counters
#
#
self
.
counter_num_preempted_reqs
=
prometheus_client
.
C
ounter
(
self
.
counter_num_preempted_reqs
=
self
.
_c
ounter
_cls
(
name
=
"vllm:num_preemptions_total"
,
name
=
"vllm:num_preemptions_total"
,
documentation
=
"Cumulative number of preemption from the engine."
,
documentation
=
"Cumulative number of preemption from the engine."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_prompt_tokens
=
prometheus_client
.
C
ounter
(
self
.
counter_prompt_tokens
=
self
.
_c
ounter
_cls
(
name
=
"vllm:prompt_tokens_total"
,
name
=
"vllm:prompt_tokens_total"
,
documentation
=
"Number of prefill tokens processed."
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_generation_tokens
=
prometheus_client
.
C
ounter
(
self
.
counter_generation_tokens
=
self
.
_c
ounter
_cls
(
name
=
"vllm:generation_tokens_total"
,
name
=
"vllm:generation_tokens_total"
,
documentation
=
"Number of generation tokens processed."
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_request_success
:
dict
[
FinishReason
,
self
.
counter_request_success
:
dict
[
FinishReason
,
prometheus_client
.
Counter
]
=
{}
prometheus_client
.
Counter
]
=
{}
counter_request_success_base
=
prometheus_client
.
C
ounter
(
counter_request_success_base
=
self
.
_c
ounter
_cls
(
name
=
"vllm:request_success_total"
,
name
=
"vllm:request_success_total"
,
documentation
=
"Count of successfully processed requests."
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
"finished_reason"
])
labelnames
=
labelnames
+
[
"finished_reason"
])
...
@@ -225,21 +229,21 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -225,21 +229,21 @@ class PrometheusStatLogger(StatLoggerBase):
# Histograms of counts
# Histograms of counts
#
#
self
.
histogram_num_prompt_tokens_request
=
\
self
.
histogram_num_prompt_tokens_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_prompt_tokens"
,
name
=
"vllm:request_prompt_tokens"
,
documentation
=
"Number of prefill tokens processed."
,
documentation
=
"Number of prefill tokens processed."
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_num_generation_tokens_request
=
\
self
.
histogram_num_generation_tokens_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_generation_tokens"
,
name
=
"vllm:request_generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
documentation
=
"Number of generation tokens processed."
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_iteration_tokens
=
\
self
.
histogram_iteration_tokens
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:iteration_tokens_total"
,
name
=
"vllm:iteration_tokens_total"
,
documentation
=
"Histogram of number of tokens per engine_step."
,
documentation
=
"Histogram of number of tokens per engine_step."
,
buckets
=
[
buckets
=
[
...
@@ -249,7 +253,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -249,7 +253,7 @@ class PrometheusStatLogger(StatLoggerBase):
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_max_num_generation_tokens_request
=
\
self
.
histogram_max_num_generation_tokens_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_max_num_generation_tokens"
,
name
=
"vllm:request_max_num_generation_tokens"
,
documentation
=
documentation
=
"Histogram of maximum number of requested generation tokens."
,
"Histogram of maximum number of requested generation tokens."
,
...
@@ -257,14 +261,14 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -257,14 +261,14 @@ class PrometheusStatLogger(StatLoggerBase):
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_n_request
=
\
self
.
histogram_n_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_params_n"
,
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
documentation
=
"Histogram of the n request parameter."
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
buckets
=
[
1
,
2
,
5
,
10
,
20
],
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_max_tokens_request
=
\
self
.
histogram_max_tokens_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_params_max_tokens"
,
name
=
"vllm:request_params_max_tokens"
,
documentation
=
"Histogram of the max_tokens request parameter."
,
documentation
=
"Histogram of the max_tokens request parameter."
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
...
@@ -274,7 +278,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -274,7 +278,7 @@ class PrometheusStatLogger(StatLoggerBase):
# Histogram of timing intervals
# Histogram of timing intervals
#
#
self
.
histogram_time_to_first_token
=
\
self
.
histogram_time_to_first_token
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:time_to_first_token_seconds"
,
name
=
"vllm:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
documentation
=
"Histogram of time to first token in seconds."
,
buckets
=
[
buckets
=
[
...
@@ -285,7 +289,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -285,7 +289,7 @@ class PrometheusStatLogger(StatLoggerBase):
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_time_per_output_token
=
\
self
.
histogram_time_per_output_token
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:time_per_output_token_seconds"
,
name
=
"vllm:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
documentation
=
"Histogram of time per output token in seconds."
,
buckets
=
[
buckets
=
[
...
@@ -299,34 +303,34 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -299,34 +303,34 @@ class PrometheusStatLogger(StatLoggerBase):
40.0
,
50.0
,
60.0
,
120.0
,
240.0
,
480.0
,
960.0
,
1920.0
,
7680.0
40.0
,
50.0
,
60.0
,
120.0
,
240.0
,
480.0
,
960.0
,
1920.0
,
7680.0
]
]
self
.
histogram_e2e_time_request
=
\
self
.
histogram_e2e_time_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:e2e_request_latency_seconds"
,
name
=
"vllm:e2e_request_latency_seconds"
,
documentation
=
"Histogram of e2e request latency in seconds."
,
documentation
=
"Histogram of e2e request latency in seconds."
,
buckets
=
request_latency_buckets
,
buckets
=
request_latency_buckets
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_queue_time_request
=
\
self
.
histogram_queue_time_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_queue_time_seconds"
,
name
=
"vllm:request_queue_time_seconds"
,
documentation
=
documentation
=
"Histogram of time spent in WAITING phase for request."
,
"Histogram of time spent in WAITING phase for request."
,
buckets
=
request_latency_buckets
,
buckets
=
request_latency_buckets
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_inference_time_request
=
\
self
.
histogram_inference_time_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_inference_time_seconds"
,
name
=
"vllm:request_inference_time_seconds"
,
documentation
=
documentation
=
"Histogram of time spent in RUNNING phase for request."
,
"Histogram of time spent in RUNNING phase for request."
,
buckets
=
request_latency_buckets
,
buckets
=
request_latency_buckets
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_prefill_time_request
=
\
self
.
histogram_prefill_time_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_prefill_time_seconds"
,
name
=
"vllm:request_prefill_time_seconds"
,
documentation
=
documentation
=
"Histogram of time spent in PREFILL phase for request."
,
"Histogram of time spent in PREFILL phase for request."
,
buckets
=
request_latency_buckets
,
buckets
=
request_latency_buckets
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
histogram_decode_time_request
=
\
self
.
histogram_decode_time_request
=
\
prometheus_client
.
H
istogram
(
self
.
_h
istogram
_cls
(
name
=
"vllm:request_decode_time_seconds"
,
name
=
"vllm:request_decode_time_seconds"
,
documentation
=
documentation
=
"Histogram of time spent in DECODE phase for request."
,
"Histogram of time spent in DECODE phase for request."
,
...
@@ -343,7 +347,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -343,7 +347,7 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
labelname_running_lora_adapters
=
"running_lora_adapters"
self
.
labelname_running_lora_adapters
=
"running_lora_adapters"
self
.
max_lora
=
vllm_config
.
lora_config
.
max_loras
self
.
max_lora
=
vllm_config
.
lora_config
.
max_loras
self
.
gauge_lora_info
=
\
self
.
gauge_lora_info
=
\
prometheus_client
.
Gauge
(
self
.
_gauge_cls
(
name
=
"vllm:lora_requests_info"
,
name
=
"vllm:lora_requests_info"
,
documentation
=
"Running stats on lora requests."
,
documentation
=
"Running stats on lora requests."
,
labelnames
=
[
labelnames
=
[
...
@@ -365,7 +369,7 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -365,7 +369,7 @@ class PrometheusStatLogger(StatLoggerBase):
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
# info here with a gauge.
info_gauge
=
prometheus_client
.
Gauge
(
info_gauge
=
self
.
_gauge_cls
(
name
=
name
,
name
=
name
,
documentation
=
documentation
,
documentation
=
documentation
,
labelnames
=
metrics_info
.
keys
()).
labels
(
**
metrics_info
)
labelnames
=
metrics_info
.
keys
()).
labels
(
**
metrics_info
)
...
...
vllm/v1/metrics/ray_wrappers.py
0 → 100644
View file @
54181767
# SPDX-License-Identifier: Apache-2.0
import
time
from
typing
import
Optional
,
Union
from
vllm.config
import
VllmConfig
from
vllm.v1.metrics.loggers
import
PrometheusStatLogger
from
vllm.v1.spec_decode.metrics
import
SpecDecodingProm
try
:
from
ray.util
import
metrics
as
ray_metrics
from
ray.util.metrics
import
Metric
except
ImportError
:
ray_metrics
=
None
class
RayPrometheusMetric
:
def
__init__
(
self
):
if
ray_metrics
is
None
:
raise
ImportError
(
"RayPrometheusMetric requires Ray to be installed."
)
self
.
metric
:
Metric
=
None
def
labels
(
self
,
*
labels
,
**
labelskwargs
):
if
labelskwargs
:
for
k
,
v
in
labelskwargs
.
items
():
if
not
isinstance
(
v
,
str
):
labelskwargs
[
k
]
=
str
(
v
)
self
.
metric
.
set_default_tags
(
labelskwargs
)
return
self
class
RayGaugeWrapper
(
RayPrometheusMetric
):
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def
__init__
(
self
,
name
:
str
,
documentation
:
Optional
[
str
]
=
""
,
labelnames
:
Optional
[
list
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
metric
=
ray_metrics
.
Gauge
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
set
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
metric
.
set
(
value
)
def
set_to_current_time
(
self
):
# ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
return
self
.
metric
.
set
(
time
.
time
())
class
RayCounterWrapper
(
RayPrometheusMetric
):
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def
__init__
(
self
,
name
:
str
,
documentation
:
Optional
[
str
]
=
""
,
labelnames
:
Optional
[
list
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
metric
=
ray_metrics
.
Counter
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
inc
(
self
,
value
:
Union
[
int
,
float
]
=
1.0
):
if
value
==
0
:
return
return
self
.
metric
.
inc
(
value
)
class
RayHistogramWrapper
(
RayPrometheusMetric
):
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def
__init__
(
self
,
name
:
str
,
documentation
:
Optional
[
str
]
=
""
,
labelnames
:
Optional
[
list
[
str
]]
=
None
,
buckets
:
Optional
[
list
[
float
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
boundaries
=
buckets
if
buckets
else
[]
self
.
metric
=
ray_metrics
.
Histogram
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
,
boundaries
=
boundaries
)
def
observe
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
metric
.
observe
(
value
)
class
RaySpecDecodingProm
(
SpecDecodingProm
):
"""
RaySpecDecodingProm is used by RayMetrics to log to Ray metrics.
Provides the same metrics as SpecDecodingProm but uses Ray's
util.metrics library.
"""
_counter_cls
=
RayCounterWrapper
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_gauge_cls
=
RayGaugeWrapper
_counter_cls
=
RayCounterWrapper
_histogram_cls
=
RayHistogramWrapper
_spec_decoding_cls
=
RaySpecDecodingProm
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
super
().
__init__
(
vllm_config
,
engine_index
)
@
staticmethod
def
_unregister_vllm_metrics
():
# No-op on purpose
pass
vllm/v1/spec_decode/metrics.py
View file @
54181767
...
@@ -120,24 +120,30 @@ class SpecDecodingProm:
...
@@ -120,24 +120,30 @@ class SpecDecodingProm:
vllm:spec_decode_num_drafts[$interval]
vllm:spec_decode_num_drafts[$interval]
"""
"""
def
__init__
(
self
,
speculative_config
:
Optional
[
SpeculativeConfig
],
_counter_cls
=
prometheus_client
.
Counter
labelnames
:
list
[
str
],
labelvalues
:
list
[
str
]):
def
__init__
(
self
,
speculative_config
:
Optional
[
SpeculativeConfig
],
labelnames
:
list
[
str
],
labelvalues
:
list
[
str
],
):
self
.
spec_decoding_enabled
=
speculative_config
is
not
None
self
.
spec_decoding_enabled
=
speculative_config
is
not
None
if
not
self
.
spec_decoding_enabled
:
if
not
self
.
spec_decoding_enabled
:
return
return
self
.
counter_spec_decode_num_drafts
=
\
self
.
counter_spec_decode_num_drafts
=
\
prometheus_client
.
C
ounter
(
self
.
_c
ounter
_cls
(
name
=
"vllm:spec_decode_num_drafts_total"
,
name
=
"vllm:spec_decode_num_drafts_total"
,
documentation
=
"Number of spec decoding drafts."
,
documentation
=
"Number of spec decoding drafts."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
self
.
counter_spec_decode_num_draft_tokens
=
\
self
.
counter_spec_decode_num_draft_tokens
=
\
prometheus_client
.
C
ounter
(
self
.
_c
ounter
_cls
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
,
).
labels
(
*
labelvalues
)
self
.
counter_spec_decode_num_accepted_tokens
=
\
self
.
counter_spec_decode_num_accepted_tokens
=
\
prometheus_client
.
C
ounter
(
self
.
_c
ounter
_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
labelnames
=
labelnames
).
labels
(
*
labelvalues
)
...
@@ -146,12 +152,13 @@ class SpecDecodingProm:
...
@@ -146,12 +152,13 @@ class SpecDecodingProm:
num_spec_tokens
=
(
speculative_config
.
num_speculative_tokens
num_spec_tokens
=
(
speculative_config
.
num_speculative_tokens
if
self
.
spec_decoding_enabled
else
0
)
if
self
.
spec_decoding_enabled
else
0
)
pos_labelnames
=
labelnames
+
[
"position"
]
pos_labelnames
=
labelnames
+
[
"position"
]
base_counter
=
prometheus_client
.
C
ounter
(
base_counter
=
self
.
_c
ounter
_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens_per_pos"
,
name
=
"vllm:spec_decode_num_accepted_tokens_per_pos"
,
documentation
=
"Accepted tokens per draft position."
,
documentation
=
"Accepted tokens per draft position."
,
labelnames
=
pos_labelnames
)
labelnames
=
pos_labelnames
,
self
.
counter_spec_decode_num_accepted_tokens_per_pos
:
\
)
list
[
prometheus_client
.
Counter
]
=
[]
self
.
counter_spec_decode_num_accepted_tokens_per_pos
:
list
[
prometheus_client
.
Counter
]
=
[]
for
pos
in
range
(
num_spec_tokens
):
for
pos
in
range
(
num_spec_tokens
):
pos_labelvalues
=
labelvalues
+
[
str
(
pos
)]
pos_labelvalues
=
labelvalues
+
[
str
(
pos
)]
self
.
counter_spec_decode_num_accepted_tokens_per_pos
.
append
(
self
.
counter_spec_decode_num_accepted_tokens_per_pos
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment