Unverified Commit 29d1ffc5 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[DP] Fix Prometheus Logging (#21257)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 304dce7e
...@@ -336,9 +336,10 @@ async def test_customize_loggers(monkeypatch): ...@@ -336,9 +336,10 @@ async def test_customize_loggers(monkeypatch):
await engine.do_log_stats() await engine.do_log_stats()
assert len(engine.stat_loggers) == 1 stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(engine.stat_loggers[0]) == 1 assert len(stat_loggers) == 1
engine.stat_loggers[0][0].log.assert_called_once() assert len(stat_loggers[0]) == 1
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
......
...@@ -90,8 +90,10 @@ async def test_load(output_kind: RequestOutputKind, ...@@ -90,8 +90,10 @@ async def test_load(output_kind: RequestOutputKind,
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
stats_loggers[engine_index] = self stats_loggers[engine_index] = self
def record(self, scheduler_stats: Optional[SchedulerStats], def record(self,
iteration_stats: Optional[IterationStats]): scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0):
if iteration_stats: if iteration_stats:
self.finished_req_count += len( self.finished_req_count += len(
iteration_stats.finished_requests) iteration_stats.finished_requests)
......
...@@ -36,10 +36,9 @@ from vllm.v1.engine.output_processor import (OutputProcessor, ...@@ -36,10 +36,9 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -95,14 +94,6 @@ class AsyncLLM(EngineClient): ...@@ -95,14 +94,6 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
...@@ -121,7 +112,6 @@ class AsyncLLM(EngineClient): ...@@ -121,7 +112,6 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats) log_stats=self.log_stats)
# EngineCore (starts the engine in background process). # EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client( self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
...@@ -129,9 +119,17 @@ class AsyncLLM(EngineClient): ...@@ -129,9 +119,17 @@ class AsyncLLM(EngineClient):
client_addresses=client_addresses, client_addresses=client_addresses,
client_index=client_index, client_index=client_index,
) )
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]: # Loggers.
stat_logger.log_engine_initialized() self.logger_manager: Optional[StatLoggerManager] = None
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks,
custom_stat_loggers=stat_loggers,
)
self.logger_manager.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
try: try:
# Start output handler eagerly if we are in the asyncio eventloop. # Start output handler eagerly if we are in the asyncio eventloop.
...@@ -370,7 +368,7 @@ class AsyncLLM(EngineClient): ...@@ -370,7 +368,7 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core engine_core = self.engine_core
output_processor = self.output_processor output_processor = self.output_processor
log_stats = self.log_stats log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None logger_manager = self.logger_manager
async def output_handler(): async def output_handler():
try: try:
...@@ -410,9 +408,9 @@ class AsyncLLM(EngineClient): ...@@ -410,9 +408,9 @@ class AsyncLLM(EngineClient):
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
if stat_loggers: if logger_manager:
AsyncLLM._record_stats( logger_manager.record(
stat_loggers[outputs.engine_index], engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
) )
...@@ -431,18 +429,6 @@ class AsyncLLM(EngineClient): ...@@ -431,18 +429,6 @@ class AsyncLLM(EngineClient):
if self.log_requests: if self.log_requests:
logger.info("Aborted request %s.", request_id) logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
async def encode( async def encode(
self, self,
prompt: PromptType, prompt: PromptType,
...@@ -547,9 +533,8 @@ class AsyncLLM(EngineClient): ...@@ -547,9 +533,8 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None, scheduler_outputs=None,
model_output=None, model_output=None,
) -> None: ) -> None:
for loggers in self.stat_loggers: if self.logger_manager:
for stat_logger in loggers: self.logger_manager.log()
stat_logger.log()
async def check_health(self) -> None: async def check_health(self) -> None:
logger.debug("Called check_health.") logger.debug("Called check_health.")
...@@ -653,18 +638,16 @@ class AsyncLLM(EngineClient): ...@@ -653,18 +638,16 @@ class AsyncLLM(EngineClient):
new_data_parallel_size new_data_parallel_size
# recreate stat loggers # recreate stat loggers
if new_data_parallel_size > old_data_parallel_size: if new_data_parallel_size > old_data_parallel_size and self.log_stats:
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( # TODO(rob): fix this after talking with Ray team.
# This resets all the prometheus metrics since we
# unregister during initialization. Need to understand
# the intended behavior here better.
self.logger_manager = StatLoggerManager(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
log_stats=self.log_stats, engine_idxs=list(range(new_data_parallel_size)),
engine_num=new_data_parallel_size,
custom_stat_loggers=None, custom_stat_loggers=None,
) )
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
else:
for _ in range(old_data_parallel_size - new_data_parallel_size):
self.stat_loggers.pop()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
......
...@@ -432,14 +432,15 @@ class MPClient(EngineCoreClient): ...@@ -432,14 +432,15 @@ class MPClient(EngineCoreClient):
external_dp_lb = parallel_config.data_parallel_external_lb external_dp_lb = parallel_config.data_parallel_external_lb
offline_mode = parallel_config.data_parallel_rank_local is not None offline_mode = parallel_config.data_parallel_rank_local is not None
engine_ranks = [dp_rank] if (offline_mode self.engine_ranks = ([dp_rank] if
or external_dp_lb) else range(dp_size) (offline_mode or external_dp_lb) else list(
range(dp_size)))
assert parallel_config.data_parallel_size_local <= len( assert parallel_config.data_parallel_size_local <= len(
engine_ranks) self.engine_ranks)
# ZMQ identity of each engine that this client will talk to. # ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [ self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in engine_ranks index.to_bytes(2, "little") for index in self.engine_ranks
] ]
# Wait for ready messages from each engine on the input socket. # Wait for ready messages from each engine on the input socket.
......
This diff is collapsed.
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import time import time
from typing import Optional, Union from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm from vllm.v1.spec_decode.metrics import SpecDecodingProm
...@@ -128,9 +127,6 @@ class RayPrometheusStatLogger(PrometheusStatLogger): ...@@ -128,9 +127,6 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_histogram_cls = RayHistogramWrapper _histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm _spec_decoding_cls = RaySpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
@staticmethod @staticmethod
def _unregister_vllm_metrics(): def _unregister_vllm_metrics():
# No-op on purpose # No-op on purpose
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment