Unverified Commit 96476be5 authored by Krishnan Prashanth's avatar Krishnan Prashanth Committed by GitHub
Browse files

feat: expose multimodal embedding cache metrics via Prometheus (#8031)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent d2aad651
...@@ -59,6 +59,7 @@ class MultimodalEmbeddingCacheManager: ...@@ -59,6 +59,7 @@ class MultimodalEmbeddingCacheManager:
# Stats # Stats
self._hits = 0 self._hits = 0
self._misses = 0 self._misses = 0
self._evictions = 0
logger.info( logger.info(
f"MultimodalEmbeddingCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB" f"MultimodalEmbeddingCacheManager initialized: capacity={capacity_bytes / 1024**3:.2f}GB"
...@@ -138,6 +139,7 @@ class MultimodalEmbeddingCacheManager: ...@@ -138,6 +139,7 @@ class MultimodalEmbeddingCacheManager:
evicted_key, evicted_entry = self._cache.popitem(last=False) evicted_key, evicted_entry = self._cache.popitem(last=False)
evicted_size = self._tensor_size(evicted_entry.tensor) evicted_size = self._tensor_size(evicted_entry.tensor)
self._current_bytes -= evicted_size self._current_bytes -= evicted_size
self._evictions += 1
logger.debug( logger.debug(
f"Evicted key={evicted_key[:16]}..., size={evicted_size / 1024**2:.2f}MB" f"Evicted key={evicted_key[:16]}..., size={evicted_size / 1024**2:.2f}MB"
) )
...@@ -174,5 +176,6 @@ class MultimodalEmbeddingCacheManager: ...@@ -174,5 +176,6 @@ class MultimodalEmbeddingCacheManager:
else 0, else 0,
"hits": self._hits, "hits": self._hits,
"misses": self._misses, "misses": self._misses,
"evictions": self._evictions,
"hit_rate": hit_rate, "hit_rate": hit_rate,
} }
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for register_embedding_cache_metrics."""
from unittest.mock import MagicMock
import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.common.utils.prometheus import (
EmbeddingCacheMetrics,
register_embedding_cache_metrics,
)
pytestmark = [
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.integration,
]
def _parse_metric(text: str, name: str) -> float | None:
"""Parse a metric value from Prometheus expfmt text.
TODO: Consolidate with _get_metric_value() in trtllm/tests/test_trtllm_additional_metrics.py
into a shared test utility once more metric tests are added.
"""
for line in text.split("\n"):
if line.startswith(name + "{") or line.startswith(name + " "):
parts = line.rsplit(" ", 1)
if len(parts) == 2:
return float(parts[1])
return None
@pytest.fixture()
def cache_env():
"""Set up endpoint mock, cache, register metrics, return (cache, callback)."""
endpoint = MagicMock()
endpoint.metrics.register_prometheus_expfmt_callback = MagicMock()
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024 * 1024)
register_embedding_cache_metrics(endpoint, cache, "test-model", "encoder")
endpoint.metrics.register_prometheus_expfmt_callback.assert_called_once()
callback = endpoint.metrics.register_prometheus_expfmt_callback.call_args[0][0]
return cache, callback
class TestCounters:
"""Delta-based counter increments across scrapes."""
def test_accumulation_and_noop(self, cache_env):
"""Counters accumulate across scrapes and stay flat with no activity."""
cache, callback = cache_env
cache.get("miss1")
text1 = callback()
assert (
_parse_metric(text1, "dynamo_component_embedding_cache_misses_total") == 1.0
)
# No-op scrape: counter unchanged
assert (
_parse_metric(callback(), "dynamo_component_embedding_cache_misses_total")
== 1.0
)
# More misses: counter accumulates
cache.get("miss2")
cache.get("miss3")
text3 = callback()
assert (
_parse_metric(text3, "dynamo_component_embedding_cache_misses_total") == 3.0
)
def test_hits_and_misses(self, cache_env):
"""Hits and misses counted correctly after population."""
cache, callback = cache_env
cache.set("k", CachedEmbedding(torch.randn(10, 10)))
cache.get("k") # hit
cache.get("k") # hit
cache.get("absent") # miss
text = callback()
assert _parse_metric(text, "dynamo_component_embedding_cache_hits_total") == 2.0
assert (
_parse_metric(text, "dynamo_component_embedding_cache_misses_total") == 1.0
)
def test_evictions(self):
"""Eviction counter increments when LRU entry is displaced."""
endpoint = MagicMock()
endpoint.metrics.register_prometheus_expfmt_callback = MagicMock()
tensor_bytes = 100 * 4
cache = MultimodalEmbeddingCacheManager(capacity_bytes=tensor_bytes + 10)
register_embedding_cache_metrics(endpoint, cache, "m", "c")
callback = endpoint.metrics.register_prometheus_expfmt_callback.call_args[0][0]
cache.set("a", CachedEmbedding(torch.zeros(100, dtype=torch.float32)))
cache.set("b", CachedEmbedding(torch.zeros(100, dtype=torch.float32)))
text = callback()
assert (
_parse_metric(text, "dynamo_component_embedding_cache_evictions_total")
== 1.0
)
class TestGauges:
"""Snapshot gauge values."""
def test_empty_then_populated(self, cache_env):
"""Gauges start at zero, then reflect state after insertions."""
cache, callback = cache_env
# Empty cache
text0 = callback()
assert _parse_metric(text0, "dynamo_component_embedding_cache_entries") == 0.0
assert (
_parse_metric(text0, "dynamo_component_embedding_cache_current_bytes")
== 0.0
)
assert (
_parse_metric(text0, "dynamo_component_embedding_cache_utilization") == 0.0
)
# Add one entry
t = torch.zeros(100, dtype=torch.float32)
t_bytes = 100 * 4
cache.set("k1", CachedEmbedding(t))
text1 = callback()
assert _parse_metric(text1, "dynamo_component_embedding_cache_entries") == 1.0
assert (
_parse_metric(text1, "dynamo_component_embedding_cache_current_bytes")
== t_bytes
)
assert (
abs(
_parse_metric(text1, "dynamo_component_embedding_cache_utilization")
- t_bytes / (1024 * 1024)
)
< 1e-6
)
# Add a second entry
cache.set("k2", CachedEmbedding(torch.zeros(50, dtype=torch.float32)))
text2 = callback()
assert _parse_metric(text2, "dynamo_component_embedding_cache_entries") == 2.0
class TestLabelsAndCompleteness:
"""Label correctness and metric name completeness."""
def test_labels_present(self):
"""Model and component labels appear in output."""
endpoint = MagicMock()
endpoint.metrics.register_prometheus_expfmt_callback = MagicMock()
cache = MultimodalEmbeddingCacheManager(capacity_bytes=1024)
register_embedding_cache_metrics(
endpoint, cache, "Qwen/Qwen2.5-VL-3B", "encoder"
)
callback = endpoint.metrics.register_prometheus_expfmt_callback.call_args[0][0]
text = callback()
assert 'model="Qwen/Qwen2.5-VL-3B"' in text
assert 'dynamo_component="encoder"' in text
def test_all_metric_names_present(self, cache_env):
"""Every expected metric name appears in the scrape output."""
cache, callback = cache_env
# Generate at least one event so counters appear
cache.set("k", CachedEmbedding(torch.zeros(10, dtype=torch.float32)))
cache.get("k")
cache.get("absent")
text = callback()
for name in EmbeddingCacheMetrics:
assert (
_parse_metric(text, name) is not None
), f"metric {name} missing from output"
...@@ -11,8 +11,10 @@ Note: Engine metrics take time to appear after engine initialization, ...@@ -11,8 +11,10 @@ Note: Engine metrics take time to appear after engine initialization,
while Dynamo runtime metrics are available immediately after component creation. while Dynamo runtime metrics are available immediately after component creation.
""" """
import enum
import logging import logging
import re import re
import threading
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Optional, Pattern from typing import TYPE_CHECKING, Optional, Pattern
...@@ -25,6 +27,8 @@ from dynamo.prometheus_names import kvstats, labels, model_info, name_prefix ...@@ -25,6 +27,8 @@ from dynamo.prometheus_names import kvstats, labels, model_info, name_prefix
if TYPE_CHECKING: if TYPE_CHECKING:
from prometheus_client import CollectorRegistry from prometheus_client import CollectorRegistry
from dynamo.common.memory import MultimodalEmbeddingCacheManager
# Auto-label injection: always injects dynamo_namespace, dynamo_component, dynamo_endpoint labels # Auto-label injection: always injects dynamo_namespace, dynamo_component, dynamo_endpoint labels
# into engine metrics based on the endpoint hierarchy. # into engine metrics based on the endpoint hierarchy.
# #
...@@ -32,6 +36,18 @@ if TYPE_CHECKING: ...@@ -32,6 +36,18 @@ if TYPE_CHECKING:
# Label constants defined in: lib/runtime/src/metrics/prometheus_names.rs labels module # Label constants defined in: lib/runtime/src/metrics/prometheus_names.rs labels module
# Single source of truth for embedding cache metric names.
class EmbeddingCacheMetrics(str, enum.Enum):
"""Prometheus metric names for the multimodal embedding cache."""
HITS_TOTAL = f"{name_prefix.COMPONENT}_embedding_cache_hits_total"
MISSES_TOTAL = f"{name_prefix.COMPONENT}_embedding_cache_misses_total"
EVICTIONS_TOTAL = f"{name_prefix.COMPONENT}_embedding_cache_evictions_total"
UTILIZATION = f"{name_prefix.COMPONENT}_embedding_cache_utilization"
CURRENT_BYTES = f"{name_prefix.COMPONENT}_embedding_cache_current_bytes"
ENTRIES = f"{name_prefix.COMPONENT}_embedding_cache_entries"
def register_engine_metrics_callback( def register_engine_metrics_callback(
endpoint: Endpoint, endpoint: Endpoint,
registry: "CollectorRegistry", registry: "CollectorRegistry",
...@@ -365,3 +381,130 @@ class LLMBackendMetrics: ...@@ -365,3 +381,130 @@ class LLMBackendMetrics:
self.model_load_time.labels( self.model_load_time.labels(
**{labels.MODEL: self.model_name, labels.COMPONENT: self.component_name} **{labels.MODEL: self.model_name, labels.COMPONENT: self.component_name}
).set(value) ).set(value)
def register_embedding_cache_metrics(
endpoint: "Endpoint",
cache: "MultimodalEmbeddingCacheManager",
model_name: str = "",
component_name: str = "",
) -> None:
"""Register Prometheus metrics for a MultimodalEmbeddingCacheManager instance.
Creates a dedicated CollectorRegistry to avoid prometheus_client import-ordering
issues with SGLang's multiprocess mode. Counters use delta-based increments
derived from the cache's monotonic stats. A threading.Lock protects against
concurrent scrape races (axum may serve /metrics from multiple threads).
Must be called AFTER engine initialization to ensure prometheus_client is safe
to import.
Thread safety note:
The callback reads cache.stats from the axum HTTP thread while the asyncio
event loop thread mutates the cache. Under CPython, individual attribute reads
(ints, len()) are GIL-protected, so the worst case is a slightly inconsistent
snapshot within a single scrape (e.g., hits incremented but misses not yet).
This is acceptable for monitoring metrics — values are eventually consistent.
The threading.Lock only serializes concurrent scrapes against each other.
Args:
endpoint: Dynamo Endpoint with metrics.register_prometheus_expfmt_callback().
cache: The MultimodalEmbeddingCacheManager instance to observe.
model_name: Model name for the 'model' label.
component_name: Component name for the 'dynamo_component' label.
"""
# Lazy import: prometheus_client must be imported AFTER set_prometheus_multiproc_dir()
# in SGLang's multiprocess mode. This matches the existing pattern used by
# get_prometheus_expfmt() and LLMBackendMetrics.__init__() in this file.
from prometheus_client import CollectorRegistry, Counter, Gauge, generate_latest
registry = CollectorRegistry()
label_names = [labels.MODEL, labels.COMPONENT]
label_values = {labels.MODEL: model_name, labels.COMPONENT: component_name}
ECM = EmbeddingCacheMetrics
# Counters (delta-incremented from cache's monotonic stats on each scrape)
hits_counter = Counter(
ECM.HITS_TOTAL,
"Total embedding cache hits.",
labelnames=label_names,
registry=registry,
)
misses_counter = Counter(
ECM.MISSES_TOTAL,
"Total embedding cache misses.",
labelnames=label_names,
registry=registry,
)
evictions_counter = Counter(
ECM.EVICTIONS_TOTAL,
"Total embedding cache evictions.",
labelnames=label_names,
registry=registry,
)
# Gauges (snapshot values set on each scrape)
utilization_gauge = Gauge(
ECM.UTILIZATION,
"Cache memory utilization ratio (0.0-1.0).",
labelnames=label_names,
registry=registry,
)
current_bytes_gauge = Gauge(
ECM.CURRENT_BYTES,
"Current cache memory usage in bytes.",
labelnames=label_names,
registry=registry,
)
entries_gauge = Gauge(
ECM.ENTRIES,
"Number of entries in the cache.",
labelnames=label_names,
registry=registry,
)
# Initialize all labeled metrics so they appear in output from the first scrape,
# even before any cache activity (Prometheus best practice: export zeros, not absent).
hits_counter.labels(**label_values)
misses_counter.labels(**label_values)
evictions_counter.labels(**label_values)
lock = threading.Lock()
prev_state = {"hits": 0, "misses": 0, "evictions": 0}
def _collect_embedding_cache_metrics() -> str:
"""Callback invoked on each /metrics scrape."""
with lock:
stats = cache.stats
# Delta-based counter increments from monotonic source values
delta_hits = stats["hits"] - prev_state["hits"]
delta_misses = stats["misses"] - prev_state["misses"]
delta_evictions = stats["evictions"] - prev_state["evictions"]
if delta_hits > 0:
hits_counter.labels(**label_values).inc(delta_hits)
if delta_misses > 0:
misses_counter.labels(**label_values).inc(delta_misses)
if delta_evictions > 0:
evictions_counter.labels(**label_values).inc(delta_evictions)
prev_state["hits"] = stats["hits"]
prev_state["misses"] = stats["misses"]
prev_state["evictions"] = stats["evictions"]
# Set gauge snapshots
utilization_gauge.labels(**label_values).set(stats["utilization"])
current_bytes_gauge.labels(**label_values).set(stats["current_bytes"])
entries_gauge.labels(**label_values).set(stats["entries"])
return generate_latest(registry).decode("utf-8")
endpoint.metrics.register_prometheus_expfmt_callback(
_collect_embedding_cache_metrics
)
logging.info(
"Registered embedding cache metrics (model=%s, component=%s)",
model_name,
component_name,
)
...@@ -9,6 +9,7 @@ import sglang as sgl ...@@ -9,6 +9,7 @@ import sglang as sgl
from dynamo import prometheus_names from dynamo import prometheus_names
from dynamo.common.constants import DisaggregationMode from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.prometheus import register_embedding_cache_metrics
from dynamo.llm import ModelInput from dynamo.llm import ModelInput
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
...@@ -47,6 +48,14 @@ async def init_multimodal_encode_worker( ...@@ -47,6 +48,14 @@ async def init_multimodal_encode_worker(
handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event) handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event)
if handler._embedding_cache is not None:
register_embedding_cache_metrics(
endpoint=generate_endpoint,
cache=handler._embedding_cache,
model_name=server_args.served_model_name,
component_name=dynamo_args.component,
)
await pd_worker_client.wait_for_instances() await pd_worker_client.wait_for_instances()
ready_event = asyncio.Event() ready_event = asyncio.Event()
......
...@@ -35,6 +35,7 @@ from dynamo.common.config_dump import dump_config ...@@ -35,6 +35,7 @@ from dynamo.common.config_dump import dump_config
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.prometheus import ( from dynamo.common.utils.prometheus import (
LLMBackendMetrics, LLMBackendMetrics,
register_embedding_cache_metrics,
register_engine_metrics_callback, register_engine_metrics_callback,
) )
from dynamo.common.utils.runtime import parse_endpoint from dynamo.common.utils.runtime import parse_endpoint
...@@ -584,6 +585,16 @@ async def init_llm_worker( ...@@ -584,6 +585,16 @@ async def init_llm_worker(
) as publisher: ) as publisher:
handler_config.publisher = publisher handler_config.publisher = publisher
handler = RequestHandlerFactory().get_request_handler(handler_config) handler = RequestHandlerFactory().get_request_handler(handler_config)
encoder_cache = getattr(handler, "_encoder_cache", None)
if encoder_cache is not None:
register_embedding_cache_metrics(
endpoint=endpoint,
cache=encoder_cache,
model_name=model_name_for_metrics,
component_name=config.component,
)
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
handler.generate, handler.generate,
metrics_labels=metrics_labels, metrics_labels=metrics_labels,
......
...@@ -17,7 +17,10 @@ from vllm.v1.engine.async_llm import AsyncLLM ...@@ -17,7 +17,10 @@ from vllm.v1.engine.async_llm import AsyncLLM
from dynamo import prometheus_names from dynamo import prometheus_names
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.prometheus import LLMBackendMetrics from dynamo.common.utils.prometheus import (
LLMBackendMetrics,
register_embedding_cache_metrics,
)
from dynamo.llm import ModelInput, ModelType from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -317,6 +320,16 @@ class WorkerFactory: ...@@ -317,6 +320,16 @@ class WorkerFactory:
handler.fpm_relays = fpm_relays handler.fpm_relays = fpm_relays
self.setup_metrics_collection(config, generate_endpoint, logger) self.setup_metrics_collection(config, generate_endpoint, logger)
embedding_cache = getattr(handler, "embedding_cache_manager", None)
if embedding_cache is not None:
register_embedding_cache_metrics(
endpoint=generate_endpoint,
cache=embedding_cache,
model_name=config.served_model_name or config.model,
component_name=config.component,
)
# Register sleep/wake_up engine routes # Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep) runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up) runtime.register_engine_route("wake_up", handler.wake_up)
...@@ -541,6 +554,16 @@ class WorkerFactory: ...@@ -541,6 +554,16 @@ class WorkerFactory:
handler.fpm_relays = fpm_relays handler.fpm_relays = fpm_relays
self.setup_metrics_collection(config, generate_endpoint, logger) self.setup_metrics_collection(config, generate_endpoint, logger)
embedding_cache = getattr(handler, "embedding_cache_manager", None)
if embedding_cache is not None:
register_embedding_cache_metrics(
endpoint=generate_endpoint,
cache=embedding_cache,
model_name=config.served_model_name or config.model,
component_name=config.component,
)
# Register sleep/wake_up engine routes # Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep) runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up) runtime.register_engine_route("wake_up", handler.wake_up)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment