Unverified Commit 325ab6b0 authored by emricksini-h's avatar emricksini-h Committed by GitHub
Browse files

[Feature] OTEL tracing during loading (#31162)

parent 91a07ff6
...@@ -52,4 +52,4 @@ anthropic >= 0.71.0 ...@@ -52,4 +52,4 @@ anthropic >= 0.71.0
model-hosting-container-standards >= 0.1.13, < 1.0.0 model-hosting-container-standards >= 0.1.13, < 1.0.0
mcp mcp
grpcio grpcio
grpcio-reflection grpcio-reflection
\ No newline at end of file
...@@ -1049,6 +1049,13 @@ setup( ...@@ -1049,6 +1049,13 @@ setup(
"petit-kernel": ["petit-kernel"], "petit-kernel": ["petit-kernel"],
# Optional deps for Helion kernel development # Optional deps for Helion kernel development
"helion": ["helion"], "helion": ["helion"],
# Optional deps for OpenTelemetry tracing
"otel": [
"opentelemetry-sdk>=1.26.0",
"opentelemetry-api>=1.26.0",
"opentelemetry-exporter-otlp>=1.26.0",
"opentelemetry-semantic-conventions-ai>=0.4.1",
],
}, },
cmdclass=cmdclass, cmdclass=cmdclass,
package_data=package_data, package_data=package_data,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections.abc import Callable, Generator, Iterable
from concurrent import futures
from typing import Any, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceRequest,
ExportTraceServiceResponse,
)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer,
add_TraceServiceServicer_to_server,
)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal[
"bool_value", "string_value", "int_value", "double_value", "array_value"
]
def decode_value(value: AnyValue):
"""Decode an OpenTelemetry AnyValue protobuf message to a Python value."""
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value": (
lambda v: [decode_value(item) for item in v.array_value.values]
),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]) -> dict[str, Any]:
"""Decode OpenTelemetry KeyValue attributes to a Python dictionary."""
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
"""A fake gRPC trace service for testing OpenTelemetry trace exports."""
def __init__(self):
self.requests: list[ExportTraceServiceRequest] = []
self.evt = threading.Event()
self._lock = threading.Lock()
def Export(self, request, context):
with self._lock:
self.requests.append(request)
self.evt.set()
return ExportTraceServiceResponse()
@property
def request(self) -> ExportTraceServiceRequest | None:
"""Returns the first request received (for backward compatibility)."""
with self._lock:
return self.requests[0] if self.requests else None
def get_all_spans(self) -> list[dict]:
"""Returns all spans from all received requests as decoded dicts."""
spans = []
with self._lock:
for request in self.requests:
for resource_span in request.resource_spans:
for scope_span in resource_span.scope_spans:
for span in scope_span.spans:
spans.append(
{
"name": span.name,
"attributes": decode_attributes(span.attributes),
"trace_id": span.trace_id.hex(),
"span_id": span.span_id.hex(),
"parent_span_id": span.parent_span_id.hex()
if span.parent_span_id
else None,
"start_time_unix_nano": span.start_time_unix_nano,
"end_time_unix_nano": span.end_time_unix_nano,
}
)
return spans
def wait_for_spans(self, count: int = 1, timeout: float = 10) -> bool:
"""Wait until at least `count` spans have been received."""
import time
deadline = time.time() + timeout
while time.time() < deadline:
if len(self.get_all_spans()) >= count:
return True
time.sleep(0.1)
return False
def clear(self):
"""Clear all received requests."""
with self._lock:
self.requests.clear()
self.evt.clear()
@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service."""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(grace=None)
@pytest.fixture
def trace_server_address() -> str:
"""Returns the address of the fake trace server."""
return FAKE_TRACE_SERVER_ADDRESS
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import pytest
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
from tests.tracing.conftest import FAKE_TRACE_SERVER_ADDRESS, FakeTraceService
from vllm.tracing import init_tracer, instrument, is_otel_available
# Skip everything if OTel is missing
pytestmark = pytest.mark.skipif(not is_otel_available(), reason="OTel required")
class TestCoreInstrumentation:
"""Focuses on the @instrument decorator's ability to capture execution data."""
@pytest.fixture(autouse=True)
def setup_tracing(self, monkeypatch):
monkeypatch.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
init_tracer("test.core", FAKE_TRACE_SERVER_ADDRESS)
def test_decorator_captures_sync_and_async(self, trace_service: FakeTraceService):
"""Verify basic span creation for both sync and async functions."""
@instrument(span_name="sync_task")
def sync_task():
return True
@instrument(span_name="async_task")
async def async_task():
return True
sync_task()
asyncio.run(async_task())
assert trace_service.wait_for_spans(count=2)
span_names = [s["name"] for s in trace_service.get_all_spans()]
assert "sync_task" in span_names
assert "async_task" in span_names
def test_nested_spans_hierarchy(self, trace_service: FakeTraceService):
"""Verify that nested calls create a parent-child relationship."""
@instrument(span_name="child")
def child():
pass
@instrument(span_name="parent")
def parent():
child()
parent()
assert trace_service.wait_for_spans(count=2)
spans = trace_service.get_all_spans()
parent_span = next(s for s in spans if s["name"] == "parent")
child_span = next(s for s in spans if s["name"] == "child")
assert child_span["parent_span_id"] == parent_span["span_id"]
class TestInterProcessPropagation:
"""Test the propagation of trace context between processes."""
def test_pickup_external_context(self, monkeypatch, trace_service):
"""Test that vLLM attaches to an existing trace ID if in environment."""
monkeypatch.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
# Manually simulate an external parent trace ID
fake_trace_id = "4bf92f3577b34da6a3ce929d0e0e4736"
fake_parent_id = "00f067aa0ba902b7"
monkeypatch.setenv("traceparent", f"00-{fake_trace_id}-{fake_parent_id}-01")
init_tracer("test.external", FAKE_TRACE_SERVER_ADDRESS)
@instrument(span_name="follower")
def follower_func():
pass
follower_func()
assert trace_service.wait_for_spans(count=1)
span = trace_service.get_all_spans()[0]
assert span["trace_id"] == fake_trace_id
assert span["parent_span_id"] == fake_parent_id
...@@ -2,76 +2,19 @@ ...@@ -2,76 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa # ruff: noqa
# type: ignore # type: ignore
import threading
from collections.abc import Iterable
from concurrent import futures
from typing import Callable, Generator, Literal
import grpc
import pytest import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( import time
ExportTraceServiceResponse,
)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer,
add_TraceServiceServicer_to_server,
)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes from vllm.tracing import SpanAttributes
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" # Import shared fixtures from the tracing conftest
from tests.tracing.conftest import ( # noqa: F401
FieldName = Literal[ FAKE_TRACE_SERVER_ADDRESS,
"bool_value", "string_value", "int_value", "double_value", "array_value" FakeTraceService,
] trace_service,
)
def decode_value(value: AnyValue):
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value": (
lambda v: [decode_value(item) for item in v.array_value.values]
),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]):
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
def __init__(self):
self.request = None
self.evt = threading.Event()
def Export(self, request, context):
self.request = request
self.evt.set()
return ExportTraceServiceResponse()
@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(None)
def test_traces( def test_traces(
...@@ -97,29 +40,25 @@ def test_traces( ...@@ -97,29 +40,25 @@ def test_traces(
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(prompts, sampling_params=sampling_params)
print(f"test_traces outputs is : {outputs}") print(f"test_traces outputs is : {outputs}")
timeout = 10 # Wait for the "llm_request" span to be exported.
if not trace_service.evt.wait(timeout): # The BatchSpanProcessor batches spans and exports them periodically,
raise TimeoutError( # so we need to wait specifically for the llm_request span to appear.
f"The fake trace service didn't receive a trace within " timeout = 15
f"the {timeout} seconds timeout" deadline = time.time() + timeout
) llm_request_spans = []
while time.time() < deadline:
request = trace_service.request all_spans = trace_service.get_all_spans()
assert len(request.resource_spans) == 1, ( llm_request_spans = [s for s in all_spans if s["name"] == "llm_request"]
f"Expected 1 resource span, but got {len(request.resource_spans)}" if llm_request_spans:
) break
assert len(request.resource_spans[0].scope_spans) == 1, ( time.sleep(0.5)
f"Expected 1 scope span, "
f"but got {len(request.resource_spans[0].scope_spans)}" assert len(llm_request_spans) == 1, (
) f"Expected exactly 1 'llm_request' span, but got {len(llm_request_spans)}. "
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( f"All span names: {[s['name'] for s in all_spans]}"
f"Expected 1 span, "
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}"
) )
attributes = decode_attributes( attributes = llm_request_spans[0]["attributes"]
request.resource_spans[0].scope_spans[0].spans[0].attributes
)
# assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
assert ( assert (
......
...@@ -33,6 +33,7 @@ from vllm.config.utils import Range, hash_factors ...@@ -33,6 +33,7 @@ from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils import lazy from vllm.logging_utils import lazy
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tracing import instrument, instrument_manual
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from .compiler_interface import ( from .compiler_interface import (
...@@ -234,6 +235,7 @@ class CompilerManager: ...@@ -234,6 +235,7 @@ class CompilerManager:
) )
return compiled_graph return compiled_graph
@instrument(span_name="Compile graph")
def compile( def compile(
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
...@@ -497,6 +499,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -497,6 +499,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
# When True, it annoyingly dumps the torch.fx.Graph on errors. # When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False self.extra_traceback = False
@instrument(span_name="Inductor compilation")
def run(self, *args: Any) -> Any: def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake? # maybe instead just assert inputs are fake?
fake_args = [ fake_args = [
...@@ -922,6 +925,11 @@ class VllmBackend: ...@@ -922,6 +925,11 @@ class VllmBackend:
) )
self.compilation_config.compilation_time += dynamo_time self.compilation_config.compilation_time += dynamo_time
# Record Dynamo time in tracing if available
start_time = int(torch_compile_start_time * 1e9)
attributes = {"dynamo.time_seconds": dynamo_time}
instrument_manual("Dynamo bytecode transform", start_time, None, attributes)
# we control the compilation process, each instance can only be # we control the compilation process, each instance can only be
# called once # called once
assert not self._called, "VllmBackend can only be called once" assert not self._called, "VllmBackend can only be called once"
......
...@@ -122,9 +122,9 @@ class ObservabilityConfig: ...@@ -122,9 +122,9 @@ class ObservabilityConfig:
@classmethod @classmethod
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None: def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
if value is not None: if value is not None:
from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.tracing import is_tracing_available, otel_import_error_traceback
if not is_otel_available(): if not is_tracing_available():
raise ValueError( raise ValueError(
"OpenTelemetry is not available. Unable to configure " "OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
......
...@@ -50,6 +50,7 @@ from vllm.logger import init_logger ...@@ -50,6 +50,7 @@ from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager from vllm.tool_parsers import ToolParserManager
from vllm.tracing import instrument
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.network_utils import is_valid_ipv6_address
...@@ -377,6 +378,7 @@ def validate_api_server_args(args): ...@@ -377,6 +378,7 @@ def validate_api_server_args(args):
) )
@instrument(span_name="API server setup")
def setup_server(args): def setup_server(args):
"""Validate API server args, set up signal handler, create socket """Validate API server args, set up signal handler, create socket
ready to serve.""" ready to serve."""
......
...@@ -14,6 +14,7 @@ from vllm.model_executor.model_loader.utils import ( ...@@ -14,6 +14,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, process_weights_after_loading,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.mem_utils import format_gib from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
...@@ -37,6 +38,7 @@ class BaseModelLoader(ABC): ...@@ -37,6 +38,7 @@ class BaseModelLoader(ABC):
inplace weights loading for an already-initialized model""" inplace weights loading for an already-initialized model"""
raise NotImplementedError raise NotImplementedError
@instrument(span_name="Load model")
def load_model( def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module: ) -> nn.Module:
......
...@@ -30,6 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -30,6 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator, safetensors_weights_iterator,
) )
from vllm.tracing import instrument
from vllm.transformers_utils.repo_utils import list_filtered_repo_files from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -274,6 +275,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -274,6 +275,7 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns_overrides=None, allow_patterns_overrides=None,
) )
@instrument(span_name="Load weights")
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao": if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config) quant_config = get_quant_config(model_config, self.load_config)
......
...@@ -23,11 +23,13 @@ from vllm.model_executor.model_loader.reload import ( ...@@ -23,11 +23,13 @@ from vllm.model_executor.model_loader.reload import (
set_torchao_reload_attrs, set_torchao_reload_attrs,
) )
from vllm.model_executor.models.interfaces import SupportsQuant from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.tracing import instrument
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
@instrument(span_name="Initialize model")
def initialize_model( def initialize_model(
vllm_config: VllmConfig, vllm_config: VllmConfig,
*, *,
......
...@@ -36,6 +36,7 @@ from vllm.model_executor.layers.quantization import ( ...@@ -36,6 +36,7 @@ from vllm.model_executor.layers.quantization import (
get_quantization_config, get_quantization_config,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
try: try:
...@@ -443,6 +444,7 @@ def download_gguf( ...@@ -443,6 +444,7 @@ def download_gguf(
return local_files[0] return local_files[0]
@instrument(span_name="Download weights - HF")
def download_weights_from_hf( def download_weights_from_hf(
model_name_or_path: str, model_name_or_path: str,
cache_dir: str | None, cache_dir: str | None,
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
) )
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.tracing import instrument
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
...@@ -358,6 +359,7 @@ def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int: ...@@ -358,6 +359,7 @@ def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
return total return total
@instrument(span_name="DeepGemm warmup")
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
total = _count_warmup_iterations(model, max_tokens) total = _count_warmup_iterations(model, max_tokens)
if total == 0: if total == 0:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from typing import Any, TypeAlias
# Import the implementation details
from .otel import (
SpanKind,
extract_trace_context,
init_otel_tracer,
init_otel_worker_tracer,
instrument_otel,
is_otel_available,
manual_instrument_otel,
otel_import_error_traceback,
)
from .utils import (
SpanAttributes,
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
__all__ = [
"instrument",
"instrument_manual",
"init_tracer",
"maybe_init_worker_tracer",
"is_tracing_available",
"SpanAttributes",
"SpanKind",
"extract_trace_context",
"extract_trace_headers",
"log_tracing_disabled_warning",
"contains_trace_headers",
"otel_import_error_traceback",
]
BackendAvailableFunc: TypeAlias = Callable[[], bool]
InstrumentFunc: TypeAlias = Callable[..., Any]
InstrumentManualFunc: TypeAlias = Callable[..., Any]
InitTracerFunc: TypeAlias = Callable[..., Any]
InitWorkerTracerFunc: TypeAlias = Callable[..., Any]
_REGISTERED_TRACING_BACKENDS: dict[
str,
tuple[
BackendAvailableFunc,
InitTracerFunc,
InitWorkerTracerFunc,
InstrumentFunc,
InstrumentManualFunc,
],
] = {
"otel": (
is_otel_available,
init_otel_tracer,
init_otel_worker_tracer,
instrument_otel,
manual_instrument_otel,
),
}
def init_tracer(
instrumenting_module_name: str,
otlp_traces_endpoint: str,
extra_attributes: dict[str, str] | None = None,
):
is_available, init_tracer_fn, _, _, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return init_tracer_fn(
instrumenting_module_name, otlp_traces_endpoint, extra_attributes
)
def maybe_init_worker_tracer(
instrumenting_module_name: str,
process_kind: str,
process_name: str,
):
is_available, _, init_worker_tracer_fn, _, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return init_worker_tracer_fn(
instrumenting_module_name, process_kind, process_name
)
def instrument(
obj: Callable | None = None,
*,
span_name: str = "",
attributes: dict[str, str] | None = None,
record_exception: bool = True,
):
"""
Generic decorator to instrument functions.
"""
if obj is None:
return functools.partial(
instrument,
span_name=span_name,
attributes=attributes,
record_exception=record_exception,
)
# Dispatch to OTel (and potentially others later)
is_available, _, _, otel_instrument, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return otel_instrument(
func=obj,
span_name=span_name,
attributes=attributes,
record_exception=record_exception,
)
else:
return obj
def instrument_manual(
span_name: str,
start_time: int,
end_time: int | None = None,
attributes: dict[str, Any] | None = None,
context: Any = None,
kind: Any = None,
):
"""Manually create a span with explicit timestamps.
Args:
span_name: Name of the span to create.
start_time: Start time in nanoseconds since epoch.
end_time: Optional end time in nanoseconds. If None, ends immediately.
attributes: Optional dict of span attributes.
context: Optional trace context (e.g., from extract_trace_context).
kind: Optional SpanKind (e.g., SpanKind.SERVER).
"""
is_available, _, _, _, manual_instrument_fn = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return manual_instrument_fn(
span_name, start_time, end_time, attributes, context, kind
)
else:
return None
def is_tracing_available() -> bool:
"""
Returns True if any tracing backend (OTel, Profiler, etc.) is available.
Use this to guard expensive tracing logic in the main code.
"""
check_available = [
is_available
for is_available, _, _, _, _ in _REGISTERED_TRACING_BACKENDS.values()
]
return any(check_available)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import functools
import inspect
import os
import traceback
from collections.abc import Mapping
from contextlib import contextmanager
from typing import Any
from vllm.logger import init_logger
from vllm.tracing.utils import TRACE_HEADERS, LoadingSpanAttributes
logger = init_logger(__name__)
try:
from opentelemetry import trace
from opentelemetry.context.context import Context
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPGrpcExporter,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPHttpExporter,
)
from opentelemetry.propagate import inject
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import (
SpanKind, # noqa: F401
Tracer,
set_tracer_provider,
)
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
_IS_OTEL_AVAILABLE = True
otel_import_error_traceback = None
except ImportError:
_IS_OTEL_AVAILABLE = False
otel_import_error_traceback = traceback.format_exc()
trace = None # type: ignore
Context = Any # type: ignore
Tracer = Any # type: ignore
inject = None # type: ignore
Resource = None # type: ignore
SpanKind = Any # type: ignore
def is_otel_available() -> bool:
return _IS_OTEL_AVAILABLE
def init_otel_tracer(
instrumenting_module_name: str,
otlp_traces_endpoint: str,
extra_attributes: dict[str, str] | None = None,
) -> Tracer:
"""Initializes the OpenTelemetry tracer provider."""
if not _IS_OTEL_AVAILABLE:
raise ValueError(
"OpenTelemetry is not available. Unable to initialize "
"a tracer. Ensure OpenTelemetry packages are installed. "
f"Original error:\n{otel_import_error_traceback}"
)
# Store the endpoint in environment so child processes can inherit it
os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = otlp_traces_endpoint
resource_attrs = {}
resource_attrs["vllm.instrumenting_module_name"] = instrumenting_module_name
resource_attrs["vllm.process_id"] = str(os.getpid())
if extra_attributes:
resource_attrs.update(extra_attributes)
resource = Resource.create(resource_attrs)
trace_provider = TracerProvider(resource=resource)
span_exporter = get_span_exporter(otlp_traces_endpoint)
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
set_tracer_provider(trace_provider)
atexit.register(trace_provider.shutdown)
tracer = trace_provider.get_tracer(instrumenting_module_name)
return tracer
def get_span_exporter(endpoint):
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
if protocol == "grpc":
exporter = OTLPGrpcExporter(endpoint=endpoint, insecure=True)
elif protocol == "http/protobuf":
exporter = OTLPHttpExporter(endpoint=endpoint)
else:
raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured")
return exporter
def init_otel_worker_tracer(
instrumenting_module_name: str,
process_kind: str,
process_name: str,
) -> Tracer:
"""
Backend-specific initialization for OpenTelemetry in a worker process.
"""
# Initialize the tracer if an OTLP endpoint is configured.
# The endpoint is propagated via environment variable from the main process.
otlp_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")
if not otlp_endpoint:
return None
extra_attrs = {
"vllm.process_kind": process_kind,
"vllm.process_name": process_name,
}
return init_otel_tracer(instrumenting_module_name, otlp_endpoint, extra_attrs)
def extract_trace_context(headers: Mapping[str, str] | None) -> Context | None:
"""Extracts context from HTTP headers."""
if _IS_OTEL_AVAILABLE and headers:
return TraceContextTextMapPropagator().extract(headers)
return None
def instrument_otel(func, span_name, attributes, record_exception):
"""Internal wrapper logic for sync and async functions."""
# Pre-calculate static code attributes once (these don't change)
code_attrs = {
LoadingSpanAttributes.CODE_FUNCTION: func.__qualname__,
LoadingSpanAttributes.CODE_NAMESPACE: func.__module__,
LoadingSpanAttributes.CODE_FILEPATH: func.__code__.co_filename,
LoadingSpanAttributes.CODE_LINENO: str(func.__code__.co_firstlineno),
}
if attributes:
code_attrs.update(attributes)
final_span_name = span_name or func.__qualname__
module_name = func.__module__
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
tracer = trace.get_tracer(module_name)
ctx = _get_smart_context()
with (
tracer.start_as_current_span(
final_span_name,
context=ctx,
attributes=code_attrs,
record_exception=record_exception,
),
propagate_trace_to_env(),
):
return await func(*args, **kwargs)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
tracer = trace.get_tracer(module_name)
ctx = _get_smart_context()
with (
tracer.start_as_current_span(
final_span_name,
context=ctx,
attributes=code_attrs,
record_exception=record_exception,
),
propagate_trace_to_env(),
):
return func(*args, **kwargs)
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
def manual_instrument_otel(
span_name: str,
start_time: int,
end_time: int | None = None,
attributes: dict[str, Any] | None = None,
context: Context | None = None,
kind: Any = None, # SpanKind, but typed as Any for when OTEL unavailable
):
"""Manually create and end a span with explicit timestamps."""
if not _IS_OTEL_AVAILABLE:
return
tracer = trace.get_tracer(__name__)
# Use provided context, or fall back to smart context detection
ctx = context if context is not None else _get_smart_context()
span_kwargs: dict[str, Any] = {
"name": span_name,
"context": ctx,
"start_time": start_time,
}
if kind is not None:
span_kwargs["kind"] = kind
span = tracer.start_span(**span_kwargs)
if attributes:
span.set_attributes(attributes)
if end_time is not None:
span.end(end_time=end_time)
else:
span.end()
def _get_smart_context() -> Context | None:
"""
Determines the parent context.
1. If a Span is already active in this process, use it.
2. If not, extract from os.environ, handling the case-sensitivity mismatch.
"""
current_span = trace.get_current_span()
if current_span.get_span_context().is_valid:
return None
carrier = {}
if tp := os.environ.get("traceparent", os.environ.get("TRACEPARENT")): # noqa: SIM112
carrier["traceparent"] = tp
if ts := os.environ.get("tracestate", os.environ.get("TRACESTATE")): # noqa: SIM112
carrier["tracestate"] = ts
if not carrier:
carrier = dict(os.environ)
return TraceContextTextMapPropagator().extract(carrier)
@contextmanager
def propagate_trace_to_env():
"""
Temporarily injects the current OTel context into os.environ.
This ensures that any subprocesses (like vLLM workers) spawned
within this context inherit the correct traceparent.
"""
if not _IS_OTEL_AVAILABLE:
yield
return
# Capture original state of relevant keys
original_state = {k: os.environ.get(k) for k in TRACE_HEADERS}
try:
# inject() writes 'traceparent' and 'tracestate' to os.environ
inject(os.environ)
yield
finally:
# Restore original environment
for key, original_value in original_state.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Mapping from collections.abc import Mapping
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.func_utils import run_once from vllm.utils.func_utils import run_once
TRACE_HEADERS = ["traceparent", "tracestate"]
logger = init_logger(__name__) logger = init_logger(__name__)
_is_otel_imported = False # Standard W3C headers used for context propagation
otel_import_error_traceback: str | None = None TRACE_HEADERS = ["traceparent", "tracestate"]
try:
from opentelemetry.context.context import Context
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
_is_otel_imported = True
except ImportError:
# Capture and format traceback to provide detailed context for the import
# error. Only the string representation of the error is retained to avoid
# memory leaks.
# See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458
import traceback
otel_import_error_traceback = traceback.format_exc()
class Context: # type: ignore
pass
class BaseSpanAttributes: # type: ignore
pass
class SpanKind: # type: ignore
pass
class Tracer: # type: ignore
pass
def is_otel_available() -> bool:
return _is_otel_imported
def init_tracer(
instrumenting_module_name: str, otlp_traces_endpoint: str
) -> Tracer | None:
if not is_otel_available():
raise ValueError(
"OpenTelemetry is not available. Unable to initialize "
"a tracer. Ensure OpenTelemetry packages are installed. "
f"Original error:\n{otel_import_error_traceback}"
)
trace_provider = TracerProvider()
span_exporter = get_span_exporter(otlp_traces_endpoint)
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
set_tracer_provider(trace_provider)
tracer = trace_provider.get_tracer(instrumenting_module_name)
return tracer
def get_span_exporter(endpoint):
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
if protocol == "grpc":
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter,
)
elif protocol == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter, # type: ignore
)
else:
raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured")
return OTLPSpanExporter(endpoint=endpoint)
def extract_trace_context(headers: Mapping[str, str] | None) -> Context | None:
if is_otel_available():
headers = headers or {}
return TraceContextTextMapPropagator().extract(headers)
else:
return None
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]: class SpanAttributes:
return {h: headers[h] for h in TRACE_HEADERS if h in headers} """
Standard attributes for spans.
These are largely based on OpenTelemetry Semantic Conventions but are defined
here as constants so they can be used by any backend or logger.
"""
class SpanAttributes: # Attribute names copied from OTel semantic conventions to avoid version conflicts
# Attribute names copied from here to avoid version conflicts:
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens" GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens"
GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens" GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens"
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens" GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p" GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature" GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model" GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
# Attribute names added until they are added to the semantic conventions:
# Custom attributes added until they are standardized
GEN_AI_REQUEST_ID = "gen_ai.request.id" GEN_AI_REQUEST_ID = "gen_ai.request.id"
GEN_AI_REQUEST_N = "gen_ai.request.n" GEN_AI_REQUEST_N = "gen_ai.request.n"
GEN_AI_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences" GEN_AI_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
...@@ -116,20 +36,37 @@ class SpanAttributes: ...@@ -116,20 +36,37 @@ class SpanAttributes:
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e" GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler" GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
# Time taken in the forward pass for this across all workers
# Latency breakdowns
GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward" GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
# Time taken in the model execute function. This will include model
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute" GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill" GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill"
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference" GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference"
class LoadingSpanAttributes:
"""Custom attributes for code-level tracing (file, line number)."""
CODE_NAMESPACE = "code.namespace"
CODE_FUNCTION = "code.function"
CODE_FILEPATH = "code.filepath"
CODE_LINENO = "code.lineno"
def contains_trace_headers(headers: Mapping[str, str]) -> bool: def contains_trace_headers(headers: Mapping[str, str]) -> bool:
"""Check if the provided headers dictionary contains trace context."""
return any(h in headers for h in TRACE_HEADERS) return any(h in headers for h in TRACE_HEADERS)
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
"""
Extract only trace-related headers from a larger header dictionary.
Useful for logging or passing context to a non-OTel client.
"""
return {h: headers[h] for h in TRACE_HEADERS if h in headers}
@run_once @run_once
def log_tracing_disabled_warning() -> None: def log_tracing_disabled_warning() -> None:
logger.warning("Received a request with trace context but tracing is disabled") logger.warning("Received a request with trace context but tracing is disabled")
...@@ -110,6 +110,10 @@ class AsyncLLM(EngineClient): ...@@ -110,6 +110,10 @@ class AsyncLLM(EngineClient):
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
tracing_endpoint = self.observability_config.otlp_traces_endpoint
if tracing_endpoint is not None:
init_tracer("vllm.llm_engine", tracing_endpoint)
self.log_requests = log_requests self.log_requests = log_requests
custom_stat_loggers = list(stat_loggers or []) custom_stat_loggers = list(stat_loggers or [])
...@@ -136,10 +140,8 @@ class AsyncLLM(EngineClient): ...@@ -136,10 +140,8 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats, log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval, stream_interval=self.vllm_config.scheduler_config.stream_interval,
) )
endpoint = self.observability_config.otlp_traces_endpoint if tracing_endpoint is not None:
if endpoint is not None: self.output_processor.tracing_enabled = True
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process). # EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client( self.engine_core = EngineCoreClient.make_async_mp_client(
......
...@@ -24,6 +24,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception ...@@ -24,6 +24,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import ( from vllm.utils.gc_utils import (
freeze_gc_heap, freeze_gc_heap,
...@@ -217,6 +218,7 @@ class EngineCore: ...@@ -217,6 +218,7 @@ class EngineCore:
# environment variable overrides after this point) # environment variable overrides after this point)
enable_envs_cache() enable_envs_cache()
@instrument(span_name="Prepare model")
def _initialize_kv_caches( def _initialize_kv_caches(
self, vllm_config: VllmConfig self, vllm_config: VllmConfig
) -> tuple[int, int, KVCacheConfig]: ) -> tuple[int, int, KVCacheConfig]:
...@@ -658,6 +660,7 @@ class EngineCoreProc(EngineCore): ...@@ -658,6 +660,7 @@ class EngineCoreProc(EngineCore):
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
@instrument(span_name="EngineCoreProc init")
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
...@@ -926,8 +929,18 @@ class EngineCoreProc(EngineCore): ...@@ -926,8 +929,18 @@ class EngineCoreProc(EngineCore):
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0 data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
if data_parallel: if data_parallel:
parallel_config.data_parallel_rank_local = local_dp_rank parallel_config.data_parallel_rank_local = local_dp_rank
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name=f"EngineCore_DP{dp_rank}",
)
set_process_title("EngineCore", f"DP{dp_rank}") set_process_title("EngineCore", f"DP{dp_rank}")
else: else:
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name="EngineCore",
)
set_process_title("EngineCore") set_process_title("EngineCore")
decorate_logs() decorate_logs()
...@@ -956,6 +969,7 @@ class EngineCoreProc(EngineCore): ...@@ -956,6 +969,7 @@ class EngineCoreProc(EngineCore):
parallel_config.data_parallel_rank = 0 parallel_config.data_parallel_rank = 0
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
assert engine_core is not None
engine_core.run_busy_loop() engine_core.run_busy_loop()
except SystemExit: except SystemExit:
...@@ -1485,6 +1499,13 @@ class EngineCoreActorMixin: ...@@ -1485,6 +1499,13 @@ class EngineCoreActorMixin:
dp_rank: int = 0, dp_rank: int = 0,
local_dp_rank: int = 0, local_dp_rank: int = 0,
): ):
# Initialize tracer for distributed tracing if configured.
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name=f"DPEngineCoreActor_DP{dp_rank}",
)
self.addresses = addresses self.addresses = addresses
vllm_config.parallel_config.data_parallel_index = dp_rank vllm_config.parallel_config.data_parallel_index = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
......
...@@ -24,6 +24,7 @@ from vllm.envs import VLLM_ENGINE_READY_TIMEOUT_S ...@@ -24,6 +24,7 @@ from vllm.envs import VLLM_ENGINE_READY_TIMEOUT_S
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import ( from vllm.utils.network_utils import (
close_sockets, close_sockets,
...@@ -96,6 +97,7 @@ class EngineCoreClient(ABC): ...@@ -96,6 +97,7 @@ class EngineCoreClient(ABC):
return InprocClient(vllm_config, executor_class, log_stats) return InprocClient(vllm_config, executor_class, log_stats)
@staticmethod @staticmethod
@instrument(span_name="Overall Loading")
def make_async_mp_client( def make_async_mp_client(
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor], executor_class: type[Executor],
...@@ -650,6 +652,7 @@ def _process_utility_output( ...@@ -650,6 +652,7 @@ def _process_utility_output(
class SyncMPClient(MPClient): class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore.""" """Synchronous client for multi-proc EngineCore."""
@instrument(span_name="SyncMPClient init")
def __init__( def __init__(
self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool
): ):
...@@ -819,6 +822,7 @@ class SyncMPClient(MPClient): ...@@ -819,6 +822,7 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
@instrument(span_name="AsyncMPClient init")
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment