Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
906a19cd
Unverified
Commit
906a19cd
authored
Jun 28, 2024
by
William Lin
Committed by
GitHub
Jun 29, 2024
Browse files
[Misc] Extend vLLM Metrics logging API (#5925)
Co-authored-by:
Antoni Baum
<
antoni.baum@protonmail.com
>
parent
c4bca740
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
225 additions
and
118 deletions
+225
-118
tests/metrics/test_metrics.py
tests/metrics/test_metrics.py
+6
-6
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+30
-8
vllm/engine/metrics.py
vllm/engine/metrics.py
+189
-104
No files found.
tests/metrics/test_metrics.py
View file @
906a19cd
...
...
@@ -39,7 +39,7 @@ def test_metric_counter_prompt_tokens(
vllm_prompt_token_count
=
sum
(
prompt_token_counts
)
_
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_logger
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_logger
s
[
'prometheus'
]
metric_count
=
stat_logger
.
metrics
.
counter_prompt_tokens
.
labels
(
**
stat_logger
.
labels
).
_value
.
get
()
...
...
@@ -64,7 +64,7 @@ def test_metric_counter_generation_tokens(
gpu_memory_utilization
=
0.4
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_logger
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_logger
s
[
'prometheus'
]
metric_count
=
stat_logger
.
metrics
.
counter_generation_tokens
.
labels
(
**
stat_logger
.
labels
).
_value
.
get
()
vllm_generation_count
=
0
...
...
@@ -92,7 +92,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
disable_log_stats
=
False
,
gpu_memory_utilization
=
0.3
,
served_model_name
=
served_model_name
)
as
vllm_model
:
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_logger
stat_logger
=
vllm_model
.
model
.
llm_engine
.
stat_logger
s
[
'prometheus'
]
metrics_tag_content
=
stat_logger
.
labels
[
"model_name"
]
if
served_model_name
is
None
or
served_model_name
==
[]:
...
...
@@ -172,10 +172,10 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests
:
int
)
->
None
:
if
disable_log_stats
:
with
pytest
.
raises
(
AttributeError
):
_
=
engine
.
stat_logger
_
=
engine
.
stat_logger
s
else
:
assert
(
engine
.
stat_logger
is
not
None
),
"engine.stat_logger should be set"
assert
(
engine
.
stat_logger
s
is
not
None
),
"engine.stat_logger
s
should be set"
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
...
...
vllm/engine/llm_engine.py
View file @
906a19cd
...
...
@@ -13,7 +13,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
from
vllm.core.scheduler
import
(
ScheduledSequenceGroup
,
Scheduler
,
SchedulerOutputs
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
StatLogger
,
Stats
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
PrometheusStatLogger
,
StatLoggerBase
,
Stats
)
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
...
...
@@ -160,6 +161,7 @@ class LLMEngine:
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
None
:
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
...
...
@@ -292,11 +294,21 @@ class LLMEngine:
# Metric Logging.
if
self
.
log_stats
:
self
.
stat_logger
=
StatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
max_model_len
=
self
.
model_config
.
max_model_len
)
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
if
stat_loggers
is
not
None
:
self
.
stat_loggers
=
stat_loggers
else
:
self
.
stat_loggers
=
{
"logging"
:
LoggingStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
),
"prometheus"
:
PrometheusStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
max_model_len
=
self
.
model_config
.
max_model_len
),
}
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
self
.
cache_config
)
self
.
tracer
=
None
if
self
.
observability_config
.
otlp_traces_endpoint
:
...
...
@@ -833,14 +845,24 @@ class LLMEngine:
return
request_outputs
def
add_logger
(
self
,
logger_name
:
str
,
logger
:
StatLoggerBase
)
->
None
:
if
logger_name
in
self
.
stat_loggers
:
raise
KeyError
(
f
"Logger with name
{
logger_name
}
already exists."
)
self
.
stat_loggers
[
logger_name
]
=
logger
def
remove_logger
(
self
,
logger_name
:
str
)
->
None
:
if
logger_name
not
in
self
.
stat_loggers
:
raise
KeyError
(
f
"Logger with name
{
logger_name
}
does not exist."
)
del
self
.
stat_loggers
[
logger_name
]
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
"""Forced log when no requests active."""
if
self
.
log_stats
:
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
,
model_output
))
for
logger
in
self
.
stat_logger
s
.
values
():
logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
,
model_output
))
def
_get_stats
(
self
,
...
...
vllm/engine/metrics.py
View file @
906a19cd
import
time
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Protocol
,
Union
import
numpy
as
np
from
prometheus_client
import
(
REGISTRY
,
Counter
,
Gauge
,
Histogram
,
Info
,
disable_created_metrics
)
import
prometheus_client
from
vllm.executor.ray_utils
import
ray
from
vllm.logger
import
init_logger
if
ray
is
not
None
:
from
ray.util
import
metrics
as
ray_metrics
else
:
ray_metrics
=
None
if
TYPE_CHECKING
:
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
logger
=
init_logger
(
__name__
)
disable_created_metrics
()
prometheus_client
.
disable_created_metrics
()
# The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions.
...
...
@@ -24,56 +30,55 @@ disable_created_metrics()
# begin-metrics-definitions
class
Metrics
:
labelname_finish_reason
=
"finished_reason"
_base_library
=
prometheus_client
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
# Unregister any existing vLLM collectors
for
collector
in
list
(
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
REGISTRY
.
unregister
(
collector
)
self
.
_unregister_vllm_metrics
()
# Config Information
self
.
info_cache_config
=
Info
(
self
.
info_cache_config
=
prometheus_client
.
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
# System stats
# Scheduler State
self
.
gauge_scheduler_running
=
Gauge
(
self
.
gauge_scheduler_running
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests currently running on GPU."
,
labelnames
=
labelnames
)
self
.
gauge_scheduler_waiting
=
Gauge
(
self
.
gauge_scheduler_waiting
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
)
self
.
gauge_scheduler_swapped
=
Gauge
(
self
.
gauge_scheduler_swapped
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
labelnames
=
labelnames
)
# KV Cache Usage in %
self
.
gauge_gpu_cache_usage
=
Gauge
(
self
.
gauge_gpu_cache_usage
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
self
.
gauge_cpu_cache_usage
=
Gauge
(
self
.
gauge_cpu_cache_usage
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:cpu_cache_usage_perc"
,
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
# Iteration stats
self
.
counter_num_preemption
=
Counter
(
self
.
counter_num_preemption
=
self
.
_base_library
.
Counter
(
name
=
"vllm:num_preemptions_total"
,
documentation
=
"Cumulative number of preemption from the engine."
,
labelnames
=
labelnames
)
self
.
counter_prompt_tokens
=
Counter
(
self
.
counter_prompt_tokens
=
self
.
_base_library
.
Counter
(
name
=
"vllm:prompt_tokens_total"
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
)
self
.
counter_generation_tokens
=
Counter
(
self
.
counter_generation_tokens
=
self
.
_base_library
.
Counter
(
name
=
"vllm:generation_tokens_total"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
)
self
.
histogram_time_to_first_token
=
Histogram
(
self
.
histogram_time_to_first_token
=
self
.
_base_library
.
Histogram
(
name
=
"vllm:time_to_first_token_seconds"
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labelnames
,
...
...
@@ -81,7 +86,7 @@ class Metrics:
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.25
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
])
self
.
histogram_time_per_output_token
=
Histogram
(
self
.
histogram_time_per_output_token
=
self
.
_base_library
.
Histogram
(
name
=
"vllm:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
labelnames
=
labelnames
,
...
...
@@ -92,54 +97,77 @@ class Metrics:
# Request stats
# Latency
self
.
histogram_e2e_time_request
=
Histogram
(
self
.
histogram_e2e_time_request
=
self
.
_base_library
.
Histogram
(
name
=
"vllm:e2e_request_latency_seconds"
,
documentation
=
"Histogram of end to end request latency in seconds."
,
labelnames
=
labelnames
,
buckets
=
[
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
])
# Metadata
self
.
histogram_num_prompt_tokens_request
=
Histogram
(
self
.
histogram_num_prompt_tokens_request
=
self
.
_base_library
.
Histogram
(
name
=
"vllm:request_prompt_tokens"
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_num_generation_tokens_request
=
Histogram
(
name
=
"vllm:request_generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_best_of_request
=
Histogram
(
self
.
histogram_num_generation_tokens_request
=
\
self
.
_base_library
.
Histogram
(
name
=
"vllm:request_generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_best_of_request
=
self
.
_base_library
.
Histogram
(
name
=
"vllm:request_params_best_of"
,
documentation
=
"Histogram of the best_of request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
histogram_n_request
=
Histogram
(
self
.
histogram_n_request
=
self
.
_base_library
.
Histogram
(
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
counter_request_success
=
Counter
(
self
.
counter_request_success
=
self
.
_base_library
.
Counter
(
name
=
"vllm:request_success_total"
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
# Deprecated in favor of vllm:prompt_tokens_total
self
.
gauge_avg_prompt_throughput
=
Gauge
(
self
.
gauge_avg_prompt_throughput
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
documentation
=
"Average prefill throughput in tokens/s."
,
labelnames
=
labelnames
,
)
# Deprecated in favor of vllm:generation_tokens_total
self
.
gauge_avg_generation_throughput
=
Gauge
(
self
.
gauge_avg_generation_throughput
=
self
.
_base_library
.
Gauge
(
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
documentation
=
"Average generation throughput in tokens/s."
,
labelnames
=
labelnames
,
)
def
_unregister_vllm_metrics
(
self
)
->
None
:
for
collector
in
list
(
self
.
_base_library
.
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
self
.
_base_library
.
REGISTRY
.
unregister
(
collector
)
class
RayMetrics
(
Metrics
):
"""
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library
=
ray_metrics
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
if
ray_metrics
is
None
:
raise
ImportError
(
"RayMetrics requires Ray to be installed."
)
super
().
__init__
(
labelnames
,
max_model_len
)
def
_unregister_vllm_metrics
(
self
)
->
None
:
# No-op on purpose
pass
# end-metrics-definitions
...
...
@@ -206,34 +234,136 @@ class SupportsMetricsInfo(Protocol):
...
class
StatLogger
:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
def
local_interval_elapsed
(
now
:
float
,
last_log
:
float
,
local_interval
:
float
)
->
bool
:
elapsed_time
=
now
-
last_log
return
elapsed_time
>
local_interval
def
get_throughput
(
tracked_stats
:
List
[
int
],
now
:
float
,
last_log
:
float
)
->
float
:
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
last_log
))
def
__init__
(
self
,
local_interval
:
float
,
labels
:
Dict
[
str
,
str
],
max_model_len
:
int
)
->
None
:
# Metadata for logging locally.
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
class
StatLoggerBase
(
ABC
):
"""Base class for StatLogger."""
def
__init__
(
self
,
local_interval
:
float
)
->
None
:
# Tracked stats over current local logging interval.
self
.
num_prompt_tokens
:
List
[
int
]
=
[]
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
@
abstractmethod
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
log
(
self
,
stats
:
Stats
)
->
None
:
raise
NotImplementedError
class
LoggingStatLogger
(
StatLoggerBase
):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
def
log
(
self
,
stats
:
Stats
)
->
None
:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
# Save tracked stats for token counters.
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Log locally every local_interval seconds.
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput
=
get_throughput
(
self
.
num_prompt_tokens
,
now
=
stats
.
now
,
last_log
=
self
.
last_local_log
)
generation_throughput
=
get_throughput
(
self
.
num_generation_tokens
,
now
=
stats
.
now
,
last_log
=
self
.
last_local_log
)
# Log to stdout.
logger
.
info
(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%."
,
prompt_throughput
,
generation_throughput
,
stats
.
num_running_sys
,
stats
.
num_swapped_sys
,
stats
.
num_waiting_sys
,
stats
.
gpu_cache_usage_sys
*
100
,
stats
.
cpu_cache_usage_sys
*
100
,
)
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
stats
.
spec_decode_metrics
))
def
_format_spec_decode_metrics_str
(
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
return
(
"Speculative metrics: "
f
"Draft acceptance rate:
{
metrics
.
draft_acceptance_rate
:.
3
f
}
, "
f
"System efficiency:
{
metrics
.
system_efficiency
:.
3
f
}
, "
f
"Number of speculative tokens:
{
metrics
.
num_spec_tokens
}
, "
f
"Number of accepted tokens:
{
metrics
.
accepted_tokens
}
, "
f
"Number of draft tokens tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted tokens tokens:
{
metrics
.
emitted_tokens
}
."
)
class
PrometheusStatLogger
(
StatLoggerBase
):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls
=
Metrics
def
__init__
(
self
,
local_interval
:
float
,
labels
:
Dict
[
str
,
str
],
max_model_len
:
int
)
->
None
:
super
().
__init__
(
local_interval
)
# Prometheus metrics
self
.
labels
=
labels
self
.
metrics
=
M
etrics
(
labelnames
=
list
(
labels
.
keys
()),
max_model_len
=
max_model_len
)
self
.
metrics
=
self
.
_m
etrics
_cls
(
labelnames
=
list
(
labels
.
keys
()),
max_model_len
=
max_model_len
)
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
if
type
==
"cache_config"
:
self
.
metrics
.
info_cache_config
.
info
(
obj
.
metrics_info
())
def
_get_throughput
(
self
,
tracked_stats
:
List
[
int
],
now
:
float
)
->
float
:
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_local_log
))
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
def
_local_interval_elapsed
(
self
,
now
:
float
)
->
bool
:
elapsed_time
=
now
-
self
.
last_local_log
return
elapsed_time
>
self
.
local_interval
def
_log_counter
(
self
,
counter
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to counter.
counter
.
labels
(
**
self
.
labels
).
inc
(
data
)
def
_log_counter_labels
(
self
,
counter
,
data
:
CollectionsCounter
,
label_key
:
str
)
->
None
:
# Convenience function for collection counter of labels.
for
label
,
count
in
data
.
items
():
counter
.
labels
(
**
{
**
self
.
labels
,
label_key
:
label
}).
inc
(
count
)
def
_log_histogram
(
self
,
histogram
,
data
:
Union
[
List
[
int
],
List
[
float
]])
->
None
:
# Convenience function for logging list to histogram.
for
datum
in
data
:
histogram
.
labels
(
**
self
.
labels
).
observe
(
datum
)
def
_log_prometheus
(
self
,
stats
:
Stats
)
->
None
:
# System state data
...
...
@@ -279,26 +409,6 @@ class StatLogger:
self
.
_log_histogram
(
self
.
metrics
.
histogram_best_of_request
,
stats
.
best_of_requests
)
def
_log_gauge
(
self
,
gauge
:
Gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
def
_log_counter
(
self
,
counter
:
Counter
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to counter.
counter
.
labels
(
**
self
.
labels
).
inc
(
data
)
def
_log_counter_labels
(
self
,
counter
:
Counter
,
data
:
CollectionsCounter
,
label_key
:
str
)
->
None
:
# Convenience function for collection counter of labels.
for
label
,
count
in
data
.
items
():
counter
.
labels
(
**
{
**
self
.
labels
,
label_key
:
label
}).
inc
(
count
)
def
_log_histogram
(
self
,
histogram
:
Histogram
,
data
:
Union
[
List
[
int
],
List
[
float
]])
->
None
:
# Convenience function for logging list to histogram.
for
datum
in
data
:
histogram
.
labels
(
**
self
.
labels
).
observe
(
datum
)
def
_log_prometheus_interval
(
self
,
prompt_throughput
:
float
,
generation_throughput
:
float
)
->
None
:
# Logs metrics to prometheus that are computed every logging_interval.
...
...
@@ -313,11 +423,8 @@ class StatLogger:
self
.
metrics
.
gauge_avg_generation_throughput
.
labels
(
**
self
.
labels
).
set
(
generation_throughput
)
def
log
(
self
,
stats
:
Stats
)
->
None
:
"""Called by LLMEngine.
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
def
log
(
self
,
stats
:
Stats
):
"""Logs to prometheus and tracked stats every iteration."""
# Log to prometheus.
self
.
_log_prometheus
(
stats
)
...
...
@@ -326,50 +433,28 @@ class StatLogger:
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens_iter
)
# Log locally every local_interval seconds.
if
self
.
_local_interval_elapsed
(
stats
.
now
):
if
local_interval_elapsed
(
stats
.
now
,
self
.
last_local_log
,
self
.
local_interval
):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput
=
self
.
_get_throughput
(
self
.
num_prompt_tokens
,
now
=
stats
.
now
)
generation_throughput
=
self
.
_get_throughput
(
self
.
num_generation_tokens
,
now
=
stats
.
now
)
prompt_throughput
=
get_throughput
(
self
.
num_prompt_tokens
,
now
=
stats
.
now
,
last_log
=
self
.
last_local_log
)
generation_throughput
=
get_throughput
(
self
.
num_generation_tokens
,
now
=
stats
.
now
,
last_log
=
self
.
last_local_log
)
self
.
_log_prometheus_interval
(
prompt_throughput
=
prompt_throughput
,
generation_throughput
=
generation_throughput
)
# Log to stdout.
logger
.
info
(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%."
,
prompt_throughput
,
generation_throughput
,
stats
.
num_running_sys
,
stats
.
num_swapped_sys
,
stats
.
num_waiting_sys
,
stats
.
gpu_cache_usage_sys
*
100
,
stats
.
cpu_cache_usage_sys
*
100
,
)
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_generation_tokens
=
[]
self
.
last_local_log
=
stats
.
now
if
stats
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
stats
.
spec_decode_metrics
))
def
_format_spec_decode_metrics_str
(
self
,
metrics
:
"SpecDecodeWorkerMetrics"
)
->
str
:
return
(
"Speculative metrics: "
f
"Draft acceptance rate:
{
metrics
.
draft_acceptance_rate
:.
3
f
}
, "
f
"System efficiency:
{
metrics
.
system_efficiency
:.
3
f
}
, "
f
"Number of speculative tokens:
{
metrics
.
num_spec_tokens
}
, "
f
"Number of accepted tokens:
{
metrics
.
accepted_tokens
}
, "
f
"Number of draft tokens tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted tokens tokens:
{
metrics
.
emitted_tokens
}
."
)
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls
=
RayMetrics
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment