Unverified Commit 3e341fd6 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): expose TokenizerMetricsCollector metrics via Prometheus (#5120)

parent 0980b27f
...@@ -16,8 +16,6 @@ import re ...@@ -16,8 +16,6 @@ import re
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Optional, Pattern from typing import TYPE_CHECKING, Optional, Pattern
from prometheus_client import generate_latest
from dynamo._core import Endpoint from dynamo._core import Endpoint
# Import CollectorRegistry only for type hints to avoid importing prometheus_client at module load time. # Import CollectorRegistry only for type hints to avoid importing prometheus_client at module load time.
...@@ -119,6 +117,11 @@ def get_prometheus_expfmt( ...@@ -119,6 +117,11 @@ def get_prometheus_expfmt(
Collects all metrics from the registry and returns them in Prometheus text exposition format. Collects all metrics from the registry and returns them in Prometheus text exposition format.
Optionally filters metrics by prefix, excludes certain prefixes, and adds a prefix. Optionally filters metrics by prefix, excludes certain prefixes, and adds a prefix.
IMPORTANT: prometheus_client is imported lazily here because it must be imported AFTER
set_prometheus_multiproc_dir() is called by SGLang's engine initialization. Importing
at module level causes prometheus_client to initialize in single-process mode before
PROMETHEUS_MULTIPROC_DIR is set, which breaks TokenizerMetricsCollector metrics.
Args: Args:
registry: Prometheus registry to collect from. registry: Prometheus registry to collect from.
Pass CollectorRegistry with MultiProcessCollector for SGLang. Pass CollectorRegistry with MultiProcessCollector for SGLang.
...@@ -138,6 +141,8 @@ def get_prometheus_expfmt( ...@@ -138,6 +141,8 @@ def get_prometheus_expfmt(
# Filter out python_/process_ metrics and add trtllm_ prefix # Filter out python_/process_ metrics and add trtllm_ prefix
get_prometheus_expfmt(registry, exclude_prefixes=["python_", "process_"], add_prefix="trtllm_") get_prometheus_expfmt(registry, exclude_prefixes=["python_", "process_"], add_prefix="trtllm_")
""" """
from prometheus_client import generate_latest
try: try:
# Generate metrics in Prometheus text format # Generate metrics in Prometheus text format
metrics_text = generate_latest(registry).decode("utf-8") metrics_text = generate_latest(registry).decode("utf-8")
......
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import sglang as sgl import sglang as sgl
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from prometheus_client import CollectorRegistry, multiprocess
from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, maybe_wrap_ipv6_address from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, maybe_wrap_ipv6_address
if TYPE_CHECKING:
from prometheus_client import CollectorRegistry
from dynamo.common.utils.prometheus import register_engine_metrics_callback from dynamo.common.utils.prometheus import register_engine_metrics_callback
from dynamo.llm import ( from dynamo.llm import (
ForwardPassMetrics, ForwardPassMetrics,
...@@ -224,7 +226,7 @@ class DynamoSglangPublisher: ...@@ -224,7 +226,7 @@ class DynamoSglangPublisher:
def setup_prometheus_registry( def setup_prometheus_registry(
engine: sgl.Engine, generate_endpoint: Endpoint engine: sgl.Engine, generate_endpoint: Endpoint
) -> CollectorRegistry: ) -> "CollectorRegistry":
"""Set up Prometheus registry for SGLang metrics collection. """Set up Prometheus registry for SGLang metrics collection.
SGLang uses multiprocess architecture where metrics are stored in shared memory. SGLang uses multiprocess architecture where metrics are stored in shared memory.
...@@ -232,6 +234,11 @@ def setup_prometheus_registry( ...@@ -232,6 +234,11 @@ def setup_prometheus_registry(
registry collects sglang:* metrics which are exposed via the metrics server endpoint registry collects sglang:* metrics which are exposed via the metrics server endpoint
(set DYN_SYSTEM_PORT to a positive value to enable, e.g., DYN_SYSTEM_PORT=8081). (set DYN_SYSTEM_PORT to a positive value to enable, e.g., DYN_SYSTEM_PORT=8081).
IMPORTANT: prometheus_client must be imported AFTER sgl.Engine() has called
set_prometheus_multiproc_dir(). Importing at module level causes prometheus_client
to initialize in single-process mode before PROMETHEUS_MULTIPROC_DIR is set,
which breaks TokenizerMetricsCollector metrics (TTFT, ITL, e2e latency, etc.).
Args: Args:
engine: The SGLang engine instance. engine: The SGLang engine instance.
generate_endpoint: The Dynamo endpoint for generation requests. generate_endpoint: The Dynamo endpoint for generation requests.
...@@ -239,6 +246,8 @@ def setup_prometheus_registry( ...@@ -239,6 +246,8 @@ def setup_prometheus_registry(
Returns: Returns:
Configured CollectorRegistry with multiprocess support. Configured CollectorRegistry with multiprocess support.
""" """
from prometheus_client import CollectorRegistry, multiprocess
registry = CollectorRegistry() registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry) multiprocess.MultiProcessCollector(registry)
register_engine_metrics_callback( register_engine_metrics_callback(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Unit tests for Prometheus utilities.""" """Unit tests for Prometheus utilities."""
from unittest.mock import Mock from unittest.mock import Mock, patch
import pytest import pytest
...@@ -21,12 +21,7 @@ pytestmark = [ ...@@ -21,12 +21,7 @@ pytestmark = [
class TestGetPrometheusExpfmt: class TestGetPrometheusExpfmt:
"""Test class for get_prometheus_expfmt function.""" """Test class for get_prometheus_expfmt function."""
@pytest.fixture SAMPLE_METRICS = """# HELP python_gc_objects_collected_total Objects collected during gc
def sglang_registry(self):
"""Create a mock registry with SGLang-style metrics."""
registry = Mock()
sample_metrics = """# HELP python_gc_objects_collected_total Objects collected during gc
# TYPE python_gc_objects_collected_total counter # TYPE python_gc_objects_collected_total counter
python_gc_objects_collected_total{generation="0"} 123.0 python_gc_objects_collected_total{generation="0"} 123.0
# HELP process_cpu_seconds_total Total user and system CPU time spent in seconds # HELP process_cpu_seconds_total Total user and system CPU time spent in seconds
...@@ -43,25 +38,19 @@ sglang:generation_tokens_total{model_name="meta-llama/Llama-3.1-8B-Instruct"} 75 ...@@ -43,25 +38,19 @@ sglang:generation_tokens_total{model_name="meta-llama/Llama-3.1-8B-Instruct"} 75
sglang:cache_hit_rate{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0075 sglang:cache_hit_rate{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0075
""" """
def mock_generate_latest(reg): def test_sglang_use_case(self):
return sample_metrics.encode("utf-8")
import dynamo.common.utils.prometheus
original_generate_latest = dynamo.common.utils.prometheus.generate_latest
dynamo.common.utils.prometheus.generate_latest = mock_generate_latest
yield registry
dynamo.common.utils.prometheus.generate_latest = original_generate_latest
def test_sglang_use_case(self, sglang_registry):
"""Test SGLang use case: filter to sglang: metrics and exclude python_/process_.""" """Test SGLang use case: filter to sglang: metrics and exclude python_/process_."""
result = get_prometheus_expfmt( registry = Mock()
sglang_registry,
metric_prefix_filters=["sglang:"], with patch(
exclude_prefixes=["python_", "process_"], "prometheus_client.generate_latest",
) return_value=self.SAMPLE_METRICS.encode("utf-8"),
):
result = get_prometheus_expfmt(
registry,
metric_prefix_filters=["sglang:"],
exclude_prefixes=["python_", "process_"],
)
# Should only contain sglang: metrics # Should only contain sglang: metrics
assert "sglang:prompt_tokens_total" in result assert "sglang:prompt_tokens_total" in result
...@@ -80,11 +69,13 @@ sglang:cache_hit_rate{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0075 ...@@ -80,11 +69,13 @@ sglang:cache_hit_rate{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0075
def test_error_handling(self): def test_error_handling(self):
"""Test error handling when registry fails.""" """Test error handling when registry fails."""
# Create a registry that raises an exception
bad_registry = Mock() bad_registry = Mock()
bad_registry.side_effect = Exception("Registry error")
result = get_prometheus_expfmt(bad_registry) with patch(
"prometheus_client.generate_latest",
side_effect=Exception("Registry error"),
):
result = get_prometheus_expfmt(bad_registry)
# Should return empty string on error # Should return empty string on error
assert result == "" assert result == ""
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Unit tests for Prometheus utilities.""" """Unit tests for Prometheus utilities."""
from unittest.mock import Mock from unittest.mock import Mock, patch
import pytest import pytest
...@@ -21,12 +21,7 @@ pytestmark = [ ...@@ -21,12 +21,7 @@ pytestmark = [
class TestGetPrometheusExpfmt: class TestGetPrometheusExpfmt:
"""Test class for get_prometheus_expfmt function.""" """Test class for get_prometheus_expfmt function."""
@pytest.fixture TRTLLM_SAMPLE_METRICS = """# HELP python_gc_objects_collected_total Objects collected during gc
def trtllm_registry(self):
"""Create a mock registry with TensorRT-LLM-style metrics (no existing prefixes)."""
registry = Mock()
sample_metrics = """# HELP python_gc_objects_collected_total Objects collected during gc
# TYPE python_gc_objects_collected_total counter # TYPE python_gc_objects_collected_total counter
python_gc_objects_collected_total{generation="0"} 123.0 python_gc_objects_collected_total{generation="0"} 123.0
# HELP process_cpu_seconds_total Total user and system CPU time spent in seconds # HELP process_cpu_seconds_total Total user and system CPU time spent in seconds
...@@ -44,25 +39,19 @@ num_requests_running 3.0 ...@@ -44,25 +39,19 @@ num_requests_running 3.0
tokens_per_second 245.7 tokens_per_second 245.7
""" """
def mock_generate_latest(reg): def test_trtllm_use_case(self):
return sample_metrics.encode("utf-8")
import dynamo.common.utils.prometheus
original_generate_latest = dynamo.common.utils.prometheus.generate_latest
dynamo.common.utils.prometheus.generate_latest = mock_generate_latest
yield registry
dynamo.common.utils.prometheus.generate_latest = original_generate_latest
def test_trtllm_use_case(self, trtllm_registry):
"""Test TensorRT-LLM use case: exclude python_/process_ and add trtllm_ prefix.""" """Test TensorRT-LLM use case: exclude python_/process_ and add trtllm_ prefix."""
result = get_prometheus_expfmt( registry = Mock()
trtllm_registry,
exclude_prefixes=["python_", "process_"], with patch(
add_prefix="trtllm_", "prometheus_client.generate_latest",
) return_value=self.TRTLLM_SAMPLE_METRICS.encode("utf-8"),
):
result = get_prometheus_expfmt(
registry,
exclude_prefixes=["python_", "process_"],
add_prefix="trtllm_",
)
# Should not contain excluded metrics # Should not contain excluded metrics
assert "python_gc_objects_collected_total" not in result assert "python_gc_objects_collected_total" not in result
...@@ -82,9 +71,15 @@ tokens_per_second 245.7 ...@@ -82,9 +71,15 @@ tokens_per_second 245.7
assert "trtllm_tokens_per_second 245.7" in result assert "trtllm_tokens_per_second 245.7" in result
assert result.endswith("\n") assert result.endswith("\n")
def test_no_filtering_all_frameworks(self, trtllm_registry): def test_no_filtering_all_frameworks(self):
"""Test that without any filters, all metrics are returned.""" """Test that without any filters, all metrics are returned."""
result = get_prometheus_expfmt(trtllm_registry) registry = Mock()
with patch(
"prometheus_client.generate_latest",
return_value=self.TRTLLM_SAMPLE_METRICS.encode("utf-8"),
):
result = get_prometheus_expfmt(registry)
# Should contain all metrics including excluded ones # Should contain all metrics including excluded ones
assert "python_gc_objects_collected_total" in result assert "python_gc_objects_collected_total" in result
...@@ -93,12 +88,18 @@ tokens_per_second 245.7 ...@@ -93,12 +88,18 @@ tokens_per_second 245.7
assert "num_requests_running" in result assert "num_requests_running" in result
assert result.endswith("\n") assert result.endswith("\n")
def test_empty_result_handling(self, trtllm_registry): def test_empty_result_handling(self):
"""Test handling when all metrics are filtered out.""" """Test handling when all metrics are filtered out."""
result = get_prometheus_expfmt( registry = Mock()
trtllm_registry,
exclude_prefixes=["python_", "process_", "request_", "num_", "tokens_"], with patch(
) "prometheus_client.generate_latest",
return_value=self.TRTLLM_SAMPLE_METRICS.encode("utf-8"),
):
result = get_prometheus_expfmt(
registry,
exclude_prefixes=["python_", "process_", "request_", "num_", "tokens_"],
)
# Should return empty string with newline or just newline # Should return empty string with newline or just newline
assert result == "\n" or result == "" assert result == "\n" or result == ""
...@@ -116,36 +117,31 @@ trtllm_request_success_total{model_name="test",finished_reason="stop"} 10.0 ...@@ -116,36 +117,31 @@ trtllm_request_success_total{model_name="test",finished_reason="stop"} 10.0
trtllm_time_to_first_token_seconds_count 5.0 trtllm_time_to_first_token_seconds_count 5.0
""" """
def mock_generate_latest(reg): with patch(
return sample_metrics.encode("utf-8") "prometheus_client.generate_latest",
return_value=sample_metrics.encode("utf-8"),
import dynamo.common.utils.prometheus ):
original_generate_latest = dynamo.common.utils.prometheus.generate_latest
dynamo.common.utils.prometheus.generate_latest = mock_generate_latest
try:
result = get_prometheus_expfmt( result = get_prometheus_expfmt(
registry, registry,
exclude_prefixes=["python_", "process_"], exclude_prefixes=["python_", "process_"],
add_prefix="trtllm_", add_prefix="trtllm_",
) )
# Should not double-add prefix # Should not double-add prefix
assert "trtllm_trtllm_request_success_total" not in result assert "trtllm_trtllm_request_success_total" not in result
assert "trtllm_request_success_total" in result assert "trtllm_request_success_total" in result
assert "trtllm_time_to_first_token_seconds" in result assert "trtllm_time_to_first_token_seconds" in result
assert result.endswith("\n") assert result.endswith("\n")
finally:
dynamo.common.utils.prometheus.generate_latest = original_generate_latest
def test_error_handling(self): def test_error_handling(self):
"""Test error handling when registry fails.""" """Test error handling when registry fails."""
# Create a registry that raises an exception
bad_registry = Mock() bad_registry = Mock()
bad_registry.side_effect = Exception("Registry error")
result = get_prometheus_expfmt(bad_registry) with patch(
"prometheus_client.generate_latest",
side_effect=Exception("Registry error"),
):
result = get_prometheus_expfmt(bad_registry)
# Should return empty string on error # Should return empty string on error
assert result == "" assert result == ""
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Unit tests for Prometheus utilities.""" """Unit tests for Prometheus utilities."""
from unittest.mock import Mock from unittest.mock import Mock, patch
import pytest import pytest
...@@ -21,12 +21,7 @@ pytestmark = [ ...@@ -21,12 +21,7 @@ pytestmark = [
class TestGetPrometheusExpfmt: class TestGetPrometheusExpfmt:
"""Test class for get_prometheus_expfmt function.""" """Test class for get_prometheus_expfmt function."""
@pytest.fixture SAMPLE_METRICS = """# HELP python_gc_objects_collected_total Objects collected during gc
def vllm_registry(self):
"""Create a mock registry with vLLM-style metrics."""
registry = Mock()
sample_metrics = """# HELP python_gc_objects_collected_total Objects collected during gc
# TYPE python_gc_objects_collected_total counter # TYPE python_gc_objects_collected_total counter
python_gc_objects_collected_total{generation="0"} 123.0 python_gc_objects_collected_total{generation="0"} 123.0
# HELP process_cpu_seconds_total Total user and system CPU time spent in seconds # HELP process_cpu_seconds_total Total user and system CPU time spent in seconds
...@@ -41,25 +36,19 @@ vllm:time_to_first_token_seconds_bucket{le="0.005",model_name="meta-llama/Llama- ...@@ -41,25 +36,19 @@ vllm:time_to_first_token_seconds_bucket{le="0.005",model_name="meta-llama/Llama-
vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B"} 165.0 vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B"} 165.0
""" """
def mock_generate_latest(reg): def test_vllm_use_case(self):
return sample_metrics.encode("utf-8")
import dynamo.common.utils.prometheus
original_generate_latest = dynamo.common.utils.prometheus.generate_latest
dynamo.common.utils.prometheus.generate_latest = mock_generate_latest
yield registry
dynamo.common.utils.prometheus.generate_latest = original_generate_latest
def test_vllm_use_case(self, vllm_registry):
"""Test vLLM use case: filter to vllm: metrics and exclude python_/process_.""" """Test vLLM use case: filter to vllm: metrics and exclude python_/process_."""
result = get_prometheus_expfmt( registry = Mock()
vllm_registry,
metric_prefix_filters=["vllm:"], with patch(
exclude_prefixes=["python_", "process_"], "prometheus_client.generate_latest",
) return_value=self.SAMPLE_METRICS.encode("utf-8"),
):
result = get_prometheus_expfmt(
registry,
metric_prefix_filters=["vllm:"],
exclude_prefixes=["python_", "process_"],
)
# Should only contain vllm: metrics # Should only contain vllm: metrics
assert "vllm:request_success_total" in result assert "vllm:request_success_total" in result
...@@ -77,11 +66,13 @@ vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B"} 165 ...@@ -77,11 +66,13 @@ vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B"} 165
def test_error_handling(self): def test_error_handling(self):
"""Test error handling when registry fails.""" """Test error handling when registry fails."""
# Create a registry that raises an exception
bad_registry = Mock() bad_registry = Mock()
bad_registry.side_effect = Exception("Registry error")
result = get_prometheus_expfmt(bad_registry) with patch(
"prometheus_client.generate_latest",
side_effect=Exception("Registry error"),
):
result = get_prometheus_expfmt(bad_registry)
# Should return empty string on error # Should return empty string on error
assert result == "" assert result == ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment