Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f0b9933
Unverified
Commit
5f0b9933
authored
Jul 17, 2024
by
Antoni Baum
Committed by
GitHub
Jul 17, 2024
Browse files
[Bugfix] Fix Ray Metrics API usage (#6354)
parent
a38524f3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
195 additions
and
40 deletions
+195
-40
tests/metrics/test_metrics.py
tests/metrics/test_metrics.py
+54
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+19
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-0
vllm/engine/metrics.py
vllm/engine/metrics.py
+120
-40
No files found.
tests/metrics/test_metrics.py
View file @
5f0b9933
from
typing
import
List
import
pytest
import
ray
from
prometheus_client
import
REGISTRY
from
vllm
import
EngineArgs
,
LLMEngine
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.metrics
import
RayPrometheusStatLogger
from
vllm.sampling_params
import
SamplingParams
MODELS
=
[
...
...
@@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
labels
)
assert
(
metric_value
==
num_requests
),
"Metrics should be collected"
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
16
])
def
test_engine_log_metrics_ray
(
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
# This test is quite weak - it only checks that we can use
# RayPrometheusStatLogger without exceptions.
# Checking whether the metrics are actually emitted is unfortunately
# non-trivial.
# We have to run in a Ray task for Ray metrics to be emitted correctly
@
ray
.
remote
(
num_gpus
=
1
)
def
_inner
():
class
_RayPrometheusStatLogger
(
RayPrometheusStatLogger
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_i
=
0
super
().
__init__
(
*
args
,
**
kwargs
)
def
log
(
self
,
*
args
,
**
kwargs
):
self
.
_i
+=
1
return
super
().
log
(
*
args
,
**
kwargs
)
engine_args
=
EngineArgs
(
model
=
model
,
dtype
=
dtype
,
disable_log_stats
=
False
,
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
logger
=
_RayPrometheusStatLogger
(
local_interval
=
0.5
,
labels
=
dict
(
model_name
=
engine
.
model_config
.
served_model_name
),
max_model_len
=
engine
.
model_config
.
max_model_len
)
engine
.
add_logger
(
"ray"
,
logger
)
for
i
,
prompt
in
enumerate
(
example_prompts
):
engine
.
add_request
(
f
"request-id-
{
i
}
"
,
prompt
,
SamplingParams
(
max_tokens
=
max_tokens
),
)
while
engine
.
has_unfinished_requests
():
engine
.
step
()
assert
logger
.
_i
>
0
,
".log must be called at least once"
ray
.
get
(
_inner
.
remote
())
vllm/engine/async_llm_engine.py
View file @
5f0b9933
...
...
@@ -12,6 +12,7 @@ from vllm.core.scheduler import SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.logger
import
init_logger
...
...
@@ -389,6 +390,7 @@ class AsyncLLMEngine:
engine_args
:
AsyncEngineArgs
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
...
...
@@ -451,6 +453,7 @@ class AsyncLLMEngine:
max_log_len
=
engine_args
.
max_log_len
,
start_engine_loop
=
start_engine_loop
,
usage_context
=
usage_context
,
stat_loggers
=
stat_loggers
,
)
return
engine
...
...
@@ -957,3 +960,19 @@ class AsyncLLMEngine:
)
else
:
return
self
.
engine
.
is_tracing_enabled
()
def
add_logger
(
self
,
logger_name
:
str
,
logger
:
StatLoggerBase
)
->
None
:
if
self
.
engine_use_ray
:
ray
.
get
(
self
.
engine
.
add_logger
.
remote
(
# type: ignore
logger_name
=
logger_name
,
logger
=
logger
))
else
:
self
.
engine
.
add_logger
(
logger_name
=
logger_name
,
logger
=
logger
)
def
remove_logger
(
self
,
logger_name
:
str
)
->
None
:
if
self
.
engine_use_ray
:
ray
.
get
(
self
.
engine
.
remove_logger
.
remote
(
# type: ignore
logger_name
=
logger_name
))
else
:
self
.
engine
.
remove_logger
(
logger_name
=
logger_name
)
vllm/engine/llm_engine.py
View file @
5f0b9933
...
...
@@ -379,6 +379,7 @@ class LLMEngine:
cls
,
engine_args
:
EngineArgs
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
...
...
@@ -423,6 +424,7 @@ class LLMEngine:
executor_class
=
executor_class
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
,
stat_loggers
=
stat_loggers
,
)
return
engine
...
...
vllm/engine/metrics.py
View file @
5f0b9933
...
...
@@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
class
Metrics
:
labelname_finish_reason
=
"finished_reason"
_base_library
=
prometheus_client
_gauge_cls
=
prometheus_client
.
Gauge
_counter_cls
=
prometheus_client
.
Counter
_histogram_cls
=
prometheus_client
.
Histogram
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
# Unregister any existing vLLM collectors
self
.
_unregister_vllm_metrics
()
# Config Information
self
.
info_cache_config
=
prometheus_client
.
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
self
.
_create_info_cache_config
()
# System stats
# Scheduler State
self
.
gauge_scheduler_running
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_running
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests currently running on GPU."
,
labelnames
=
labelnames
)
self
.
gauge_scheduler_waiting
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_waiting
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
)
self
.
gauge_scheduler_swapped
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_swapped
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
labelnames
=
labelnames
)
# KV Cache Usage in %
self
.
gauge_gpu_cache_usage
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_gpu_cache_usage
=
self
.
_
gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
self
.
gauge_cpu_cache_usage
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_cpu_cache_usage
=
self
.
_
gauge_cls
(
name
=
"vllm:cpu_cache_usage_perc"
,
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
# Iteration stats
self
.
counter_num_preemption
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_num_preemption
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:num_preemptions_total"
,
documentation
=
"Cumulative number of preemption from the engine."
,
labelnames
=
labelnames
)
self
.
counter_prompt_tokens
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_prompt_tokens
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:prompt_tokens_total"
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
)
self
.
counter_generation_tokens
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_generation_tokens
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:generation_tokens_total"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
)
self
.
histogram_time_to_first_token
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_time_to_first_token
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labelnames
,
...
...
@@ -86,7 +86,7 @@ class Metrics:
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
])
self
.
histogram_time_per_output_token
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_time_per_output_token
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
labelnames
=
labelnames
,
...
...
@@ -97,83 +97,157 @@ class Metrics:
# Request stats
# Latency
self
.
histogram_e2e_time_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_e2e_time_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:e2e_request_latency_seconds"
,
documentation
=
"Histogram of end to end request latency in seconds."
,
labelnames
=
labelnames
,
buckets
=
[
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
])
# Metadata
self
.
histogram_num_prompt_tokens_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_num_prompt_tokens_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_prompt_tokens"
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_num_generation_tokens_request
=
\
self
.
_
base_library
.
H
istogram
(
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_best_of_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_best_of_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_params_best_of"
,
documentation
=
"Histogram of the best_of request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
histogram_n_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_n_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
counter_request_success
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_request_success
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:request_success_total"
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Speculatie decoding stats
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_
gauge_cls
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
labelnames
=
labelnames
)
self
.
gauge_spec_decode_efficiency
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_spec_decode_efficiency
=
self
.
_
gauge_cls
(
name
=
"vllm:spec_decode_efficiency"
,
documentation
=
"Speculative decoding system efficiency."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_base_library
.
Counter
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_base_library
.
Counter
(
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
_base_library
.
Counter
(
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
# Deprecated in favor of vllm:prompt_tokens_total
self
.
gauge_avg_prompt_throughput
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_avg_prompt_throughput
=
self
.
_
gauge_cls
(
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
documentation
=
"Average prefill throughput in tokens/s."
,
labelnames
=
labelnames
,
)
# Deprecated in favor of vllm:generation_tokens_total
self
.
gauge_avg_generation_throughput
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_avg_generation_throughput
=
self
.
_
gauge_cls
(
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
documentation
=
"Average generation throughput in tokens/s."
,
labelnames
=
labelnames
,
)
def
_create_info_cache_config
(
self
)
->
None
:
# Config Information
self
.
info_cache_config
=
prometheus_client
.
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
def
_unregister_vllm_metrics
(
self
)
->
None
:
for
collector
in
list
(
self
.
_base_library
.
REGISTRY
.
_collector_to_names
):
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
self
.
_base_library
.
REGISTRY
.
unregister
(
collector
)
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
# end-metrics-definitions
class
_RayGaugeWrapper
:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_gauge
=
ray_metrics
.
Gauge
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
labels
(
self
,
**
labels
):
self
.
_gauge
.
set_default_tags
(
labels
)
return
self
def
set
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
_gauge
.
set
(
value
)
class
_RayCounterWrapper
:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_counter
=
ray_metrics
.
Counter
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
labels
(
self
,
**
labels
):
self
.
_counter
.
set_default_tags
(
labels
)
return
self
def
inc
(
self
,
value
:
Union
[
int
,
float
]
=
1.0
):
if
value
==
0
:
return
return
self
.
_counter
.
inc
(
value
)
class
_RayHistogramWrapper
:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
,
buckets
:
Optional
[
List
[
float
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_histogram
=
ray_metrics
.
Histogram
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
,
boundaries
=
buckets
)
def
labels
(
self
,
**
labels
):
self
.
_histogram
.
set_default_tags
(
labels
)
return
self
def
observe
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
_histogram
.
observe
(
value
)
class
RayMetrics
(
Metrics
):
...
...
@@ -181,7 +255,9 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library
=
ray_metrics
_gauge_cls
=
_RayGaugeWrapper
_counter_cls
=
_RayCounterWrapper
_histogram_cls
=
_RayHistogramWrapper
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
if
ray_metrics
is
None
:
...
...
@@ -192,8 +268,9 @@ class RayMetrics(Metrics):
# No-op on purpose
pass
# end-metrics-definitions
def
_create_info_cache_config
(
self
)
->
None
:
# No-op on purpose
pass
def
build_1_2_5_buckets
(
max_value
:
int
)
->
List
[
int
]:
...
...
@@ -498,3 +575,6 @@ class PrometheusStatLogger(StatLoggerBase):
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls
=
RayMetrics
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment