Unverified Commit 325ab6b0 authored by emricksini-h's avatar emricksini-h Committed by GitHub
Browse files

[Feature] OTEL tracing during loading (#31162)

parent 91a07ff6
......@@ -1049,6 +1049,13 @@ setup(
"petit-kernel": ["petit-kernel"],
# Optional deps for Helion kernel development
"helion": ["helion"],
# Optional deps for OpenTelemetry tracing
"otel": [
"opentelemetry-sdk>=1.26.0",
"opentelemetry-api>=1.26.0",
"opentelemetry-exporter-otlp>=1.26.0",
"opentelemetry-semantic-conventions-ai>=0.4.1",
],
},
cmdclass=cmdclass,
package_data=package_data,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections.abc import Callable, Generator, Iterable
from concurrent import futures
from typing import Any, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceRequest,
ExportTraceServiceResponse,
)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer,
add_TraceServiceServicer_to_server,
)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal[
"bool_value", "string_value", "int_value", "double_value", "array_value"
]
def decode_value(value: AnyValue):
"""Decode an OpenTelemetry AnyValue protobuf message to a Python value."""
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value": (
lambda v: [decode_value(item) for item in v.array_value.values]
),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]) -> dict[str, Any]:
"""Decode OpenTelemetry KeyValue attributes to a Python dictionary."""
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
"""A fake gRPC trace service for testing OpenTelemetry trace exports."""
def __init__(self):
self.requests: list[ExportTraceServiceRequest] = []
self.evt = threading.Event()
self._lock = threading.Lock()
def Export(self, request, context):
with self._lock:
self.requests.append(request)
self.evt.set()
return ExportTraceServiceResponse()
@property
def request(self) -> ExportTraceServiceRequest | None:
"""Returns the first request received (for backward compatibility)."""
with self._lock:
return self.requests[0] if self.requests else None
def get_all_spans(self) -> list[dict]:
"""Returns all spans from all received requests as decoded dicts."""
spans = []
with self._lock:
for request in self.requests:
for resource_span in request.resource_spans:
for scope_span in resource_span.scope_spans:
for span in scope_span.spans:
spans.append(
{
"name": span.name,
"attributes": decode_attributes(span.attributes),
"trace_id": span.trace_id.hex(),
"span_id": span.span_id.hex(),
"parent_span_id": span.parent_span_id.hex()
if span.parent_span_id
else None,
"start_time_unix_nano": span.start_time_unix_nano,
"end_time_unix_nano": span.end_time_unix_nano,
}
)
return spans
def wait_for_spans(self, count: int = 1, timeout: float = 10) -> bool:
"""Wait until at least `count` spans have been received."""
import time
deadline = time.time() + timeout
while time.time() < deadline:
if len(self.get_all_spans()) >= count:
return True
time.sleep(0.1)
return False
def clear(self):
"""Clear all received requests."""
with self._lock:
self.requests.clear()
self.evt.clear()
@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service."""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(grace=None)
@pytest.fixture
def trace_server_address() -> str:
"""Returns the address of the fake trace server."""
return FAKE_TRACE_SERVER_ADDRESS
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import pytest
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
from tests.tracing.conftest import FAKE_TRACE_SERVER_ADDRESS, FakeTraceService
from vllm.tracing import init_tracer, instrument, is_otel_available
# Skip everything if OTel is missing
pytestmark = pytest.mark.skipif(not is_otel_available(), reason="OTel required")
class TestCoreInstrumentation:
"""Focuses on the @instrument decorator's ability to capture execution data."""
@pytest.fixture(autouse=True)
def setup_tracing(self, monkeypatch):
monkeypatch.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
init_tracer("test.core", FAKE_TRACE_SERVER_ADDRESS)
def test_decorator_captures_sync_and_async(self, trace_service: FakeTraceService):
"""Verify basic span creation for both sync and async functions."""
@instrument(span_name="sync_task")
def sync_task():
return True
@instrument(span_name="async_task")
async def async_task():
return True
sync_task()
asyncio.run(async_task())
assert trace_service.wait_for_spans(count=2)
span_names = [s["name"] for s in trace_service.get_all_spans()]
assert "sync_task" in span_names
assert "async_task" in span_names
def test_nested_spans_hierarchy(self, trace_service: FakeTraceService):
"""Verify that nested calls create a parent-child relationship."""
@instrument(span_name="child")
def child():
pass
@instrument(span_name="parent")
def parent():
child()
parent()
assert trace_service.wait_for_spans(count=2)
spans = trace_service.get_all_spans()
parent_span = next(s for s in spans if s["name"] == "parent")
child_span = next(s for s in spans if s["name"] == "child")
assert child_span["parent_span_id"] == parent_span["span_id"]
class TestInterProcessPropagation:
"""Test the propagation of trace context between processes."""
def test_pickup_external_context(self, monkeypatch, trace_service):
"""Test that vLLM attaches to an existing trace ID if in environment."""
monkeypatch.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
# Manually simulate an external parent trace ID
fake_trace_id = "4bf92f3577b34da6a3ce929d0e0e4736"
fake_parent_id = "00f067aa0ba902b7"
monkeypatch.setenv("traceparent", f"00-{fake_trace_id}-{fake_parent_id}-01")
init_tracer("test.external", FAKE_TRACE_SERVER_ADDRESS)
@instrument(span_name="follower")
def follower_func():
pass
follower_func()
assert trace_service.wait_for_spans(count=1)
span = trace_service.get_all_spans()[0]
assert span["trace_id"] == fake_trace_id
assert span["parent_span_id"] == fake_parent_id
......@@ -2,76 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# type: ignore
import threading
from collections.abc import Iterable
from concurrent import futures
from typing import Callable, Generator, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceResponse,
)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer,
add_TraceServiceServicer_to_server,
)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
import time
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal[
"bool_value", "string_value", "int_value", "double_value", "array_value"
]
def decode_value(value: AnyValue):
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value": (
lambda v: [decode_value(item) for item in v.array_value.values]
),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]):
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
def __init__(self):
self.request = None
self.evt = threading.Event()
def Export(self, request, context):
self.request = request
self.evt.set()
return ExportTraceServiceResponse()
@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(None)
# Import shared fixtures from the tracing conftest
from tests.tracing.conftest import ( # noqa: F401
FAKE_TRACE_SERVER_ADDRESS,
FakeTraceService,
trace_service,
)
def test_traces(
......@@ -97,29 +40,25 @@ def test_traces(
outputs = llm.generate(prompts, sampling_params=sampling_params)
print(f"test_traces outputs is : {outputs}")
timeout = 10
if not trace_service.evt.wait(timeout):
raise TimeoutError(
f"The fake trace service didn't receive a trace within "
f"the {timeout} seconds timeout"
# Wait for the "llm_request" span to be exported.
# The BatchSpanProcessor batches spans and exports them periodically,
# so we need to wait specifically for the llm_request span to appear.
timeout = 15
deadline = time.time() + timeout
llm_request_spans = []
while time.time() < deadline:
all_spans = trace_service.get_all_spans()
llm_request_spans = [s for s in all_spans if s["name"] == "llm_request"]
if llm_request_spans:
break
time.sleep(0.5)
assert len(llm_request_spans) == 1, (
f"Expected exactly 1 'llm_request' span, but got {len(llm_request_spans)}. "
f"All span names: {[s['name'] for s in all_spans]}"
)
request = trace_service.request
assert len(request.resource_spans) == 1, (
f"Expected 1 resource span, but got {len(request.resource_spans)}"
)
assert len(request.resource_spans[0].scope_spans) == 1, (
f"Expected 1 scope span, "
f"but got {len(request.resource_spans[0].scope_spans)}"
)
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
f"Expected 1 span, "
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}"
)
attributes = decode_attributes(
request.resource_spans[0].scope_spans[0].spans[0].attributes
)
attributes = llm_request_spans[0]["attributes"]
# assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
assert (
......
......@@ -33,6 +33,7 @@ from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform
from vllm.tracing import instrument, instrument_manual
from vllm.utils.import_utils import resolve_obj_by_qualname
from .compiler_interface import (
......@@ -234,6 +235,7 @@ class CompilerManager:
)
return compiled_graph
@instrument(span_name="Compile graph")
def compile(
self,
graph: fx.GraphModule,
......@@ -497,6 +499,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
@instrument(span_name="Inductor compilation")
def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake?
fake_args = [
......@@ -922,6 +925,11 @@ class VllmBackend:
)
self.compilation_config.compilation_time += dynamo_time
# Record Dynamo time in tracing if available
start_time = int(torch_compile_start_time * 1e9)
attributes = {"dynamo.time_seconds": dynamo_time}
instrument_manual("Dynamo bytecode transform", start_time, None, attributes)
# we control the compilation process, each instance can only be
# called once
assert not self._called, "VllmBackend can only be called once"
......
......@@ -122,9 +122,9 @@ class ObservabilityConfig:
@classmethod
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
if value is not None:
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.tracing import is_tracing_available, otel_import_error_traceback
if not is_otel_available():
if not is_tracing_available():
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
......
......@@ -50,6 +50,7 @@ from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager
from vllm.tracing import instrument
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
......@@ -377,6 +378,7 @@ def validate_api_server_args(args):
)
@instrument(span_name="API server setup")
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
......
......@@ -14,6 +14,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading,
)
from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import set_default_torch_dtype
......@@ -37,6 +38,7 @@ class BaseModelLoader(ABC):
inplace weights loading for an already-initialized model"""
raise NotImplementedError
@instrument(span_name="Load model")
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
......
......@@ -30,6 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.tracing import instrument
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__)
......@@ -274,6 +275,7 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns_overrides=None,
)
@instrument(span_name="Load weights")
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
......
......@@ -23,11 +23,13 @@ from vllm.model_executor.model_loader.reload import (
set_torchao_reload_attrs,
)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.tracing import instrument
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
@instrument(span_name="Initialize model")
def initialize_model(
vllm_config: VllmConfig,
*,
......
......@@ -36,6 +36,7 @@ from vllm.model_executor.layers.quantization import (
get_quantization_config,
)
from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.import_utils import PlaceholderModule
try:
......@@ -443,6 +444,7 @@ def download_gguf(
return local_files[0]
@instrument(span_name="Download weights - HF")
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: str | None,
......
......@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.tracing import instrument
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_mk_alignment_for_contiguous_layout,
......@@ -358,6 +359,7 @@ def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
return total
@instrument(span_name="DeepGemm warmup")
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
total = _count_warmup_iterations(model, max_tokens)
if total == 0:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from typing import Any, TypeAlias
# Import the implementation details
from .otel import (
SpanKind,
extract_trace_context,
init_otel_tracer,
init_otel_worker_tracer,
instrument_otel,
is_otel_available,
manual_instrument_otel,
otel_import_error_traceback,
)
from .utils import (
SpanAttributes,
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
__all__ = [
"instrument",
"instrument_manual",
"init_tracer",
"maybe_init_worker_tracer",
"is_tracing_available",
"SpanAttributes",
"SpanKind",
"extract_trace_context",
"extract_trace_headers",
"log_tracing_disabled_warning",
"contains_trace_headers",
"otel_import_error_traceback",
]
BackendAvailableFunc: TypeAlias = Callable[[], bool]
InstrumentFunc: TypeAlias = Callable[..., Any]
InstrumentManualFunc: TypeAlias = Callable[..., Any]
InitTracerFunc: TypeAlias = Callable[..., Any]
InitWorkerTracerFunc: TypeAlias = Callable[..., Any]
_REGISTERED_TRACING_BACKENDS: dict[
str,
tuple[
BackendAvailableFunc,
InitTracerFunc,
InitWorkerTracerFunc,
InstrumentFunc,
InstrumentManualFunc,
],
] = {
"otel": (
is_otel_available,
init_otel_tracer,
init_otel_worker_tracer,
instrument_otel,
manual_instrument_otel,
),
}
def init_tracer(
instrumenting_module_name: str,
otlp_traces_endpoint: str,
extra_attributes: dict[str, str] | None = None,
):
is_available, init_tracer_fn, _, _, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return init_tracer_fn(
instrumenting_module_name, otlp_traces_endpoint, extra_attributes
)
def maybe_init_worker_tracer(
instrumenting_module_name: str,
process_kind: str,
process_name: str,
):
is_available, _, init_worker_tracer_fn, _, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return init_worker_tracer_fn(
instrumenting_module_name, process_kind, process_name
)
def instrument(
obj: Callable | None = None,
*,
span_name: str = "",
attributes: dict[str, str] | None = None,
record_exception: bool = True,
):
"""
Generic decorator to instrument functions.
"""
if obj is None:
return functools.partial(
instrument,
span_name=span_name,
attributes=attributes,
record_exception=record_exception,
)
# Dispatch to OTel (and potentially others later)
is_available, _, _, otel_instrument, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return otel_instrument(
func=obj,
span_name=span_name,
attributes=attributes,
record_exception=record_exception,
)
else:
return obj
def instrument_manual(
span_name: str,
start_time: int,
end_time: int | None = None,
attributes: dict[str, Any] | None = None,
context: Any = None,
kind: Any = None,
):
"""Manually create a span with explicit timestamps.
Args:
span_name: Name of the span to create.
start_time: Start time in nanoseconds since epoch.
end_time: Optional end time in nanoseconds. If None, ends immediately.
attributes: Optional dict of span attributes.
context: Optional trace context (e.g., from extract_trace_context).
kind: Optional SpanKind (e.g., SpanKind.SERVER).
"""
is_available, _, _, _, manual_instrument_fn = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return manual_instrument_fn(
span_name, start_time, end_time, attributes, context, kind
)
else:
return None
def is_tracing_available() -> bool:
"""
Returns True if any tracing backend (OTel, Profiler, etc.) is available.
Use this to guard expensive tracing logic in the main code.
"""
check_available = [
is_available
for is_available, _, _, _, _ in _REGISTERED_TRACING_BACKENDS.values()
]
return any(check_available)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import functools
import inspect
import os
import traceback
from collections.abc import Mapping
from contextlib import contextmanager
from typing import Any
from vllm.logger import init_logger
from vllm.tracing.utils import TRACE_HEADERS, LoadingSpanAttributes
logger = init_logger(__name__)
try:
from opentelemetry import trace
from opentelemetry.context.context import Context
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPGrpcExporter,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPHttpExporter,
)
from opentelemetry.propagate import inject
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import (
SpanKind, # noqa: F401
Tracer,
set_tracer_provider,
)
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
_IS_OTEL_AVAILABLE = True
otel_import_error_traceback = None
except ImportError:
_IS_OTEL_AVAILABLE = False
otel_import_error_traceback = traceback.format_exc()
trace = None # type: ignore
Context = Any # type: ignore
Tracer = Any # type: ignore
inject = None # type: ignore
Resource = None # type: ignore
SpanKind = Any # type: ignore
def is_otel_available() -> bool:
return _IS_OTEL_AVAILABLE
def init_otel_tracer(
instrumenting_module_name: str,
otlp_traces_endpoint: str,
extra_attributes: dict[str, str] | None = None,
) -> Tracer:
"""Initializes the OpenTelemetry tracer provider."""
if not _IS_OTEL_AVAILABLE:
raise ValueError(
"OpenTelemetry is not available. Unable to initialize "
"a tracer. Ensure OpenTelemetry packages are installed. "
f"Original error:\n{otel_import_error_traceback}"
)
# Store the endpoint in environment so child processes can inherit it
os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = otlp_traces_endpoint
resource_attrs = {}
resource_attrs["vllm.instrumenting_module_name"] = instrumenting_module_name
resource_attrs["vllm.process_id"] = str(os.getpid())
if extra_attributes:
resource_attrs.update(extra_attributes)
resource = Resource.create(resource_attrs)
trace_provider = TracerProvider(resource=resource)
span_exporter = get_span_exporter(otlp_traces_endpoint)
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
set_tracer_provider(trace_provider)
atexit.register(trace_provider.shutdown)
tracer = trace_provider.get_tracer(instrumenting_module_name)
return tracer
def get_span_exporter(endpoint):
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
if protocol == "grpc":
exporter = OTLPGrpcExporter(endpoint=endpoint, insecure=True)
elif protocol == "http/protobuf":
exporter = OTLPHttpExporter(endpoint=endpoint)
else:
raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured")
return exporter
def init_otel_worker_tracer(
instrumenting_module_name: str,
process_kind: str,
process_name: str,
) -> Tracer:
"""
Backend-specific initialization for OpenTelemetry in a worker process.
"""
# Initialize the tracer if an OTLP endpoint is configured.
# The endpoint is propagated via environment variable from the main process.
otlp_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")
if not otlp_endpoint:
return None
extra_attrs = {
"vllm.process_kind": process_kind,
"vllm.process_name": process_name,
}
return init_otel_tracer(instrumenting_module_name, otlp_endpoint, extra_attrs)
def extract_trace_context(headers: Mapping[str, str] | None) -> Context | None:
"""Extracts context from HTTP headers."""
if _IS_OTEL_AVAILABLE and headers:
return TraceContextTextMapPropagator().extract(headers)
return None
def instrument_otel(func, span_name, attributes, record_exception):
"""Internal wrapper logic for sync and async functions."""
# Pre-calculate static code attributes once (these don't change)
code_attrs = {
LoadingSpanAttributes.CODE_FUNCTION: func.__qualname__,
LoadingSpanAttributes.CODE_NAMESPACE: func.__module__,
LoadingSpanAttributes.CODE_FILEPATH: func.__code__.co_filename,
LoadingSpanAttributes.CODE_LINENO: str(func.__code__.co_firstlineno),
}
if attributes:
code_attrs.update(attributes)
final_span_name = span_name or func.__qualname__
module_name = func.__module__
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
tracer = trace.get_tracer(module_name)
ctx = _get_smart_context()
with (
tracer.start_as_current_span(
final_span_name,
context=ctx,
attributes=code_attrs,
record_exception=record_exception,
),
propagate_trace_to_env(),
):
return await func(*args, **kwargs)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
tracer = trace.get_tracer(module_name)
ctx = _get_smart_context()
with (
tracer.start_as_current_span(
final_span_name,
context=ctx,
attributes=code_attrs,
record_exception=record_exception,
),
propagate_trace_to_env(),
):
return func(*args, **kwargs)
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
def manual_instrument_otel(
span_name: str,
start_time: int,
end_time: int | None = None,
attributes: dict[str, Any] | None = None,
context: Context | None = None,
kind: Any = None, # SpanKind, but typed as Any for when OTEL unavailable
):
"""Manually create and end a span with explicit timestamps."""
if not _IS_OTEL_AVAILABLE:
return
tracer = trace.get_tracer(__name__)
# Use provided context, or fall back to smart context detection
ctx = context if context is not None else _get_smart_context()
span_kwargs: dict[str, Any] = {
"name": span_name,
"context": ctx,
"start_time": start_time,
}
if kind is not None:
span_kwargs["kind"] = kind
span = tracer.start_span(**span_kwargs)
if attributes:
span.set_attributes(attributes)
if end_time is not None:
span.end(end_time=end_time)
else:
span.end()
def _get_smart_context() -> Context | None:
"""
Determines the parent context.
1. If a Span is already active in this process, use it.
2. If not, extract from os.environ, handling the case-sensitivity mismatch.
"""
current_span = trace.get_current_span()
if current_span.get_span_context().is_valid:
return None
carrier = {}
if tp := os.environ.get("traceparent", os.environ.get("TRACEPARENT")): # noqa: SIM112
carrier["traceparent"] = tp
if ts := os.environ.get("tracestate", os.environ.get("TRACESTATE")): # noqa: SIM112
carrier["tracestate"] = ts
if not carrier:
carrier = dict(os.environ)
return TraceContextTextMapPropagator().extract(carrier)
@contextmanager
def propagate_trace_to_env():
"""
Temporarily injects the current OTel context into os.environ.
This ensures that any subprocesses (like vLLM workers) spawned
within this context inherit the correct traceparent.
"""
if not _IS_OTEL_AVAILABLE:
yield
return
# Capture original state of relevant keys
original_state = {k: os.environ.get(k) for k in TRACE_HEADERS}
try:
# inject() writes 'traceparent' and 'tracestate' to os.environ
inject(os.environ)
yield
finally:
# Restore original environment
for key, original_value in original_state.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Mapping
from vllm.logger import init_logger
from vllm.utils.func_utils import run_once
TRACE_HEADERS = ["traceparent", "tracestate"]
logger = init_logger(__name__)
_is_otel_imported = False
otel_import_error_traceback: str | None = None
try:
from opentelemetry.context.context import Context
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
_is_otel_imported = True
except ImportError:
# Capture and format traceback to provide detailed context for the import
# error. Only the string representation of the error is retained to avoid
# memory leaks.
# See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458
import traceback
otel_import_error_traceback = traceback.format_exc()
class Context: # type: ignore
pass
class BaseSpanAttributes: # type: ignore
pass
class SpanKind: # type: ignore
pass
class Tracer: # type: ignore
pass
def is_otel_available() -> bool:
return _is_otel_imported
def init_tracer(
instrumenting_module_name: str, otlp_traces_endpoint: str
) -> Tracer | None:
if not is_otel_available():
raise ValueError(
"OpenTelemetry is not available. Unable to initialize "
"a tracer. Ensure OpenTelemetry packages are installed. "
f"Original error:\n{otel_import_error_traceback}"
)
trace_provider = TracerProvider()
span_exporter = get_span_exporter(otlp_traces_endpoint)
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
set_tracer_provider(trace_provider)
tracer = trace_provider.get_tracer(instrumenting_module_name)
return tracer
def get_span_exporter(endpoint):
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
if protocol == "grpc":
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter,
)
elif protocol == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter, # type: ignore
)
else:
raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured")
return OTLPSpanExporter(endpoint=endpoint)
def extract_trace_context(headers: Mapping[str, str] | None) -> Context | None:
if is_otel_available():
headers = headers or {}
return TraceContextTextMapPropagator().extract(headers)
else:
return None
# Standard W3C headers used for context propagation
TRACE_HEADERS = ["traceparent", "tracestate"]
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
return {h: headers[h] for h in TRACE_HEADERS if h in headers}
class SpanAttributes:
"""
Standard attributes for spans.
These are largely based on OpenTelemetry Semantic Conventions but are defined
here as constants so they can be used by any backend or logger.
"""
class SpanAttributes:
# Attribute names copied from here to avoid version conflicts:
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
# Attribute names copied from OTel semantic conventions to avoid version conflicts
GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens"
GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens"
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
# Attribute names added until they are added to the semantic conventions:
# Custom attributes added until they are standardized
GEN_AI_REQUEST_ID = "gen_ai.request.id"
GEN_AI_REQUEST_N = "gen_ai.request.n"
GEN_AI_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
......@@ -116,20 +36,37 @@ class SpanAttributes:
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
# Time taken in the forward pass for this across all workers
# Latency breakdowns
GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
# Time taken in the model execute function. This will include model
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill"
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference"
class LoadingSpanAttributes:
"""Custom attributes for code-level tracing (file, line number)."""
CODE_NAMESPACE = "code.namespace"
CODE_FUNCTION = "code.function"
CODE_FILEPATH = "code.filepath"
CODE_LINENO = "code.lineno"
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
"""Check if the provided headers dictionary contains trace context."""
return any(h in headers for h in TRACE_HEADERS)
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
"""
Extract only trace-related headers from a larger header dictionary.
Useful for logging or passing context to a non-OTel client.
"""
return {h: headers[h] for h in TRACE_HEADERS if h in headers}
@run_once
def log_tracing_disabled_warning() -> None:
logger.warning("Received a request with trace context but tracing is disabled")
......@@ -110,6 +110,10 @@ class AsyncLLM(EngineClient):
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
tracing_endpoint = self.observability_config.otlp_traces_endpoint
if tracing_endpoint is not None:
init_tracer("vllm.llm_engine", tracing_endpoint)
self.log_requests = log_requests
custom_stat_loggers = list(stat_loggers or [])
......@@ -136,10 +140,8 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
if tracing_endpoint is not None:
self.output_processor.tracing_enabled = True
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
......
......@@ -24,6 +24,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import (
freeze_gc_heap,
......@@ -217,6 +218,7 @@ class EngineCore:
# environment variable overrides after this point)
enable_envs_cache()
@instrument(span_name="Prepare model")
def _initialize_kv_caches(
self, vllm_config: VllmConfig
) -> tuple[int, int, KVCacheConfig]:
......@@ -658,6 +660,7 @@ class EngineCoreProc(EngineCore):
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
@instrument(span_name="EngineCoreProc init")
def __init__(
self,
vllm_config: VllmConfig,
......@@ -926,8 +929,18 @@ class EngineCoreProc(EngineCore):
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
if data_parallel:
parallel_config.data_parallel_rank_local = local_dp_rank
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name=f"EngineCore_DP{dp_rank}",
)
set_process_title("EngineCore", f"DP{dp_rank}")
else:
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name="EngineCore",
)
set_process_title("EngineCore")
decorate_logs()
......@@ -956,6 +969,7 @@ class EngineCoreProc(EngineCore):
parallel_config.data_parallel_rank = 0
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
assert engine_core is not None
engine_core.run_busy_loop()
except SystemExit:
......@@ -1485,6 +1499,13 @@ class EngineCoreActorMixin:
dp_rank: int = 0,
local_dp_rank: int = 0,
):
# Initialize tracer for distributed tracing if configured.
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name=f"DPEngineCoreActor_DP{dp_rank}",
)
self.addresses = addresses
vllm_config.parallel_config.data_parallel_index = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
......
......@@ -24,6 +24,7 @@ from vllm.envs import VLLM_ENGINE_READY_TIMEOUT_S
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import (
close_sockets,
......@@ -96,6 +97,7 @@ class EngineCoreClient(ABC):
return InprocClient(vllm_config, executor_class, log_stats)
@staticmethod
@instrument(span_name="Overall Loading")
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
......@@ -650,6 +652,7 @@ def _process_utility_output(
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
@instrument(span_name="SyncMPClient init")
def __init__(
self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool
):
......@@ -819,6 +822,7 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
@instrument(span_name="AsyncMPClient init")
def __init__(
self,
vllm_config: VllmConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment