Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f0b9933
Unverified
Commit
5f0b9933
authored
Jul 17, 2024
by
Antoni Baum
Committed by
GitHub
Jul 17, 2024
Browse files
[Bugfix] Fix Ray Metrics API usage (#6354)
parent
a38524f3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
195 additions
and
40 deletions
+195
-40
tests/metrics/test_metrics.py
tests/metrics/test_metrics.py
+54
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+19
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-0
vllm/engine/metrics.py
vllm/engine/metrics.py
+120
-40
No files found.
tests/metrics/test_metrics.py
View file @
5f0b9933
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
import
ray
from
prometheus_client
import
REGISTRY
from
prometheus_client
import
REGISTRY
from
vllm
import
EngineArgs
,
LLMEngine
from
vllm
import
EngineArgs
,
LLMEngine
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.metrics
import
RayPrometheusStatLogger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
MODELS
=
[
MODELS
=
[
...
@@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
...
@@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
labels
)
labels
)
assert
(
assert
(
metric_value
==
num_requests
),
"Metrics should be collected"
metric_value
==
num_requests
),
"Metrics should be collected"
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
16
])
def
test_engine_log_metrics_ray
(
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
# This test is quite weak - it only checks that we can use
# RayPrometheusStatLogger without exceptions.
# Checking whether the metrics are actually emitted is unfortunately
# non-trivial.
# We have to run in a Ray task for Ray metrics to be emitted correctly
@
ray
.
remote
(
num_gpus
=
1
)
def
_inner
():
class
_RayPrometheusStatLogger
(
RayPrometheusStatLogger
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_i
=
0
super
().
__init__
(
*
args
,
**
kwargs
)
def
log
(
self
,
*
args
,
**
kwargs
):
self
.
_i
+=
1
return
super
().
log
(
*
args
,
**
kwargs
)
engine_args
=
EngineArgs
(
model
=
model
,
dtype
=
dtype
,
disable_log_stats
=
False
,
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
logger
=
_RayPrometheusStatLogger
(
local_interval
=
0.5
,
labels
=
dict
(
model_name
=
engine
.
model_config
.
served_model_name
),
max_model_len
=
engine
.
model_config
.
max_model_len
)
engine
.
add_logger
(
"ray"
,
logger
)
for
i
,
prompt
in
enumerate
(
example_prompts
):
engine
.
add_request
(
f
"request-id-
{
i
}
"
,
prompt
,
SamplingParams
(
max_tokens
=
max_tokens
),
)
while
engine
.
has_unfinished_requests
():
engine
.
step
()
assert
logger
.
_i
>
0
,
".log must be called at least once"
ray
.
get
(
_inner
.
remote
())
vllm/engine/async_llm_engine.py
View file @
5f0b9933
...
@@ -12,6 +12,7 @@ from vllm.core.scheduler import SchedulerOutputs
...
@@ -12,6 +12,7 @@ from vllm.core.scheduler import SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -389,6 +390,7 @@ class AsyncLLMEngine:
...
@@ -389,6 +390,7 @@ class AsyncLLMEngine:
engine_args
:
AsyncEngineArgs
,
engine_args
:
AsyncEngineArgs
,
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLMEngine"
:
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
# Create the engine configs.
...
@@ -451,6 +453,7 @@ class AsyncLLMEngine:
...
@@ -451,6 +453,7 @@ class AsyncLLMEngine:
max_log_len
=
engine_args
.
max_log_len
,
max_log_len
=
engine_args
.
max_log_len
,
start_engine_loop
=
start_engine_loop
,
start_engine_loop
=
start_engine_loop
,
usage_context
=
usage_context
,
usage_context
=
usage_context
,
stat_loggers
=
stat_loggers
,
)
)
return
engine
return
engine
...
@@ -957,3 +960,19 @@ class AsyncLLMEngine:
...
@@ -957,3 +960,19 @@ class AsyncLLMEngine:
)
)
else
:
else
:
return
self
.
engine
.
is_tracing_enabled
()
return
self
.
engine
.
is_tracing_enabled
()
def
add_logger
(
self
,
logger_name
:
str
,
logger
:
StatLoggerBase
)
->
None
:
if
self
.
engine_use_ray
:
ray
.
get
(
self
.
engine
.
add_logger
.
remote
(
# type: ignore
logger_name
=
logger_name
,
logger
=
logger
))
else
:
self
.
engine
.
add_logger
(
logger_name
=
logger_name
,
logger
=
logger
)
def
remove_logger
(
self
,
logger_name
:
str
)
->
None
:
if
self
.
engine_use_ray
:
ray
.
get
(
self
.
engine
.
remove_logger
.
remote
(
# type: ignore
logger_name
=
logger_name
))
else
:
self
.
engine
.
remove_logger
(
logger_name
=
logger_name
)
vllm/engine/llm_engine.py
View file @
5f0b9933
...
@@ -379,6 +379,7 @@ class LLMEngine:
...
@@ -379,6 +379,7 @@ class LLMEngine:
cls
,
cls
,
engine_args
:
EngineArgs
,
engine_args
:
EngineArgs
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"LLMEngine"
:
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
# Create the engine configs.
...
@@ -423,6 +424,7 @@ class LLMEngine:
...
@@ -423,6 +424,7 @@ class LLMEngine:
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_stats
=
not
engine_args
.
disable_log_stats
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
,
usage_context
=
usage_context
,
stat_loggers
=
stat_loggers
,
)
)
return
engine
return
engine
...
...
vllm/engine/metrics.py
View file @
5f0b9933
...
@@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
...
@@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
# begin-metrics-definitions
class
Metrics
:
class
Metrics
:
labelname_finish_reason
=
"finished_reason"
labelname_finish_reason
=
"finished_reason"
_base_library
=
prometheus_client
_gauge_cls
=
prometheus_client
.
Gauge
_counter_cls
=
prometheus_client
.
Counter
_histogram_cls
=
prometheus_client
.
Histogram
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
# Unregister any existing vLLM collectors
# Unregister any existing vLLM collectors
self
.
_unregister_vllm_metrics
()
self
.
_unregister_vllm_metrics
()
# Config Information
# Config Information
self
.
info_cache_config
=
prometheus_client
.
Info
(
self
.
_create_info_cache_config
()
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
# System stats
# System stats
# Scheduler State
# Scheduler State
self
.
gauge_scheduler_running
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_running
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_running"
,
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests currently running on GPU."
,
documentation
=
"Number of requests currently running on GPU."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
gauge_scheduler_waiting
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_waiting
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
gauge_scheduler_swapped
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_scheduler_swapped
=
self
.
_
gauge_cls
(
name
=
"vllm:num_requests_swapped"
,
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
documentation
=
"Number of requests swapped to CPU."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
# KV Cache Usage in %
# KV Cache Usage in %
self
.
gauge_gpu_cache_usage
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_gpu_cache_usage
=
self
.
_
gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
gauge_cpu_cache_usage
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_cpu_cache_usage
=
self
.
_
gauge_cls
(
name
=
"vllm:cpu_cache_usage_perc"
,
name
=
"vllm:cpu_cache_usage_perc"
,
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
# Iteration stats
# Iteration stats
self
.
counter_num_preemption
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_num_preemption
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:num_preemptions_total"
,
name
=
"vllm:num_preemptions_total"
,
documentation
=
"Cumulative number of preemption from the engine."
,
documentation
=
"Cumulative number of preemption from the engine."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
counter_prompt_tokens
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_prompt_tokens
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:prompt_tokens_total"
,
name
=
"vllm:prompt_tokens_total"
,
documentation
=
"Number of prefill tokens processed."
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
counter_generation_tokens
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_generation_tokens
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:generation_tokens_total"
,
name
=
"vllm:generation_tokens_total"
,
documentation
=
"Number of generation tokens processed."
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
histogram_time_to_first_token
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_time_to_first_token
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:time_to_first_token_seconds"
,
name
=
"vllm:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
...
@@ -86,7 +86,7 @@ class Metrics:
...
@@ -86,7 +86,7 @@ class Metrics:
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
])
])
self
.
histogram_time_per_output_token
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_time_per_output_token
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:time_per_output_token_seconds"
,
name
=
"vllm:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
documentation
=
"Histogram of time per output token in seconds."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
...
@@ -97,83 +97,157 @@ class Metrics:
...
@@ -97,83 +97,157 @@ class Metrics:
# Request stats
# Request stats
# Latency
# Latency
self
.
histogram_e2e_time_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_e2e_time_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:e2e_request_latency_seconds"
,
name
=
"vllm:e2e_request_latency_seconds"
,
documentation
=
"Histogram of end to end request latency in seconds."
,
documentation
=
"Histogram of end to end request latency in seconds."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
buckets
=
[
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
])
buckets
=
[
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
])
# Metadata
# Metadata
self
.
histogram_num_prompt_tokens_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_num_prompt_tokens_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_prompt_tokens"
,
name
=
"vllm:request_prompt_tokens"
,
documentation
=
"Number of prefill tokens processed."
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
)
self
.
histogram_num_generation_tokens_request
=
\
self
.
histogram_num_generation_tokens_request
=
\
self
.
_
base_library
.
H
istogram
(
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_generation_tokens"
,
name
=
"vllm:request_generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
)
self
.
histogram_best_of_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_best_of_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_params_best_of"
,
name
=
"vllm:request_params_best_of"
,
documentation
=
"Histogram of the best_of request parameter."
,
documentation
=
"Histogram of the best_of request parameter."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
)
self
.
histogram_n_request
=
self
.
_
base_library
.
H
istogram
(
self
.
histogram_n_request
=
self
.
_
h
istogram
_cls
(
name
=
"vllm:request_params_n"
,
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
documentation
=
"Histogram of the n request parameter."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
)
self
.
counter_request_success
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_request_success
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:request_success_total"
,
name
=
"vllm:request_success_total"
,
documentation
=
"Count of successfully processed requests."
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Speculatie decoding stats
# Speculatie decoding stats
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_
gauge_cls
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
documentation
=
"Speulative token acceptance rate."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
gauge_spec_decode_efficiency
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_spec_decode_efficiency
=
self
.
_
gauge_cls
(
name
=
"vllm:spec_decode_efficiency"
,
name
=
"vllm:spec_decode_efficiency"
,
documentation
=
"Speculative decoding system efficiency."
,
documentation
=
"Speculative decoding system efficiency."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_counter_cls
(
self
.
_base_library
.
Counter
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
documentation
=
"Number of accepted tokens."
,
labelnames
=
labelnames
))
labelnames
=
labelnames
))
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_
base_library
.
C
ounter
(
self
.
counter_spec_decode_num_draft_tokens
=
self
.
_
c
ounter
_cls
(
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
name
=
"vllm:spec_decode_num_draft_tokens_total"
,
documentation
=
"Number of draft tokens."
,
documentation
=
"Number of draft tokens."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
counter_spec_decode_num_emitted_tokens
=
(
self
.
_counter_cls
(
self
.
_base_library
.
Counter
(
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
name
=
"vllm:spec_decode_num_emitted_tokens_total"
,
documentation
=
"Number of emitted tokens."
,
documentation
=
"Number of emitted tokens."
,
labelnames
=
labelnames
))
labelnames
=
labelnames
))
# Deprecated in favor of vllm:prompt_tokens_total
# Deprecated in favor of vllm:prompt_tokens_total
self
.
gauge_avg_prompt_throughput
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_avg_prompt_throughput
=
self
.
_
gauge_cls
(
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
documentation
=
"Average prefill throughput in tokens/s."
,
documentation
=
"Average prefill throughput in tokens/s."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
)
)
# Deprecated in favor of vllm:generation_tokens_total
# Deprecated in favor of vllm:generation_tokens_total
self
.
gauge_avg_generation_throughput
=
self
.
_
base_library
.
Gauge
(
self
.
gauge_avg_generation_throughput
=
self
.
_
gauge_cls
(
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
documentation
=
"Average generation throughput in tokens/s."
,
documentation
=
"Average generation throughput in tokens/s."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
)
)
def
_create_info_cache_config
(
self
)
->
None
:
# Config Information
self
.
info_cache_config
=
prometheus_client
.
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
def
_unregister_vllm_metrics
(
self
)
->
None
:
def
_unregister_vllm_metrics
(
self
)
->
None
:
for
collector
in
list
(
self
.
_base_library
.
REGISTRY
.
_collector_to_names
):
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
self
.
_base_library
.
REGISTRY
.
unregister
(
collector
)
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
# end-metrics-definitions
class
_RayGaugeWrapper
:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_gauge
=
ray_metrics
.
Gauge
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
labels
(
self
,
**
labels
):
self
.
_gauge
.
set_default_tags
(
labels
)
return
self
def
set
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
_gauge
.
set
(
value
)
class
_RayCounterWrapper
:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_counter
=
ray_metrics
.
Counter
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
)
def
labels
(
self
,
**
labels
):
self
.
_counter
.
set_default_tags
(
labels
)
return
self
def
inc
(
self
,
value
:
Union
[
int
,
float
]
=
1.0
):
if
value
==
0
:
return
return
self
.
_counter
.
inc
(
value
)
class
_RayHistogramWrapper
:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
,
buckets
:
Optional
[
List
[
float
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_histogram
=
ray_metrics
.
Histogram
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
,
boundaries
=
buckets
)
def
labels
(
self
,
**
labels
):
self
.
_histogram
.
set_default_tags
(
labels
)
return
self
def
observe
(
self
,
value
:
Union
[
int
,
float
]):
return
self
.
_histogram
.
observe
(
value
)
class
RayMetrics
(
Metrics
):
class
RayMetrics
(
Metrics
):
...
@@ -181,7 +255,9 @@ class RayMetrics(Metrics):
...
@@ -181,7 +255,9 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
"""
_base_library
=
ray_metrics
_gauge_cls
=
_RayGaugeWrapper
_counter_cls
=
_RayCounterWrapper
_histogram_cls
=
_RayHistogramWrapper
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
if
ray_metrics
is
None
:
if
ray_metrics
is
None
:
...
@@ -192,8 +268,9 @@ class RayMetrics(Metrics):
...
@@ -192,8 +268,9 @@ class RayMetrics(Metrics):
# No-op on purpose
# No-op on purpose
pass
pass
def
_create_info_cache_config
(
self
)
->
None
:
# end-metrics-definitions
# No-op on purpose
pass
def
build_1_2_5_buckets
(
max_value
:
int
)
->
List
[
int
]:
def
build_1_2_5_buckets
(
max_value
:
int
)
->
List
[
int
]:
...
@@ -498,3 +575,6 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -498,3 +575,6 @@ class PrometheusStatLogger(StatLoggerBase):
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls
=
RayMetrics
_metrics_cls
=
RayMetrics
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment