Unverified Commit 4c21b090 authored by Feng Su's avatar Feng Su Committed by GitHub
Browse files

[Feature] Sglang Tracing: Fine-Grained Tracking for Request Latency - Part 1 (#9962)


Signed-off-by: default avatarFeng Su <sufeng@linux.alibaba.com>
Signed-off-by: default avatarHuaixin Chang <changhuaixin@linux.alibaba.com>
Signed-off-by: default avatarPeng Wang <rocking@linux.alibaba.com>
parent 165abeeb
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server.
You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965.
## Setup Guide
This section explains how to configure the request tracing and export the trace data.
1. Install the required packages and tools
* install Docker and Docker Compose
* install the dependencies
```bash
# enter the SGLang root directory
pip install -e "python[tracing]"
# or manually install the dependencies using pip
pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc
```
2. launch opentelemetry collector and jaeger
```bash
docker compose -f examples/monitoring/tracing_compose.yaml up -d
```
3. start your SGLang server with tracing enabled
```bash
python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 <other option>
```
Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.
4. raise some requests
5. Observe whether trace data is being exported
* Access port 16686 of Jaeger using a web browser to visualize the request traces.
* The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI.
## How to add Tracing for slices you're interested in?
We have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below.
1. initialization
Every process involved in tracing during the initialization phase should execute:
```python
process_tracing_init(oltp_traces_endpoint, server_name)
```
The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.
Every thread involved in tracing during the initialization phase should execute:
```python
trace_set_thread_info("thread label", tp_rank, dp_rank)
```
The "thread label" can be regarded as the name of the thread, used to distinguish different threads in the visualization view.
2. Mark the beginning and end of a request
```
trace_req_start(rid, bootstrap_room)
trace_req_finish(rid)
```
These two APIs must be called within the same process, for example, in the tokenizer.
3. Add tracing for slice
* Add slice tracing normally:
```python
trace_slice_start("slice A", rid)
trace_slice_end("slice A", rid)
```
- Use the "anonymous" flag to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end.
<br>Note: Anonymous slices must not be nested.
```python
trace_slice_start("", rid, anonymous = True)
trace_slice_end("slice A", rid)
```
- In trace_slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed.
```python
trace_slice_start("", rid, anonymous = True)
trace_slice_end("slice A", rid, auto_next_anon = True)
trace_slice_end("slice B", rid, auto_next_anon = True)
trace_slice_end("slice C", rid, auto_next_anon = True)
trace_slice_end("slice D", rid)
```
- The end of the last slice in a thread must be marked with thread_finish_flag=True; otherwise, the thread's span will not be properly generated.
```python
trace_slice_end("slice D", rid, thread_finish_flag = True)
```
4. When the request execution flow transfers to another thread, the trace context needs to be explicitly propagated.
- sender: Execute the following code before sending the request to another thread via ZMQ
```python
trace_context = trace_get_proc_propagate_context(rid)
req.trace_context = trace_context
```
- receiver: Execute the following code after receiving the request via ZMQ
```python
trace_set_proc_propagate_context(rid, req.trace_context)
```
## How to Extend the Tracing Framework to Support Complex Tracing Scenarios
The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.
The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure.
The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows:
```
SglangTraceReqContext (req_id="req-123")
├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0)
│ └── SglangTraceSliceContext (name="prefill") # cur slice
|
└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1)
└── SglangTraceSliceContext (name="prefill") # cur slice
```
Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends.
In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.
When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span.
receivers:
otlp:
protocols:
grpc:
endpoint: 0.0.0.0:4317
http:
endpoint: 0.0.0.0:4318
processors:
batch:
exporters:
otlp:
endpoint: jaeger:4317
tls:
insecure: true
file:
path: /tmp/otel_trace.json
extensions:
health_check:
pprof:
zpages:
service:
extensions: [health_check, pprof, zpages]
pipelines:
traces:
receivers: [otlp]
processors: [batch]
exporters: [otlp, file]
metrics:
receivers: [otlp]
processors: [batch]
exporters: [otlp]
logs:
receivers: [otlp]
processors: [batch]
exporters: [otlp]
services:
otel-collector:
image: docker.io/otel/opentelemetry-collector
volumes:
- ./opentelemetry.yaml:/etc/otelcol/config.yaml
- /tmp:/tmp
ports:
- "4317:4317" # OTLP gRPC
- "4318:4318" # OTLP HTTP
depends_on:
- jaeger
restart: unless-stopped
jaeger:
image: jaegertracing/all-in-one
container_name: jaeger
ports:
- "16686:16686"
environment:
- COLLECTOR_OTLP_ENABLED=true
restart: unless-stopped
...@@ -56,6 +56,13 @@ runtime_common = [ ...@@ -56,6 +56,13 @@ runtime_common = [
"xgrammar==0.1.24", "xgrammar==0.1.24",
] ]
tracing = [
"opentelemetry-sdk",
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc",
]
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.3.9.post2", "sgl-kernel==0.3.9.post2",
......
...@@ -33,6 +33,8 @@ import zmq ...@@ -33,6 +33,8 @@ import zmq
import zmq.asyncio import zmq.asyncio
from PIL.Image import Image from PIL.Image import Image
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -138,6 +140,12 @@ class Engine(EngineBase): ...@@ -138,6 +140,12 @@ class Engine(EngineBase):
context, zmq.DEALER, self.port_args.rpc_ipc_name, True context, zmq.DEALER, self.port_args.rpc_ipc_name, True
) )
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)
def generate( def generate(
self, self,
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
......
...@@ -31,6 +31,8 @@ from typing import Any, AsyncIterator, Callable, Dict, List, Optional ...@@ -31,6 +31,8 @@ from typing import Any, AsyncIterator, Callable, Dict, List, Optional
import setproctitle import setproctitle
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -179,6 +181,13 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -179,6 +181,13 @@ async def init_multi_tokenizer() -> ServerArgs:
scheduler_info=scheduler_info, scheduler_info=scheduler_info,
) )
) )
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = f"MultiTokenizer-{tokenizer_manager.worker_id}"
trace_set_thread_info(thread_label)
return server_args return server_args
...@@ -1203,6 +1212,12 @@ def launch_server( ...@@ -1203,6 +1212,12 @@ def launch_server(
server_args=server_args, server_args=server_args,
) )
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)
set_global_state( set_global_state(
_GlobalState( _GlobalState(
tokenizer_manager=tokenizer_manager, tokenizer_manager=tokenizer_manager,
......
...@@ -605,6 +605,9 @@ class TokenizedGenerateReqInput: ...@@ -605,6 +605,9 @@ class TokenizedGenerateReqInput:
# Image gen grpc migration # Image gen grpc migration
return_bytes: bool = False return_bytes: bool = False
# tracing context
trace_context: Optional[Dict] = None
@dataclass @dataclass
class BatchTokenizedGenerateReqInput: class BatchTokenizedGenerateReqInput:
...@@ -654,6 +657,9 @@ class EmbeddingReqInput: ...@@ -654,6 +657,9 @@ class EmbeddingReqInput:
# For background responses (OpenAI responses API) # For background responses (OpenAI responses API)
background: bool = False background: bool = False
# tracing context
trace_context: Optional[Dict] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided # at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None: if self.text is None and self.input_ids is None and self.image_data is None:
......
...@@ -149,6 +149,15 @@ from sglang.srt.parser.reasoning_parser import ReasoningParser ...@@ -149,6 +149,15 @@ from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_event,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice,
trace_slice_end,
trace_slice_start,
)
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import ( from sglang.srt.utils import (
DynamicGradMode, DynamicGradMode,
...@@ -826,6 +835,10 @@ class Scheduler( ...@@ -826,6 +835,10 @@ class Scheduler(
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
...@@ -847,6 +860,10 @@ class Scheduler( ...@@ -847,6 +860,10 @@ class Scheduler(
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
if batch: if batch:
batch.launch_done = threading.Event() batch.launch_done = threading.Event()
result = self.run_batch(batch) result = self.run_batch(batch)
...@@ -1110,6 +1127,12 @@ class Scheduler( ...@@ -1110,6 +1127,12 @@ class Scheduler(
self.tp_cpu_group, self.tp_cpu_group,
src=self.tp_group.ranks[0], src=self.tp_group.ranks[0],
) )
for req in recv_reqs:
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("", req.rid, anonymous=True)
return recv_reqs return recv_reqs
def process_input_requests(self, recv_reqs: List): def process_input_requests(self, recv_reqs: List):
...@@ -1347,6 +1370,7 @@ class Scheduler( ...@@ -1347,6 +1370,7 @@ class Scheduler(
else: else:
self._prefetch_kvcache(req) self._prefetch_kvcache(req)
self.waiting_queue.append(req) self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)
def _prefetch_kvcache(self, req: Req): def _prefetch_kvcache(self, req: Req):
if self.enable_hicache_storage: if self.enable_hicache_storage:
...@@ -1914,8 +1938,23 @@ class Scheduler( ...@@ -1914,8 +1938,23 @@ class Scheduler(
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done) self.process_batch_result_decode(batch, result, launch_done)
for req in batch.reqs:
trace_slice(
"decode loop",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done) self.process_batch_result_prefill(batch, result, launch_done)
for req in batch.reqs:
trace_slice(
"prefill",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done) self.tp_worker.resolve_last_batch_result(launch_done)
...@@ -2600,6 +2639,12 @@ def run_scheduler_process( ...@@ -2600,6 +2639,12 @@ def run_scheduler_process(
pipe_writer, pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None, balance_meta: Optional[DPBalanceMeta] = None,
): ):
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
if (numa_node := server_args.numa_node) is not None: if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id]) numa_bind_to_node(numa_node[gpu_id])
......
...@@ -82,6 +82,13 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat ...@@ -82,6 +82,13 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.tracing.trace import (
trace_get_proc_propagate_context,
trace_req_finish,
trace_req_start,
trace_slice_end,
trace_slice_start,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_gc_warning, configure_gc_warning,
dataclass_to_string_truncated, dataclass_to_string_truncated,
...@@ -358,6 +365,24 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -358,6 +365,24 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# If it's a single value, add worker_id prefix # If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}" obj.rid = f"{self.worker_id}_{obj.rid}"
if obj.is_single:
bootstrap_room = (
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
)
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
else:
for i in range(len(obj.rid)):
bootstrap_room = (
obj.bootstrap_room[i]
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
else None
)
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
trace_slice_start(
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
)
if self.log_requests: if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata max_length, skip_names, _ = self.log_request_metadata
logger.info( logger.info(
...@@ -574,6 +599,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -574,6 +599,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
mm_inputs = None mm_inputs = None
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
trace_slice_end("tokenize", obj.rid)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
) )
...@@ -752,6 +778,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -752,6 +778,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
req, req.text, input_ids_list[i], None, None, token_type_ids req, req.text, input_ids_list[i], None, None, token_type_ids
) )
) )
trace_slice_end("tokenize", req.rid)
logger.debug(f"Completed batch processing for {batch_size} requests") logger.debug(f"Completed batch processing for {batch_size} requests")
return tokenized_objs return tokenized_objs
...@@ -779,9 +806,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -779,9 +806,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None, created_time: Optional[float] = None,
): ):
trace_slice_start("dispatch", obj.rid)
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
return state return state
def _send_batch_request( def _send_batch_request(
...@@ -1429,6 +1459,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1429,6 +1459,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
state.finished_time = time.time() state.finished_time = time.time()
meta_info["e2e_latency"] = state.finished_time - state.created_time meta_info["e2e_latency"] = state.finished_time - state.created_time
trace_req_finish(rid, ts=int(state.finished_time * 1e9))
del self.rid_to_state[rid] del self.rid_to_state[rid]
# Mark ongoing LoRA request as finished. # Mark ongoing LoRA request as finished.
......
...@@ -215,6 +215,8 @@ class ServerArgs: ...@@ -215,6 +215,8 @@ class ServerArgs:
enable_request_time_stats_logging: bool = False enable_request_time_stats_logging: bool = False
kv_events_config: Optional[str] = None kv_events_config: Optional[str] = None
gc_warning_threshold_secs: float = 0.0 gc_warning_threshold_secs: float = 0.0
enable_trace: bool = False
oltp_traces_endpoint: str = "localhost:4317"
# API related # API related
api_key: Optional[str] = None api_key: Optional[str] = None
...@@ -1390,6 +1392,17 @@ class ServerArgs: ...@@ -1390,6 +1392,17 @@ class ServerArgs:
default=None, default=None,
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.", help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
) )
parser.add_argument(
"--enable-trace",
action="store_true",
help="Enable opentelemetry trace",
)
parser.add_argument(
"--oltp-traces-endpoint",
type=str,
default="localhost:4317",
help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
)
# API related # API related
parser.add_argument( parser.add_argument(
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""package for sglang requests tracing"""
from __future__ import annotations
import ctypes
import logging
import os
import random
import threading
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
opentelemetry_imported = False
tracing_enabled = False
try:
from opentelemetry import context, propagate, trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider, id_generator
from opentelemetry.sdk.trace.export import BatchSpanProcessor
opentelemetry_imported = True
except ImportError:
class id_generator:
class IdGenerator:
pass
logger.info("opentelemetry package is not installed, tracing disabled")
@dataclass
class SglangTraceThreadInfo:
host_id: str
pid: int
thread_label: str
tp_rank: int
dp_rank: int
tracer: trace.Tracer
@dataclass
class SglangTraceSliceContext:
slice_name: str
span: Optional[trace.span.Span] = None
# When True, defers slice_name assignment until trace_slice_end()
anonymous: bool = False
@dataclass
class SglangTraceThreadContext:
thread_info: SglangTraceThreadInfo
cur_slice_stack: List[SglangTraceSliceContext]
thread_span: Optional[trace.span.Span] = None
# Record the most recently completed span as the previous span for the next span to be created.
last_span_context: Optional[trace.span.SpanContext] = None
@dataclass
class SglangTraceReqContext:
rid: str
start_time_ns: int
threads_context: Dict[int, SglangTraceThreadContext]
bootstrap_room: Optional[int] = None
# Indicates whether this instance is a replica from the main process.
# When True, root_span is None and only root_span_context is preserved.
is_copy: bool = False
root_span: Optional[trace.span.Span] = None
root_span_context: Optional[context.Context] = None
@dataclass
class SglangTracePropagateContext:
root_span_context: context.Context
prev_span_context: Optional[trace.span.SpanContext]
def to_dict(self):
carrier: dict[str, str] = {}
context.attach(self.root_span_context)
propagate.inject(carrier)
if self.prev_span_context:
return {
"root_span": carrier,
"prev_span": {
"span_id": self.prev_span_context.span_id,
"trace_id": self.prev_span_context.trace_id,
},
}
else:
return {"root_span": carrier, "prev_span": "None"}
@classmethod
def instance_from_dict(cls, d):
if "root_span" not in d or "prev_span" not in d:
return None
carrier = d["root_span"]
root_span_context = propagate.extract(carrier)
if d["prev_span"] == "None":
prev_span_context = None
else:
prev_span_context = trace.span.SpanContext(
trace_id=d["prev_span"]["trace_id"],
span_id=d["prev_span"]["span_id"],
is_remote=True,
)
return cls(root_span_context, prev_span_context)
class SglangTraceCustomIdGenerator(id_generator.IdGenerator):
"""
The default IdGenerator may produce duplicate trace IDs across multiple TP scheduler processes,
hence a custom IdGenerator is implemented.
"""
def __init__(self):
super().__init__()
self.local_random = random.Random()
self.local_random.seed(time.time())
def generate_trace_id(self) -> int:
return self.local_random.getrandbits(64)
def generate_span_id(self) -> int:
return self.local_random.getrandbits(64)
# global variables
threads_info: Dict[int, SglangTraceThreadInfo] = {}
reqs_context: Dict[str, SglangTraceReqContext] = {}
__get_cur_time_ns = lambda: int(time.time() * 1e9)
def __get_host_id() -> str:
"""
In distributed tracing systems, obtain a unique node identifier
and inject it into all subsequently generated spans
to prevent PID conflicts between threads on different nodes.
"""
if os.path.exists("/etc/machine-id"):
try:
with open("/etc/machine-id", "r") as f:
return f.read().strip()
except:
pass
mac = uuid.getnode()
if mac != 0:
return uuid.UUID(int=mac).hex
return "unknown"
# Should be called by each tracked process.
def process_tracing_init(otlp_endpoint, server_name):
global tracing_enabled
global __get_cur_time_ns
if not opentelemetry_imported:
tracing_enabled = False
return
try:
resource = Resource.create(
attributes={
SERVICE_NAME: server_name,
}
)
tracer_provider = TracerProvider(
resource=resource, id_generator=SglangTraceCustomIdGenerator()
)
processor = BatchSpanProcessor(
OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
)
tracer_provider.add_span_processor(processor)
trace.set_tracer_provider(tracer_provider)
except Exception as e:
logger.error(f": initialize opentelemetry error:{e}")
logger.warning("pelease set correct otlp endpoint")
tracing_enabled = False
return
if hasattr(time, "time_ns"):
__get_cur_time_ns = lambda: int(time.time_ns())
tracing_enabled = True
# Should be called by each tracked thread.
def trace_set_thread_info(
thread_label: str, tp_rank: Optional[int] = None, dp_rank: Optional[int] = None
):
if not tracing_enabled:
return
pid = threading.get_native_id()
if pid in threads_info:
return
threads_info[pid] = SglangTraceThreadInfo(
host_id=__get_host_id(),
pid=pid,
thread_label=thread_label,
tp_rank=tp_rank,
dp_rank=dp_rank,
tracer=trace.get_tracer("sglang server"),
)
def __create_thread_context(pid, req_span_context, ts: Optional[int] = None):
if pid not in threads_info:
trace_set_thread_info("unknown")
thread_info = threads_info[pid]
thread_context = SglangTraceThreadContext(
thread_info=thread_info,
cur_slice_stack=[],
)
thread_name = f"{thread_info.thread_label}"
if thread_info.tp_rank is not None:
thread_name += f" [TP {thread_info.tp_rank}] "
thread_name += f"(host:{thread_info.host_id[:8]} | pid:{pid})"
ts = ts or __get_cur_time_ns()
thread_context.thread_span = thread_context.thread_info.tracer.start_span(
name=thread_name,
start_time=ts,
context=req_span_context,
)
if thread_info.tp_rank is not None:
thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank})
thread_context.thread_span.set_attributes(
{
"host_id": thread_info.host_id,
"pid": thread_info.pid,
"thread_label": thread_info.thread_label,
}
)
return thread_context
def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
if not tracing_enabled:
return None
rid = str(rid)
if rid not in reqs_context or not reqs_context[rid].root_span_context:
return None
pid = threading.get_native_id()
prev_span_context = None
thread_context = reqs_context[rid].threads_context[pid]
if thread_context.cur_slice_stack:
cur_slice_info = thread_context.cur_slice_stack[0]
prev_span_context = cur_slice_info.span.get_span_context()
elif thread_context.last_span_context:
prev_span_context = thread_context.last_span_context
trace_context = SglangTracePropagateContext(
reqs_context[rid].root_span_context, prev_span_context
)
return trace_context.to_dict()
def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]]):
if not tracing_enabled:
return
if not trace_context:
return
trace_context = SglangTracePropagateContext.instance_from_dict(trace_context)
if not trace_context:
return
rid = str(rid)
# Create a copy of the request context
if rid not in reqs_context:
reqs_context[rid] = SglangTraceReqContext(
rid=rid,
start_time_ns=__get_cur_time_ns(),
threads_context={},
root_span_context=trace_context.root_span_context,
is_copy=True,
)
pid = threading.get_native_id()
if pid in reqs_context[rid].threads_context:
return
# Create new thread context.
reqs_context[rid].threads_context[pid] = __create_thread_context(
pid,
trace_context.root_span_context,
reqs_context[rid].start_time_ns,
)
reqs_context[rid].threads_context[
pid
].last_span_context = trace_context.prev_span_context
def trace_req_start(
rid: str,
bootstrap_room: Optional[int] = None,
ts: Optional[int] = None,
):
if not tracing_enabled:
return
rid = str(rid)
ts = ts or __get_cur_time_ns()
pid = threading.get_native_id()
if pid not in threads_info:
return
# create req context and root span
reqs_context[rid] = SglangTraceReqContext(
rid=rid,
start_time_ns=ts,
threads_context={},
bootstrap_room=bootstrap_room,
is_copy=False,
)
# Drop the worker_id added by MultiTokenizer
orig_rid = rid.split("_")[-1]
tracer = threads_info[pid].tracer
root_span = tracer.start_span(
name=f"Req {orig_rid[:8]}",
start_time=ts,
)
root_span.set_attributes(
{
"rid": rid,
"bootstrap_room": bootstrap_room if bootstrap_room else "None",
}
)
reqs_context[rid].root_span = root_span
reqs_context[rid].root_span_context = trace.set_span_in_context(root_span)
# create thread context and thread span
reqs_context[rid].threads_context[pid] = __create_thread_context(
pid,
reqs_context[rid].root_span_context,
ts,
)
def trace_req_finish(
rid: str, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None
):
if not tracing_enabled:
return
rid = str(rid)
if rid not in reqs_context:
return
req_context = reqs_context[rid]
ts = ts or __get_cur_time_ns()
# End all unclosed thread spans.
for thread_context in req_context.threads_context.values():
thread_context.thread_span.end(end_time=ts)
if attrs:
req_context.root_span.set_attributes(attrs)
req_context.root_span.end(end_time=ts)
del reqs_context[rid]
def trace_slice_start(
name: str,
rid: str,
ts: Optional[int] = None,
anonymous: bool = False,
):
rid = str(rid)
if not tracing_enabled or rid not in reqs_context:
return
pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context:
return
thread_context = reqs_context[rid].threads_context[pid]
ts = ts or __get_cur_time_ns()
slice_info = SglangTraceSliceContext(
slice_name=name,
anonymous=anonymous,
)
# find prev slice
prev_span_context = None
if not thread_context.cur_slice_stack:
if thread_context.last_span_context:
prev_span_context = thread_context.last_span_context
parent_span = thread_context.thread_span
if thread_context.cur_slice_stack:
parent_span = thread_context.cur_slice_stack[-1].span
parent_span_context = trace.set_span_in_context(parent_span)
span = thread_context.thread_info.tracer.start_span(
name=slice_info.slice_name,
start_time=ts,
context=parent_span_context,
)
if prev_span_context:
span.add_link(prev_span_context)
slice_info.span = span
thread_context.cur_slice_stack.append(slice_info)
def trace_slice_end(
name: str,
rid: str,
ts: Optional[int] = None,
attrs: Optional[Dict[str, Any]] = None,
auto_next_anon: bool = False,
thread_finish_flag: bool = False,
):
rid = str(rid)
if not tracing_enabled or rid not in reqs_context:
return
pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context:
return
thread_context = reqs_context[rid].threads_context[pid]
if not thread_context.cur_slice_stack:
logger.warning(f"No matching with the SLICE_START event{name} is required.")
return
ts = ts or __get_cur_time_ns()
slice_info = thread_context.cur_slice_stack[-1]
span = slice_info.span
if slice_info.anonymous:
span.update_name(name)
else:
span = slice_info.span
if slice_info.slice_name != name:
span.set_status(trace.Status(trace.StatusCode.ERROR))
logger.warning(f"Slice name mismatch: {name} != {slice_info.slice_name}")
if attrs:
span.set_attributes(attrs)
span.end(end_time=ts)
thread_context.cur_slice_stack.pop()
if len(thread_context.cur_slice_stack) == 0:
thread_context.last_span_context = span.get_span_context()
# If this is the last slice in the thread,
# release the thread context and check whether to release the request context.
if thread_finish_flag:
thread_context.thread_span.end(end_time=ts)
del reqs_context[rid].threads_context[pid]
if reqs_context[rid].is_copy and not reqs_context[rid].threads_context:
del reqs_context[rid]
return
if auto_next_anon:
trace_slice_start("", rid, ts, True)
# alias
trace_slice = trace_slice_end
# Add event to the current slice on the same thread with the same rid.
def trace_event(name: str, rid: str, ts: Optional[int] = None):
if not tracing_enabled or rid not in reqs_context:
return
rid = str(rid)
pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context:
return
thread_context = reqs_context[rid].threads_context[pid]
if not thread_context.cur_slice_stack:
logger.warning(f"No slice is currently being traced.")
return
ts = ts or __get_cur_time_ns()
slice_info = thread_context.cur_slice_stack[-1]
slice_info.span.add_event(name=name, timestamp=ts)
# Add attrs to the current slice on the same thread with the same rid.
def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]):
if not tracing_enabled or rid not in reqs_context:
return
rid = str(rid)
pid = threading.get_native_id()
if pid not in reqs_context[rid].threads_context:
return
thread_context = reqs_context[rid].threads_context[pid]
if not thread_context.cur_slice_stack:
logger.warning(f"No slice is currently being traced.")
return
slice_info = thread_context.cur_slice_stack[-1]
slice_info.span.set_attributes(attrs)
import multiprocessing as mp
import os
import subprocess
import time
import unittest
from dataclasses import dataclass
from typing import Any, Dict, Optional
import requests
import zmq
from sglang import Engine
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.tracing.trace import *
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
@dataclass
class Req:
rid: int
trace_context: Optional[Dict[str, Any]] = None
class TestTrace(CustomTestCase):
def __launch_otel_jaeger(self):
cmd = [
"docker",
"compose",
"-f",
"../../examples/monitoring/tracing_compose.yaml",
"up",
"-d",
]
proc = subprocess.run(cmd)
if proc.returncode != 0:
print("launch opentelemetry collector and jaeger docker err")
return False
return True
def __stop_otel_jaeger(self):
cmd = [
"docker",
"compose",
"-f",
"../../examples/monitoring/tracing_compose.yaml",
"down",
]
proc = subprocess.run(cmd)
if proc.returncode != 0:
print("stop opentelemetry collector and jaeger docker err")
return False
return True
def __clear_trace_file(self):
try:
os.remove("/tmp/otel_trace.json")
except:
pass
def test_trace_enable(self):
self.__clear_trace_file()
assert self.__launch_otel_jaeger()
process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--enable-trace", "--oltp-traces-endpoint", "0.0.0.0:4317"],
)
try:
# Make some requests to generate trace data
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
self.assertEqual(response.status_code, 200)
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
"stream": True,
},
stream=True,
)
for _ in response.iter_lines(decode_unicode=False):
pass
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
time.sleep(10)
# check trace file
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
finally:
kill_process_tree(process.pid)
assert self.__stop_otel_jaeger()
def test_trace_engine_enable(self):
self.__clear_trace_file()
assert self.__launch_otel_jaeger()
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
sampling_params = {"temperature": 0, "max_new_tokens": 8}
engine = Engine(
model_path=model_path,
random_seed=42,
enable_trace=True,
oltp_traces_endpoint="localhost:4317",
)
try:
engine.generate(prompt, sampling_params)
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
time.sleep(10)
# check trace file
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
finally:
engine.shutdown()
assert self.__stop_otel_jaeger()
def test_trace_engine_encode(self):
self.__clear_trace_file()
assert self.__launch_otel_jaeger()
prompt = "Today is a sunny day and I like"
model_path = "Qwen/Qwen2-7B"
engine = Engine(
model_path=model_path,
random_seed=42,
enable_trace=True,
oltp_traces_endpoint="localhost:4317",
is_embedding=True,
)
try:
engine.encode(prompt)
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
time.sleep(10)
# check trace file
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
finally:
engine.shutdown()
assert self.__stop_otel_jaeger()
def test_slice_trace_simple(self):
self.__clear_trace_file()
assert self.__launch_otel_jaeger()
try:
process_tracing_init("0.0.0.0:4317", "test")
trace_set_thread_info("Test")
trace_req_start(0)
trace_slice_start("test slice", 0)
time.sleep(1)
trace_slice_end("test slice", 0)
trace_req_finish(0)
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
time.sleep(10)
# check trace file
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
finally:
assert self.__stop_otel_jaeger()
def test_slice_trace_complex(self):
self.__clear_trace_file()
assert self.__launch_otel_jaeger()
try:
process_tracing_init("0.0.0.0:4317", "test")
trace_set_thread_info("Test")
trace_req_start(0)
trace_slice_start("", 0, anonymous=True)
time.sleep(1)
trace_slice_end("slice A", 0, auto_next_anon=True)
time.sleep(1)
trace_slice_end("slice B", 0, auto_next_anon=True)
time.sleep(1)
trace_slice_end("slice C", 0, thread_finish_flag=True)
trace_req_finish(0)
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
time.sleep(10)
# check trace file
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
finally:
assert self.__stop_otel_jaeger()
def test_trace_context_propagete(self):
def __process_work():
process_tracing_init("0.0.0.0:4317", "test")
trace_set_thread_info("Sub Process")
context = zmq.Context(2)
recv_from_main = get_zmq_socket(
context, zmq.PULL, "ipc:///tmp/zmq_test.ipc", True
)
try:
req = recv_from_main.recv_pyobj()
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("work", req.rid)
time.sleep(1)
trace_slice_end("work", req.rid, thread_finish_flag=True)
finally:
recv_from_main.close()
context.term()
self.__clear_trace_file()
assert self.__launch_otel_jaeger()
context = zmq.Context(2)
send_to_subproc = get_zmq_socket(
context, zmq.PUSH, "ipc:///tmp/zmq_test.ipc", False
)
try:
process_tracing_init("0.0.0.0:4317", "test")
trace_set_thread_info("Main Process")
subproc = mp.Process(target=__process_work)
subproc.start()
# sleep for a few second to ensure subprocess init
time.sleep(1)
req = Req(rid=0)
trace_req_start(req.rid)
trace_slice_start("dispatch", req.rid)
time.sleep(1)
req.trace_context = trace_get_proc_propagate_context(req.rid)
send_to_subproc.send_pyobj(req)
trace_slice_end("dispatch", req.rid)
subproc.join()
trace_req_finish(req.rid)
# sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file.
time.sleep(10)
# check trace file
assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist"
assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty"
finally:
send_to_subproc.close()
context.term()
assert self.__stop_otel_jaeger()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment