Unverified Commit ea961060 authored by Feng Su's avatar Feng Su Committed by GitHub
Browse files

[Feature] Sglang Tracing: Fine-Grained Tracking for Request Latency - Part 2 (#10804)


Signed-off-by: default avatarFeng Su <sufeng@linux.alibaba.com>
parent b1e13e7c
...@@ -95,6 +95,8 @@ SGLang supports various environment variables that can be used to configure its ...@@ -95,6 +95,8 @@ SGLang supports various environment variables that can be used to configure its
| --- | --- | --- | | --- | --- | --- |
| `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` | | `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` |
| `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` | | `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` |
| `SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS` | Config BatchSpanProcessor.schedule_delay_millis if tracing is enabled | `500` |
| `SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE` | Config BatchSpanProcessor.max_export_batch_size if tracing is enabled | `64` |
## Storage & Caching ## Storage & Caching
......
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server. SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--otlp-traces-endpoint` when launching the server.
You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965. You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965.
...@@ -22,7 +22,13 @@ This section explains how to configure the request tracing and export the trace ...@@ -22,7 +22,13 @@ This section explains how to configure the request tracing and export the trace
3. start your SGLang server with tracing enabled 3. start your SGLang server with tracing enabled
```bash ```bash
python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 <other option> # set env variables
export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500
export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64
# start the prefill and decode server
python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
# start the mini lb
python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
``` ```
Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317. Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.
...@@ -39,9 +45,9 @@ We have already inserted instrumentation points in the tokenizer and scheduler m ...@@ -39,9 +45,9 @@ We have already inserted instrumentation points in the tokenizer and scheduler m
Every process involved in tracing during the initialization phase should execute: Every process involved in tracing during the initialization phase should execute:
```python ```python
process_tracing_init(oltp_traces_endpoint, server_name) process_tracing_init(otlp_traces_endpoint, server_name)
``` ```
The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes. The otlp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.
Every thread involved in tracing during the initialization phase should execute: Every thread involved in tracing during the initialization phase should execute:
```python ```python
...@@ -95,24 +101,52 @@ We have already inserted instrumentation points in the tokenizer and scheduler m ...@@ -95,24 +101,52 @@ We have already inserted instrumentation points in the tokenizer and scheduler m
trace_set_proc_propagate_context(rid, req.trace_context) trace_set_proc_propagate_context(rid, req.trace_context)
``` ```
5. When the request execution flow transfers to another node(PD disaggregation), the trace context needs to be explicitly propagated.
- sender: Execute the following code before sending the request to node thread via http
```python
trace_context = trace_get_remote_propagate_context(bootstrap_room_list)
headers = {"trace_context": trace_context}
session.post(url, headers=headers)
```
- receiver: Execute the following code after receiving the request via http
```python
trace_set_remote_propagate_context(request.headers['trace_context'])
```
## How to Extend the Tracing Framework to Support Complex Tracing Scenarios ## How to Extend the Tracing Framework to Support Complex Tracing Scenarios
The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles. The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.
The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure. The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a two-level trace context structure and a four-level span structure: `SglangTraceReqContext`, `SglangTraceThreadContext`. Their relationship is as follows:
The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows:
``` ```
SglangTraceReqContext (req_id="req-123") SglangTraceReqContext (req_id="req-123")
├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0) ├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0)
│ └── SglangTraceSliceContext (name="prefill") # cur slice
| |
└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1) └── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1)
└── SglangTraceSliceContext (name="prefill") # cur slice
``` ```
Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends. Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is stored in a list.
In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow. In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.
When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span. When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span.
We designed a four-level span structure, consisting of `bootstrap_room_span`, `req_root_span`, `thread_span`, and `slice_span`. Among them, `req_root_span` and `thread_span` correspond to `SglangTraceReqContext` and `SglangTraceThreadContext`, respectively, and `slice_span` is stored within the `SglangTraceThreadContext`. The `bootstrap_room_span` is designed to accommodate the separation of PD-disaggregation. On different nodes, we may want to add certain attributes to the `req_root_span`. However, if the `req_root_span` is shared across all nodes, the Prefill and Decode nodes would not be allowed to add attributes due to the constraints imposed by OpenTelemetry's design.
```
bootstrap room span
├── router req root span
| └── router thread span
| └── slice span
├── prefill req root span
| ├── tokenizer thread span
| | └── slice span
| └── scheduler thread span
| └── slice span
└── decode req root span
├── tokenizer thread span
| └── slice span
└── scheduler thread span
└── slice span
```
...@@ -58,6 +58,11 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -58,6 +58,11 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool, ReqToTokenPool,
SWAKVPool, SWAKVPool,
) )
from sglang.srt.tracing.trace import (
trace_event_batch,
trace_slice_batch,
trace_slice_end,
)
from sglang.srt.utils import get_int_env_var, require_mlp_sync from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -313,6 +318,7 @@ class DecodePreallocQueue: ...@@ -313,6 +318,7 @@ class DecodePreallocQueue:
) )
req.add_latency(RequestStage.DECODE_PREPARE) req.add_latency(RequestStage.DECODE_PREPARE)
trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
self.queue.append( self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
) )
...@@ -527,6 +533,9 @@ class DecodePreallocQueue: ...@@ -527,6 +533,9 @@ class DecodePreallocQueue:
time.perf_counter() time.perf_counter()
) )
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
trace_slice_end(
RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
)
self.queue = [ self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
...@@ -775,8 +784,19 @@ class DecodeTransferQueue: ...@@ -775,8 +784,19 @@ class DecodeTransferQueue:
[decode_req.req], decode_req.req.return_logprob [decode_req.req], decode_req.req.return_logprob
) )
self.tree_cache.cache_finished_req(decode_req.req) self.tree_cache.cache_finished_req(decode_req.req)
trace_slice_end(
RequestStage.DECODE_QUICK_FINISH,
decode_req.req.rid,
thread_finish_flag=True,
)
else: else:
transferred_reqs.append(decode_req.req) transferred_reqs.append(decode_req.req)
trace_slice_end(
RequestStage.DECODE_TRANSFERRED,
decode_req.req.rid,
auto_next_anon=True,
)
elif poll in [ elif poll in [
KVPoll.Bootstrapping, KVPoll.Bootstrapping,
KVPoll.WaitingForInput, KVPoll.WaitingForInput,
...@@ -822,6 +842,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -822,6 +842,7 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output( self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs) batch.reqs, any(req.return_logprob for req in batch.reqs)
) )
trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
if prepare_mlp_sync_flag: if prepare_mlp_sync_flag:
self._prepare_idle_batch_and_run(None) self._prepare_idle_batch_and_run(None)
else: else:
...@@ -871,6 +892,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -871,6 +892,7 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output( self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs) batch.reqs, any(req.return_logprob for req in batch.reqs)
) )
trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
if prepare_mlp_sync_flag: if prepare_mlp_sync_flag:
batch_, batch_result = self._prepare_idle_batch_and_run( batch_, batch_result = self._prepare_idle_batch_and_run(
None, delay_process=True None, delay_process=True
...@@ -953,6 +975,9 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -953,6 +975,9 @@ class SchedulerDisaggregationDecodeMixin:
self.running_batch = self.update_running_batch(self.running_batch) self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None ret = self.running_batch if not self.running_batch.is_empty() else None
if ret:
attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
trace_event_batch("schedule", ret.reqs, attrs=attrs)
return ret return ret
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
......
...@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
NSATokenToKVPool, NSATokenToKVPool,
SWAKVPool, SWAKVPool,
) )
from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -198,6 +199,7 @@ class PrefillBootstrapQueue: ...@@ -198,6 +199,7 @@ class PrefillBootstrapQueue:
self._process_req(req) self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE) req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req) self.queue.append(req)
trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True)
def extend(self, reqs: List[Req], num_kv_heads: int) -> None: def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
for req in reqs: for req in reqs:
...@@ -289,6 +291,10 @@ class PrefillBootstrapQueue: ...@@ -289,6 +291,10 @@ class PrefillBootstrapQueue:
req.time_stats.wait_queue_entry_time = time.perf_counter() req.time_stats.wait_queue_entry_time = time.perf_counter()
req.add_latency(RequestStage.PREFILL_BOOTSTRAP) req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
trace_slice_end(
RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True
)
self.queue = [ self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
] ]
...@@ -316,6 +322,9 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -316,6 +322,9 @@ class SchedulerDisaggregationPrefillMixin:
) )
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
if batch:
attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
trace_event_batch("schedule", batch.reqs, attrs=attrs)
if require_mlp_sync(self.server_args): if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch) batch = self.prepare_mlp_sync_batch(batch)
...@@ -348,6 +357,9 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -348,6 +357,9 @@ class SchedulerDisaggregationPrefillMixin:
) )
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
if batch:
attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
trace_event_batch("schedule", batch.reqs, attrs=attrs)
if require_mlp_sync(self.server_args): if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch) batch = self.prepare_mlp_sync_batch(batch)
...@@ -423,6 +435,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -423,6 +435,7 @@ class SchedulerDisaggregationPrefillMixin:
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD) req.add_latency(RequestStage.PREFILL_FORWARD)
trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True)
self.disagg_prefill_inflight_queue.append(req) self.disagg_prefill_inflight_queue.append(req)
if self.spec_algorithm.is_eagle() and batch.spec_info is not None: if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
req.output_topk_p = batch.spec_info.topk_p[i] req.output_topk_p = batch.spec_info.topk_p[i]
...@@ -487,6 +500,9 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -487,6 +500,9 @@ class SchedulerDisaggregationPrefillMixin:
if self.enable_overlap: if self.enable_overlap:
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
trace_slice(
RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True
)
self.maybe_send_health_check_signal() self.maybe_send_health_check_signal()
...@@ -558,6 +574,9 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -558,6 +574,9 @@ class SchedulerDisaggregationPrefillMixin:
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE) req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index) self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1 req.metadata_buffer_index = -1
trace_slice(
RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True
)
self.disagg_prefill_inflight_queue = undone_reqs self.disagg_prefill_inflight_queue = undone_reqs
......
...@@ -143,10 +143,13 @@ class Engine(EngineBase): ...@@ -143,10 +143,13 @@ class Engine(EngineBase):
# Enable tracing # Enable tracing
if server_args.enable_trace: if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang") process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null": thread_label = "Tokenizer"
thread_label = "Tokenizer" if server_args.disaggregation_mode == "prefill":
trace_set_thread_info(thread_label) thread_label = "Prefill Tokenizer"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode Tokenizer"
trace_set_thread_info(thread_label)
try: try:
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
......
...@@ -220,9 +220,12 @@ async def lifespan(fast_api_app: FastAPI): ...@@ -220,9 +220,12 @@ async def lifespan(fast_api_app: FastAPI):
# Init tracing # Init tracing
if server_args.enable_trace: if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang") process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null": if server_args.disaggregation_mode == "prefill":
trace_set_thread_info(thread_label) thread_label = "Prefill" + thread_label
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode" + thread_label
trace_set_thread_info(thread_label)
# Initialize OpenAI serving handlers # Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
......
...@@ -129,6 +129,8 @@ class Envs: ...@@ -129,6 +129,8 @@ class Envs:
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1) SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial") SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp") SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS = EnvInt(500)
SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE = EnvInt(64)
# Scheduler: memory leak test # Scheduler: memory leak test
SGLANG_TEST_RETRACT = EnvBool(False) SGLANG_TEST_RETRACT = EnvBool(False)
......
...@@ -34,13 +34,21 @@ from sglang.srt.managers.io_struct import ( ...@@ -34,13 +34,21 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
WatchLoadUpdateReq, WatchLoadUpdateReq,
) )
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req, RequestStage
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import ( from sglang.srt.server_args import (
DP_ATTENTION_HANDSHAKE_PORT_DELTA, DP_ATTENTION_HANDSHAKE_PORT_DELTA,
PortArgs, PortArgs,
ServerArgs, ServerArgs,
) )
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_get_proc_propagate_context,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice_end,
trace_slice_start,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
bind_port, bind_port,
configure_logger, configure_logger,
...@@ -170,11 +178,22 @@ class DataParallelController: ...@@ -170,11 +178,22 @@ class DataParallelController:
def handle_load_update_req(self, obj): def handle_load_update_req(self, obj):
self.dp_budget.update_budget(obj) self.dp_budget.update_budget(obj)
def dispatching_with_trace(self, req: Req):
if self.server_args.enable_trace:
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start(RequestStage.DC_DISPATCH, req.rid)
req.trace_context = trace_get_proc_propagate_context(req.rid)
self.dispatching(req)
if self.server_args.enable_trace:
trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True)
def init_dispatcher(self): def init_dispatcher(self):
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
(TokenizedGenerateReqInput, self.dispatching), (TokenizedGenerateReqInput, self.dispatching_with_trace),
(TokenizedEmbeddingReqInput, self.dispatching), (TokenizedEmbeddingReqInput, self.dispatching_with_trace),
(BlockReqInput, self.send_to_all_workers), (BlockReqInput, self.send_to_all_workers),
(WatchLoadUpdateReq, self.handle_load_update_req), (WatchLoadUpdateReq, self.handle_load_update_req),
] ]
...@@ -487,6 +506,14 @@ def run_data_parallel_controller_process( ...@@ -487,6 +506,14 @@ def run_data_parallel_controller_process(
pipe_writer, pipe_writer,
): ):
kill_itself_when_parent_died() kill_itself_when_parent_died()
if server_args.enable_trace:
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
thread_label = "DP Controller"
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill DP Controller"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode DP Controller"
trace_set_thread_info(thread_label)
setproctitle.setproctitle("sglang::data_parallel_controller") setproctitle.setproctitle("sglang::data_parallel_controller")
faulthandler.enable() faulthandler.enable()
configure_logger(server_args) configure_logger(server_args)
......
...@@ -396,13 +396,23 @@ class MultimodalInputs: ...@@ -396,13 +396,23 @@ class MultimodalInputs:
class RequestStage(str, enum.Enum): class RequestStage(str, enum.Enum):
# prefill # Tokenizer
TOKENIZE = "tokenize"
TOKENIZER_DISPATCH = "dispatch"
# DP controller
DC_DISPATCH = "dc_dispatch"
# common/non-disaggregation
PREFILL_WAITING = "prefill_waiting" PREFILL_WAITING = "prefill_waiting"
REQUEST_PROCESS = "request_process"
DECODE_LOOP = "decode_loop"
PREFILL_FORWARD = "prefill_forward"
PREFILL_CHUNKED_FORWARD = "chunked_prefill"
# disaggregation prefill # disaggregation prefill
PREFILL_PREPARE = "prefill_prepare" PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap" PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache" PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
# disaggregation decode # disaggregation decode
...@@ -410,6 +420,8 @@ class RequestStage(str, enum.Enum): ...@@ -410,6 +420,8 @@ class RequestStage(str, enum.Enum):
DECODE_BOOTSTRAP = "decode_bootstrap" DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting" DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred" DECODE_TRANSFERRED = "decode_transferred"
DECODE_FAKE_OUTPUT = "fake_output"
DECODE_QUICK_FINISH = "quick_finish"
class Req: class Req:
......
...@@ -157,6 +157,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args ...@@ -157,6 +157,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import ( from sglang.srt.tracing.trace import (
process_tracing_init, process_tracing_init,
trace_event_batch,
trace_set_proc_propagate_context, trace_set_proc_propagate_context,
trace_set_thread_info, trace_set_thread_info,
trace_slice_batch, trace_slice_batch,
...@@ -1354,7 +1355,7 @@ class Scheduler( ...@@ -1354,7 +1355,7 @@ class Scheduler(
self._prefetch_kvcache(req) self._prefetch_kvcache(req)
self.waiting_queue.append(req) self.waiting_queue.append(req)
req.time_stats.wait_queue_entry_time = time.perf_counter() req.time_stats.wait_queue_entry_time = time.perf_counter()
trace_slice_end("process req", req.rid, auto_next_anon=True) trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True)
elif self.disaggregation_mode == DisaggregationMode.PREFILL: elif self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req) self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add( self.disagg_prefill_bootstrap_queue.add(
...@@ -1618,6 +1619,10 @@ class Scheduler( ...@@ -1618,6 +1619,10 @@ class Scheduler(
if need_dp_attn_preparation: if need_dp_attn_preparation:
ret = self.prepare_mlp_sync_batch(ret) ret = self.prepare_mlp_sync_batch(ret)
if ret:
attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
trace_event_batch("schedule", ret.reqs, attrs=attrs)
return ret return ret
def get_num_allocatable_reqs(self, running_bs): def get_num_allocatable_reqs(self, running_bs):
...@@ -1993,13 +1998,10 @@ class Scheduler( ...@@ -1993,13 +1998,10 @@ class Scheduler(
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if self.enable_trace: trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs)
trace_slice_batch("decode loop", batch.reqs)
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
if self.enable_trace:
trace_slice_batch("prefill", batch.reqs)
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
...@@ -2741,10 +2743,13 @@ def run_scheduler_process( ...@@ -2741,10 +2743,13 @@ def run_scheduler_process(
# Set up tracing # Set up tracing
if server_args.enable_trace: if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang") process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null": thread_label = "Scheduler"
thread_label = "Scheduler" if server_args.disaggregation_mode == "prefill":
trace_set_thread_info(thread_label, tp_rank, dp_rank) thread_label = "Prefill Scheduler"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
# Create a scheduler and run the event loop # Create a scheduler and run the event loop
try: try:
......
...@@ -14,7 +14,13 @@ from sglang.srt.managers.io_struct import ( ...@@ -14,7 +14,13 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOutput, BatchEmbeddingOutput,
BatchTokenIDOutput, BatchTokenIDOutput,
) )
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import (
BaseFinishReason,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.tracing.trace import trace_slice
from sglang.srt.utils.common import ceil_div from sglang.srt.utils.common import ceil_div
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -160,6 +166,14 @@ class SchedulerOutputProcessorMixin: ...@@ -160,6 +166,14 @@ class SchedulerOutputProcessorMixin:
) )
self.abort_request(AbortReq(rid=req.rid)) self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
trace_slice(
RequestStage.PREFILL_FORWARD,
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
else: else:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_chunked -= 1 req.is_chunked -= 1
...@@ -188,6 +202,12 @@ class SchedulerOutputProcessorMixin: ...@@ -188,6 +202,12 @@ class SchedulerOutputProcessorMixin:
) )
logprob_pt += num_input_logprobs logprob_pt += num_input_logprobs
trace_slice(
RequestStage.PREFILL_CHUNKED_FORWARD,
req.rid,
auto_next_anon=True,
)
else: # embedding or reward model else: # embedding or reward model
is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set() is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
...@@ -227,6 +247,13 @@ class SchedulerOutputProcessorMixin: ...@@ -227,6 +247,13 @@ class SchedulerOutputProcessorMixin:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_chunked -= 1 req.is_chunked -= 1
trace_slice(
RequestStage.PREFILL_FORWARD,
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def _resolve_spec_overlap_token_ids( def _resolve_spec_overlap_token_ids(
......
...@@ -68,6 +68,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -68,6 +68,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.schedule_batch import RequestStage
from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler import is_health_check_generate_req
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
...@@ -79,6 +80,7 @@ from sglang.srt.tracing.trace import ( ...@@ -79,6 +80,7 @@ from sglang.srt.tracing.trace import (
trace_get_proc_propagate_context, trace_get_proc_propagate_context,
trace_req_finish, trace_req_finish,
trace_req_start, trace_req_start,
trace_set_remote_propagate_context,
trace_slice_end, trace_slice_end,
trace_slice_start, trace_slice_start,
) )
...@@ -383,6 +385,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -383,6 +385,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.auto_create_handle_loop() self.auto_create_handle_loop()
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if request:
if "trace_context" in request.headers:
trace_set_remote_propagate_context(request.headers["trace_context"])
if self.server_args.tokenizer_worker_num > 1: if self.server_args.tokenizer_worker_num > 1:
self._attach_multi_http_worker_info(obj) self._attach_multi_http_worker_info(obj)
...@@ -605,7 +611,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -605,7 +611,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
mm_inputs = None mm_inputs = None
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
trace_slice_end("tokenize", obj.rid) trace_slice_end(RequestStage.TOKENIZE, obj.rid)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
) )
...@@ -831,7 +837,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -831,7 +837,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
req, req.text, input_ids_list[i], None, None, token_type_ids req, req.text, input_ids_list[i], None, None, token_type_ids
) )
) )
trace_slice_end("tokenize", req.rid) trace_slice_end(RequestStage.TOKENIZE, req.rid)
logger.debug(f"Completed batch processing for {batch_size} requests") logger.debug(f"Completed batch processing for {batch_size} requests")
return tokenized_objs return tokenized_objs
...@@ -883,12 +889,14 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -883,12 +889,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None, created_time: Optional[float] = None,
): ):
trace_slice_start("dispatch", obj.rid) trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid)
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True) trace_slice_end(
RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True
)
return state return state
def _send_batch_request( def _send_batch_request(
...@@ -2131,7 +2139,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -2131,7 +2139,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
bootstrap_room = ( bootstrap_room = (
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
) )
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9)) trace_req_start(
obj.rid,
bootstrap_room,
ts=int(created_time * 1e9),
role=self.server_args.disaggregation_mode,
)
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
else: else:
for i in range(len(obj.rid)): for i in range(len(obj.rid)):
...@@ -2140,7 +2153,12 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -2140,7 +2153,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
else None else None
) )
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9)) trace_req_start(
obj.rid[i],
bootstrap_room,
ts=int(created_time * 1e9),
role=self.server_args.disaggregation_mode,
)
trace_slice_start( trace_slice_start(
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
) )
......
...@@ -299,7 +299,7 @@ class ServerArgs: ...@@ -299,7 +299,7 @@ class ServerArgs:
enable_request_time_stats_logging: bool = False enable_request_time_stats_logging: bool = False
kv_events_config: Optional[str] = None kv_events_config: Optional[str] = None
enable_trace: bool = False enable_trace: bool = False
oltp_traces_endpoint: str = "localhost:4317" otlp_traces_endpoint: str = "localhost:4317"
# API related # API related
api_key: Optional[str] = None api_key: Optional[str] = None
...@@ -2340,7 +2340,7 @@ class ServerArgs: ...@@ -2340,7 +2340,7 @@ class ServerArgs:
help="Enable opentelemetry trace", help="Enable opentelemetry trace",
) )
parser.add_argument( parser.add_argument(
"--oltp-traces-endpoint", "--otlp-traces-endpoint",
type=str, type=str,
default="localhost:4317", default="localhost:4317",
help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>", help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
from __future__ import annotations from __future__ import annotations
import base64
import json
import logging import logging
import os import os
import random import random
...@@ -24,6 +26,8 @@ import uuid ...@@ -24,6 +26,8 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Req from sglang.srt.managers.scheduler import Req
...@@ -85,6 +89,8 @@ class SglangTraceReqContext: ...@@ -85,6 +89,8 @@ class SglangTraceReqContext:
# Indicates whether this instance is a replica from the main process. # Indicates whether this instance is a replica from the main process.
# When True, root_span is None and only root_span_context is preserved. # When True, root_span is None and only root_span_context is preserved.
is_copy: bool = False is_copy: bool = False
bootstrap_room_span: Optional[trace.span.Span] = None
bootstrap_room_span_context: Optional[context.Context] = None
root_span: Optional[trace.span.Span] = None root_span: Optional[trace.span.Span] = None
root_span_context: Optional[context.Context] = None root_span_context: Optional[context.Context] = None
...@@ -96,8 +102,7 @@ class SglangTracePropagateContext: ...@@ -96,8 +102,7 @@ class SglangTracePropagateContext:
def to_dict(self): def to_dict(self):
carrier: dict[str, str] = {} carrier: dict[str, str] = {}
context.attach(self.root_span_context) propagate.inject(carrier, self.root_span_context)
propagate.inject(carrier)
if self.prev_span_context: if self.prev_span_context:
return { return {
...@@ -149,6 +154,7 @@ class SglangTraceCustomIdGenerator(id_generator.IdGenerator): ...@@ -149,6 +154,7 @@ class SglangTraceCustomIdGenerator(id_generator.IdGenerator):
# global variables # global variables
remote_trace_contexts: Dict[str, SglangTracePropagateContext] = {}
threads_info: Dict[int, SglangTraceThreadInfo] = {} threads_info: Dict[int, SglangTraceThreadInfo] = {}
reqs_context: Dict[str, SglangTraceReqContext] = {} reqs_context: Dict[str, SglangTraceReqContext] = {}
...@@ -193,8 +199,17 @@ def process_tracing_init(otlp_endpoint, server_name): ...@@ -193,8 +199,17 @@ def process_tracing_init(otlp_endpoint, server_name):
resource=resource, id_generator=SglangTraceCustomIdGenerator() resource=resource, id_generator=SglangTraceCustomIdGenerator()
) )
schedule_delay_millis = get_int_env_var(
"SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", 500
)
max_export_batch_size = get_int_env_var(
"SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", 64
)
processor = BatchSpanProcessor( processor = BatchSpanProcessor(
OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True),
schedule_delay_millis=schedule_delay_millis,
max_export_batch_size=max_export_batch_size,
) )
tracer_provider.add_span_processor(processor) tracer_provider.add_span_processor(processor)
trace.set_tracer_provider(tracer_provider) trace.set_tracer_provider(tracer_provider)
...@@ -266,7 +281,9 @@ def __create_thread_context(pid, req_span_context, ts: Optional[int] = None): ...@@ -266,7 +281,9 @@ def __create_thread_context(pid, req_span_context, ts: Optional[int] = None):
return thread_context return thread_context
def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]: def trace_get_proc_propagate_context(
rid, remote_propagate=False
) -> Optional[Dict[str, Any]]:
if not tracing_enabled: if not tracing_enabled:
return None return None
...@@ -283,9 +300,11 @@ def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]: ...@@ -283,9 +300,11 @@ def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
elif thread_context.last_span_context: elif thread_context.last_span_context:
prev_span_context = thread_context.last_span_context prev_span_context = thread_context.last_span_context
trace_context = SglangTracePropagateContext( root_span_context = reqs_context[rid].root_span_context
reqs_context[rid].root_span_context, prev_span_context if remote_propagate:
) root_span_context = reqs_context[rid].bootstrap_room_span_context
trace_context = SglangTracePropagateContext(root_span_context, prev_span_context)
return trace_context.to_dict() return trace_context.to_dict()
...@@ -327,10 +346,54 @@ def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any] ...@@ -327,10 +346,54 @@ def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]
].last_span_context = trace_context.prev_span_context ].last_span_context = trace_context.prev_span_context
def trace_get_remote_propagate_context(bootstrap_room_list: List[str]):
if not tracing_enabled:
return ""
reqs_trace_contexts = {}
for bootstrap_room in bootstrap_room_list:
# In the router, rid is also the bootstrap room.
bootstrap_room = str(bootstrap_room)
if bootstrap_room not in reqs_context:
continue
_context = trace_get_proc_propagate_context(
bootstrap_room, remote_propagate=True
)
reqs_trace_contexts[bootstrap_room] = _context
json_str = json.dumps(reqs_trace_contexts, ensure_ascii=False)
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
def trace_set_remote_propagate_context(base64_str):
if not tracing_enabled:
return
if base64_str is None or base64_str == "" or base64_str == "None":
return
base64_bytes = base64.b64decode(base64_str)
json_str = base64_bytes.decode("utf-8")
remote_reqs_trace_contexts = json.loads(json_str)
for bootstrap_room in remote_reqs_trace_contexts:
if bootstrap_room in remote_trace_contexts:
continue
remote_trace_contexts[bootstrap_room] = (
SglangTracePropagateContext.instance_from_dict(
remote_reqs_trace_contexts[bootstrap_room]
)
)
def trace_req_start( def trace_req_start(
rid: str, rid: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
ts: Optional[int] = None, ts: Optional[int] = None,
role: Optional[str] = "null",
): ):
if not tracing_enabled: if not tracing_enabled:
return return
...@@ -344,6 +407,7 @@ def trace_req_start( ...@@ -344,6 +407,7 @@ def trace_req_start(
return return
# create req context and root span # create req context and root span
bootstrap_room = 0 if bootstrap_room is None else bootstrap_room
reqs_context[rid] = SglangTraceReqContext( reqs_context[rid] = SglangTraceReqContext(
rid=rid, rid=rid,
start_time_ns=ts, start_time_ns=ts,
...@@ -352,23 +416,42 @@ def trace_req_start( ...@@ -352,23 +416,42 @@ def trace_req_start(
is_copy=False, is_copy=False,
) )
# create bootstrap room span
tracer = threads_info[pid].tracer
if str(bootstrap_room) not in remote_trace_contexts:
attrs = {"bootstrap_room": str(hex(bootstrap_room))}
bootstrap_room_span = tracer.start_span(
name=f"Bootstrap Room {hex(bootstrap_room)}",
start_time=ts,
attributes=attrs,
)
reqs_context[rid].bootstrap_room_span = bootstrap_room_span
bootstrap_room_span_context = trace.set_span_in_context(bootstrap_room_span)
else:
bootstrap_room_span_context = remote_trace_contexts[
str(bootstrap_room)
].root_span_context
# Drop the worker_id added by MultiTokenizer # Drop the worker_id added by MultiTokenizer
orig_rid = rid.split("_")[-1] orig_rid = rid.split("_")[-1]
tracer = threads_info[pid].tracer role = "" if role == "null" else role
attrs = {"rid": orig_rid}
root_span = tracer.start_span( root_span = tracer.start_span(
name=f"Req {orig_rid[:8]}", name=f"{role} Req {orig_rid[:8]}",
start_time=ts, start_time=ts,
context=bootstrap_room_span_context,
attributes=attrs,
) )
root_span.set_attributes( root_span.set_attributes(
{ {
"rid": rid, "rid": rid,
"bootstrap_room": bootstrap_room if bootstrap_room else "None",
} }
) )
reqs_context[rid].root_span = root_span reqs_context[rid].root_span = root_span
reqs_context[rid].root_span_context = trace.set_span_in_context(root_span) reqs_context[rid].root_span_context = trace.set_span_in_context(root_span)
reqs_context[rid].bootstrap_room_span_context = bootstrap_room_span_context
# create thread context and thread span # create thread context and thread span
reqs_context[rid].threads_context[pid] = __create_thread_context( reqs_context[rid].threads_context[pid] = __create_thread_context(
...@@ -376,6 +459,10 @@ def trace_req_start( ...@@ -376,6 +459,10 @@ def trace_req_start(
reqs_context[rid].root_span_context, reqs_context[rid].root_span_context,
ts, ts,
) )
if str(bootstrap_room) in remote_trace_contexts:
reqs_context[rid].threads_context[pid].last_span_context = (
remote_trace_contexts[str(bootstrap_room)].prev_span_context
)
def trace_req_finish( def trace_req_finish(
...@@ -399,6 +486,10 @@ def trace_req_finish( ...@@ -399,6 +486,10 @@ def trace_req_finish(
req_context.root_span.set_attributes(attrs) req_context.root_span.set_attributes(attrs)
req_context.root_span.end(end_time=ts) req_context.root_span.end(end_time=ts)
if str(req_context.bootstrap_room) in remote_trace_contexts:
del remote_trace_contexts[str(req_context.bootstrap_room)]
else:
req_context.bootstrap_room_span.end(end_time=ts)
del reqs_context[rid] del reqs_context[rid]
...@@ -518,7 +609,9 @@ trace_slice = trace_slice_end ...@@ -518,7 +609,9 @@ trace_slice = trace_slice_end
# Add event to the current slice on the same thread with the same rid. # Add event to the current slice on the same thread with the same rid.
def trace_event(name: str, rid: str, ts: Optional[int] = None): def trace_event(
name: str, rid: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None
):
if not tracing_enabled: if not tracing_enabled:
return return
...@@ -539,7 +632,7 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None): ...@@ -539,7 +632,7 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
ts = ts or __get_cur_time_ns() ts = ts or __get_cur_time_ns()
slice_info = thread_context.cur_slice_stack[-1] slice_info = thread_context.cur_slice_stack[-1]
slice_info.span.add_event(name=name, timestamp=ts) slice_info.span.add_event(name=name, timestamp=ts, attributes=attrs)
# Add attrs to the current slice on the same thread with the same rid. # Add attrs to the current slice on the same thread with the same rid.
...@@ -569,6 +662,9 @@ def trace_slice_batch( ...@@ -569,6 +662,9 @@ def trace_slice_batch(
name: str, name: str,
reqs: List[Req], reqs: List[Req],
): ):
if not tracing_enabled:
return
for req in reqs: for req in reqs:
trace_slice( trace_slice(
name, name,
...@@ -576,3 +672,16 @@ def trace_slice_batch( ...@@ -576,3 +672,16 @@ def trace_slice_batch(
auto_next_anon=not req.finished(), auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(), thread_finish_flag=req.finished(),
) )
def trace_event_batch(
name: str,
reqs: List[Req],
ts: Optional[int] = None,
attrs: Dict[str, Any] = None,
):
if not tracing_enabled:
return
for req in reqs:
trace_event(name, req.rid, ts=ts, attrs=attrs)
import argparse
import bisect
import json
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple
parser = argparse.ArgumentParser(
description="Convert SGLang OTEL trace files to Perfetto format.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-i",
"--input",
dest="input_file",
required=True,
type=str,
help="Path to the input OTEL trace file (JSON or JSONL format).",
)
parser.add_argument(
"-o",
"--output",
dest="output_file",
type=str,
default="sglang_trace_perfetto.json",
help="Path to the output Perfetto JSON file.",
)
parser.add_argument(
"-f", "--torch-file", dest="torch_file", help="specify torch profile file"
)
args = parser.parse_args()
perfetto_data = None
if args.torch_file:
with open(args.torch_file, "r", encoding="utf-8") as file:
perfetto_data = json.load(file)
baseline = perfetto_data["baseTimeNanoseconds"]
else:
baseline = 0
def id_generator():
i = 0
while True:
yield i
i += 1
relation_id_gen = id_generator()
class SpanLayoutContainer:
def __init__(self):
self.intervals = []
def check_overlap(self, start, end):
idx = bisect.bisect_left(self.intervals, (start, float("-inf")))
if idx > 0:
prev_start, prev_end = self.intervals[idx - 1]
if prev_end > start:
return True
if idx < len(self.intervals):
next_start, next_end = self.intervals[idx]
if next_start < end:
return True
return False
def insert_span(self, start, end):
bisect.insort_left(self.intervals, (start, end))
def new_metadata_level1(name: str, pid):
return {
"name": "process_name",
"ph": "M",
"pid": pid,
"args": {"name": name},
}
def new_metadata_level2(name: str, pid, slot_seq):
return {
"name": "thread_name",
"ph": "M",
"pid": pid,
"tid": slot_seq,
"args": {"name": name},
}
def __find_line(graph, trans_graph_status, slot_meta_data, pid, start, end):
if pid in trans_graph_status:
line = trans_graph_status[pid]
if start == end:
return line
# check conflict
if not graph[pid][line].check_overlap(start, end):
return line
if pid not in graph:
line = 1
graph[pid] = {line: SpanLayoutContainer()}
trans_graph_status[pid] = line
slot_meta_data.append(new_metadata_level2("slot", pid, line))
return line
for line in graph[pid]:
if not graph[pid][line].check_overlap(start, end):
trans_graph_status[pid] = line
return line
new_line = len(graph[pid]) + 1
graph[pid][new_line] = SpanLayoutContainer()
trans_graph_status[pid] = new_line
slot_meta_data.append(new_metadata_level2("slot", pid, new_line))
return new_line
OtelSpan = Dict[str, Any]
def load_otel_data(path: str | Path):
p = Path(path)
with p.open("rt", encoding="utf-8") as f:
first = f.read(1)
f.seek(0)
if first == "[":
data = json.load(f) # JSON array
else:
data = [json.loads(line) for line in f if line.strip()] # JSONL
return data
def extract_all_otel_spans(otel_data):
otel_spans = []
for line_data in otel_data:
for resource_spans in line_data["resourceSpans"]:
for scope_spans in resource_spans["scopeSpans"]:
for span in scope_spans["spans"]:
if "attributes" in span:
attributes_dict = {
attr.get("key"): next(
iter(attr.get("value", {}).values()), None
)
for attr in span["attributes"]
}
span["attributes"] = attributes_dict
else:
span["attributes"] = {}
otel_spans.append(span)
return otel_spans
def build_otel_span_tree(otel_spans):
span_id_map = {span["spanId"]: span for span in otel_spans}
for span in otel_spans:
span["child"] = []
bootstrap_room_spans = []
for span in otel_spans:
span_id = span["spanId"]
parent_span_id = span.get("parentSpanId", "")
if parent_span_id == "":
# check if root span is a request span
attrs = span.get("attributes", {})
bootstrap_room_spans.append(span)
elif parent_span_id in span_id_map:
parent_span = span_id_map[parent_span_id]
parent_span["child"].append(span)
link_spans = []
if "links" in span:
for link in span["links"]:
link_span = span_id_map.get(link["spanId"])
if link_span:
link_spans.append(link_span)
span["links"] = link_spans
return bootstrap_room_spans
def generate_perfetto_span(otel_bootstrap_room_spans, thread_meta_data):
for bootstrap_room_span in otel_bootstrap_room_spans:
bootstrap_room = bootstrap_room_span["attributes"]["bootstrap_room"]
bootstrap_room_span["spans"] = []
for node_req_span in bootstrap_room_span["child"]:
rid = node_req_span["attributes"]["rid"]
for thread_span in node_req_span["child"]:
pid = int(thread_span["attributes"]["pid"])
thread_name = f'{thread_span["attributes"]["host_id"][:8]}:{thread_span["attributes"]["thread_label"]}'
if "tp_rank" in thread_span["attributes"]:
thread_name += f"-TP{thread_span['attributes']['tp_rank']}"
if pid not in thread_meta_data:
thread_meta_data[pid] = new_metadata_level1(thread_name, pid)
for span in thread_span["child"]:
span["attributes"]["bootstrap_room"] = bootstrap_room
span["attributes"]["rid"] = rid
span["host_id"] = thread_span["attributes"]["host_id"]
span["pid"] = pid
span["startTimeUnixNano"] = int(span["startTimeUnixNano"])
span["endTimeUnixNano"] = int(span["endTimeUnixNano"])
ts = span["startTimeUnixNano"]
dur = span["endTimeUnixNano"] - ts
perfetto_span = {
"ph": "X",
"name": span.get("name", "unknown"),
"cat": "sglang",
"ts": (ts - baseline) / 1000.0,
"dur": (dur - 1000) / 1000.0,
"pid": pid,
"tid": 0,
"args": span["attributes"],
}
span["perfetto_span"] = perfetto_span
bootstrap_room_span["spans"].append(span)
def generate_perfetto_span_layout(otel_bootstrap_room_spans, slot_meta_data):
for bootstrap_room_span in otel_bootstrap_room_spans:
bootstrap_room_span["spans"] = sorted(
bootstrap_room_span["spans"], key=lambda x: int(x["startTimeUnixNano"])
)
otel_bootstrap_room_spans = sorted(
otel_bootstrap_room_spans, key=lambda x: int(x["spans"][0]["startTimeUnixNano"])
)
graph = {}
for bootstrap_room_span in otel_bootstrap_room_spans:
req_thread_status = {}
for span in bootstrap_room_span["spans"]:
line = __find_line(
graph,
req_thread_status,
slot_meta_data,
span["perfetto_span"]["pid"],
span["startTimeUnixNano"],
span["endTimeUnixNano"],
)
graph[span["perfetto_span"]["pid"]][line].insert_span(
span["startTimeUnixNano"], span["endTimeUnixNano"]
)
span["perfetto_span"]["tid"] = line
def generate_perfetto_events(otel_bootstrap_room_spans):
for bootstrap_room_span in otel_bootstrap_room_spans:
for span in bootstrap_room_span["spans"]:
span["perfetto_events"] = []
if "events" in span:
for event in span["events"]:
attributes_dict = {
attr.get("key"): next(
iter(attr.get("value", {}).values()), None
)
for attr in event["attributes"]
}
perfetto_event = {
"ph": "i",
"cat": "sglang",
"ts": (int(event["timeUnixNano"]) - baseline) / 1000.0,
"pid": span["perfetto_span"]["pid"],
"tid": span["perfetto_span"]["tid"],
"name": event.get("name", "unknown"),
"args": attributes_dict,
}
span["perfetto_events"].append(perfetto_event)
def generate_perfetto_links(otel_bootstrap_room_spans):
for bootstrap_room_span in otel_bootstrap_room_spans:
for span in bootstrap_room_span["spans"]:
span["perfetto_links"] = []
if "links" in span:
for link_span in span["links"]:
if "correlation" in link_span["perfetto_span"]["args"]:
id = link_span["perfetto_span"]["args"]["correlation"]
else:
id = next(relation_id_gen)
link_span["perfetto_span"]["args"]["correlation"] = id
perfetto_start_node = {
"ph": "s",
"id": id,
"pid": link_span["perfetto_span"]["pid"],
"tid": link_span["perfetto_span"]["tid"],
"ts": link_span["perfetto_span"]["ts"],
"cat": "ac2g",
"name": "ac2g",
}
perfetto_end_node = {
"ph": "f",
"id": id,
"pid": span["perfetto_span"]["pid"],
"tid": span["perfetto_span"]["tid"],
"ts": span["perfetto_span"]["ts"],
"cat": "ac2g",
"name": "ac2g",
"bp": "e",
}
span["perfetto_links"].append(perfetto_start_node)
span["perfetto_links"].append(perfetto_end_node)
def gather_all_perfetto_elems(
otel_bootstrap_room_spans, thread_meta_data, slot_meta_data
):
elems = []
elems.extend(thread_meta_data.values())
elems.extend(slot_meta_data)
for bootstrap_room_span in otel_bootstrap_room_spans:
for span in bootstrap_room_span["spans"]:
elems.append(span["perfetto_span"])
elems.extend(span["perfetto_events"])
elems.extend(span["perfetto_links"])
return elems
def write_json(perfetto_elems):
global perfetto_data
if args.torch_file:
perfetto_data["traceEvents"].extend(perfetto_elems)
filered_data = [
item
for item in perfetto_data["traceEvents"]
if item.get("cat") != "gpu_user_annotation"
]
perfetto_data["traceEvents"] = filered_data
else:
perfetto_data = perfetto_elems
with open(args.output_file, "w", encoding="utf-8") as file:
json.dump(perfetto_data, file, ensure_ascii=False, indent=4)
def main():
start_time = time.time()
otel_data = load_otel_data(args.input_file)
otel_spans = extract_all_otel_spans(otel_data)
otel_bootstrap_room_spans = build_otel_span_tree(otel_spans)
thread_meta_data = {}
generate_perfetto_span(otel_bootstrap_room_spans, thread_meta_data)
slot_meta_data = []
generate_perfetto_span_layout(otel_bootstrap_room_spans, slot_meta_data)
generate_perfetto_events(otel_bootstrap_room_spans)
generate_perfetto_links(otel_bootstrap_room_spans)
perfetto_elems = gather_all_perfetto_elems(
otel_bootstrap_room_spans, thread_meta_data, slot_meta_data
)
write_json(perfetto_elems)
end_time = time.time()
execution_time = end_time - start_time
print(f"\nConversion finished successfully!")
print(f"Output written to: {args.output_file}")
print(f"Execution time: {execution_time * 1000:.4f} ms")
if __name__ == "__main__":
main()
...@@ -38,9 +38,22 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -38,9 +38,22 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
router_args = args router_args = args
if router_args.mini_lb: if router_args.mini_lb:
if router_args.enable_trace:
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_set_thread_info,
)
process_tracing_init(router_args.otlp_traces_endpoint, "sglang")
trace_set_thread_info("Mini lb")
mini_lb = MiniLoadBalancer(router_args) mini_lb = MiniLoadBalancer(router_args)
mini_lb.start() mini_lb.start()
else: else:
# TODO: support tracing for router(Rust).
del router_args.enable_trace
del router_args.otlp_traces_endpoint
if Router is None: if Router is None:
raise RuntimeError("Rust Router is not installed") raise RuntimeError("Rust Router is not installed")
router_args._validate_router_args() router_args._validate_router_args()
......
...@@ -18,6 +18,14 @@ from fastapi import FastAPI, HTTPException ...@@ -18,6 +18,14 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang_router.router_args import RouterArgs from sglang_router.router_args import RouterArgs
from sglang.srt.tracing.trace import (
trace_get_remote_propagate_context,
trace_req_finish,
trace_req_start,
trace_slice_end,
trace_slice_start,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AIOHTTP_STREAM_READ_CHUNK_SIZE = ( AIOHTTP_STREAM_READ_CHUNK_SIZE = (
...@@ -46,6 +54,7 @@ class MiniLoadBalancer: ...@@ -46,6 +54,7 @@ class MiniLoadBalancer:
self.prefill_urls = [url[0] for url in router_args.prefill_urls] self.prefill_urls = [url[0] for url in router_args.prefill_urls]
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls] self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
self.decode_urls = router_args.decode_urls self.decode_urls = router_args.decode_urls
self.enable_trace = router_args.enable_trace
def _validate_router_args(self, router_args: RouterArgs): def _validate_router_args(self, router_args: RouterArgs):
logger.warning( logger.warning(
...@@ -91,11 +100,33 @@ class MiniLoadBalancer: ...@@ -91,11 +100,33 @@ class MiniLoadBalancer:
total=self.timeout total=self.timeout
) # Add timeout for request reliability ) # Add timeout for request reliability
) as session: ) as session:
headers = {}
bootstrap_room_list = []
if self.enable_trace:
bootstrap_room_list = (
modified_request["bootstrap_room"]
if isinstance(modified_request["bootstrap_room"], list)
else [modified_request["bootstrap_room"]]
)
trace_context = trace_get_remote_propagate_context(bootstrap_room_list)
headers = {"trace_context": trace_context}
tasks = [ tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request), session.post(
session.post(f"{decode_server}/{endpoint}", json=modified_request), f"{prefill_server}/{endpoint}",
json=modified_request,
headers=headers,
),
session.post(
f"{decode_server}/{endpoint}",
json=modified_request,
headers=headers,
),
] ]
for bootstrap_room in bootstrap_room_list:
trace_slice_end("mini_lb_launch", bootstrap_room, auto_next_anon=True)
# Wait for both responses to complete. Prefill should end first. # Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks) prefill_response, decode_response = await asyncio.gather(*tasks)
...@@ -114,6 +145,14 @@ class MiniLoadBalancer: ...@@ -114,6 +145,14 @@ class MiniLoadBalancer:
else: else:
ret_json = await decode_response.json() ret_json = await decode_response.json()
for bootstrap_room in bootstrap_room_list:
trace_slice_end(
"wait_PD_finish",
bootstrap_room,
thread_finish_flag=True,
)
trace_req_finish(bootstrap_room)
return ORJSONResponse( return ORJSONResponse(
content=ret_json, content=ret_json,
status_code=decode_response.status, status_code=decode_response.status,
...@@ -131,10 +170,36 @@ class MiniLoadBalancer: ...@@ -131,10 +170,36 @@ class MiniLoadBalancer:
) # Add timeout for request reliability ) # Add timeout for request reliability
) as session: ) as session:
# Create the tasks for both prefill and decode requests # Create the tasks for both prefill and decode requests
headers = {}
bootstrap_room_list = []
if self.enable_trace:
bootstrap_room_list = (
modified_request["bootstrap_room"]
if isinstance(modified_request["bootstrap_room"], list)
else [modified_request["bootstrap_room"]]
)
trace_context = trace_get_remote_propagate_context(
bootstrap_room_list
)
headers = {"trace_context": trace_context}
tasks = [ tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request), session.post(
session.post(f"{decode_server}/{endpoint}", json=modified_request), f"{prefill_server}/{endpoint}",
json=modified_request,
headers=headers,
),
session.post(
f"{decode_server}/{endpoint}",
json=modified_request,
headers=headers,
),
] ]
for bootstrap_room in bootstrap_room_list:
trace_slice_end(
"mini_lb_launch", bootstrap_room, auto_next_anon=True
)
# Wait for both responses to complete. Since this is streaming, they return immediately. # Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks) prefill_response, decode_response = await asyncio.gather(*tasks)
...@@ -174,6 +239,14 @@ class MiniLoadBalancer: ...@@ -174,6 +239,14 @@ class MiniLoadBalancer:
): ):
yield chunk yield chunk
for bootstrap_room in bootstrap_room_list:
trace_slice_end(
"wait_PD_finish",
bootstrap_room,
thread_finish_flag=True,
)
trace_req_finish(bootstrap_room)
return StreamingResponse( return StreamingResponse(
stream_results(), stream_results(),
media_type="text/event-stream", media_type="text/event-stream",
...@@ -367,7 +440,10 @@ async def handle_completion_request(request_data: dict): ...@@ -367,7 +440,10 @@ async def handle_completion_request(request_data: dict):
def _generate_bootstrap_room(): def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1) bootstrap_room = random.randint(0, 2**63 - 1)
trace_req_start(bootstrap_room, bootstrap_room, role="router")
trace_slice_start("mini_lb_launch", bootstrap_room)
return bootstrap_room
# We may utilize `GenerateReqInput`'s logic later # We may utilize `GenerateReqInput`'s logic later
......
...@@ -112,6 +112,9 @@ class RouterArgs: ...@@ -112,6 +112,9 @@ class RouterArgs:
client_cert_path: Optional[str] = None client_cert_path: Optional[str] = None
client_key_path: Optional[str] = None client_key_path: Optional[str] = None
ca_cert_paths: List[str] = dataclasses.field(default_factory=list) ca_cert_paths: List[str] = dataclasses.field(default_factory=list)
# Trace
enable_trace: bool = False
otlp_traces_endpoint: str = "localhost:4317"
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
...@@ -608,6 +611,17 @@ class RouterArgs: ...@@ -608,6 +611,17 @@ class RouterArgs:
default=[], default=[],
help="Path(s) to CA certificate(s) for verifying worker TLS certificates. Can specify multiple CAs.", help="Path(s) to CA certificate(s) for verifying worker TLS certificates. Can specify multiple CAs.",
) )
parser.add_argument(
f"--{prefix}enable-trace",
action="store_true",
help="Enable opentelemetry trace",
)
parser.add_argument(
f"--{prefix}otlp-traces-endpoint",
type=str,
default="localhost:4317",
help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
)
@classmethod @classmethod
def from_cli_args( def from_cli_args(
......
...@@ -74,7 +74,7 @@ class TestTrace(CustomTestCase): ...@@ -74,7 +74,7 @@ class TestTrace(CustomTestCase):
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--enable-trace", "--oltp-traces-endpoint", "0.0.0.0:4317"], other_args=["--enable-trace", "--otlp-traces-endpoint", "0.0.0.0:4317"],
) )
try: try:
...@@ -121,7 +121,7 @@ class TestTrace(CustomTestCase): ...@@ -121,7 +121,7 @@ class TestTrace(CustomTestCase):
model_path=model_path, model_path=model_path,
random_seed=42, random_seed=42,
enable_trace=True, enable_trace=True,
oltp_traces_endpoint="localhost:4317", otlp_traces_endpoint="localhost:4317",
) )
try: try:
...@@ -148,7 +148,7 @@ class TestTrace(CustomTestCase): ...@@ -148,7 +148,7 @@ class TestTrace(CustomTestCase):
model_path=model_path, model_path=model_path,
random_seed=42, random_seed=42,
enable_trace=True, enable_trace=True,
oltp_traces_endpoint="localhost:4317", otlp_traces_endpoint="localhost:4317",
is_embedding=True, is_embedding=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment