Unverified Commit 29d1ffc5 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[DP] Fix Prometheus Logging (#21257)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 304dce7e
......@@ -336,9 +336,10 @@ async def test_customize_loggers(monkeypatch):
await engine.do_log_stats()
assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once()
stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(stat_loggers[0]) == 1
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module")
......
......@@ -90,8 +90,10 @@ async def test_load(output_kind: RequestOutputKind,
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
stats_loggers[engine_index] = self
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
def record(self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0):
if iteration_stats:
self.finished_req_count += len(
iteration_stats.finished_requests)
......
......@@ -36,10 +36,9 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
......@@ -95,14 +94,6 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
......@@ -121,7 +112,6 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
......@@ -129,9 +119,17 @@ class AsyncLLM(EngineClient):
client_addresses=client_addresses,
client_index=client_index,
)
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()
# Loggers.
self.logger_manager: Optional[StatLoggerManager] = None
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks,
custom_stat_loggers=stat_loggers,
)
self.logger_manager.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
......@@ -370,7 +368,7 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
logger_manager = self.logger_manager
async def output_handler():
try:
......@@ -410,9 +408,9 @@ class AsyncLLM(EngineClient):
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
if logger_manager:
logger_manager.record(
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
......@@ -431,18 +429,6 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
async def encode(
self,
prompt: PromptType,
......@@ -547,9 +533,8 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None,
model_output=None,
) -> None:
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
if self.logger_manager:
self.logger_manager.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")
......@@ -653,18 +638,16 @@ class AsyncLLM(EngineClient):
new_data_parallel_size
# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size:
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
# TODO(rob): fix this after talking with Ray team.
# This resets all the prometheus metrics since we
# unregister during initialization. Need to understand
# the intended behavior here better.
self.logger_manager = StatLoggerManager(
vllm_config=self.vllm_config,
log_stats=self.log_stats,
engine_num=new_data_parallel_size,
engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None,
)
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
else:
for _ in range(old_data_parallel_size - new_data_parallel_size):
self.stat_loggers.pop()
@property
def is_running(self) -> bool:
......
......@@ -432,14 +432,15 @@ class MPClient(EngineCoreClient):
external_dp_lb = parallel_config.data_parallel_external_lb
offline_mode = parallel_config.data_parallel_rank_local is not None
engine_ranks = [dp_rank] if (offline_mode
or external_dp_lb) else range(dp_size)
self.engine_ranks = ([dp_rank] if
(offline_mode or external_dp_lb) else list(
range(dp_size)))
assert parallel_config.data_parallel_size_local <= len(
engine_ranks)
self.engine_ranks)
# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in engine_ranks
index.to_bytes(2, "little") for index in self.engine_ranks
]
# Wait for ready messages from each engine on the input socket.
......
This diff is collapsed.
......@@ -3,7 +3,6 @@
import time
from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
......@@ -128,9 +127,6 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
@staticmethod
def _unregister_vllm_metrics():
# No-op on purpose
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment