Unverified Commit ea961060 authored by Feng Su's avatar Feng Su Committed by GitHub
Browse files

[Feature] Sglang Tracing: Fine-Grained Tracking for Request Latency - Part 2 (#10804)


Signed-off-by: default avatarFeng Su <sufeng@linux.alibaba.com>
parent b1e13e7c
......@@ -95,6 +95,8 @@ SGLang supports various environment variables that can be used to configure its
| --- | --- | --- |
| `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` |
| `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` |
| `SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS` | Config BatchSpanProcessor.schedule_delay_millis if tracing is enabled | `500` |
| `SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE` | Config BatchSpanProcessor.max_export_batch_size if tracing is enabled | `64` |
## Storage & Caching
......
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server.
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--otlp-traces-endpoint` when launching the server.
You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965.
......@@ -22,7 +22,13 @@ This section explains how to configure the request tracing and export the trace
3. start your SGLang server with tracing enabled
```bash
python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 <other option>
# set env variables
export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500
export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64
# start the prefill and decode server
python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
# start the mini lb
python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 <other option>
```
Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.
......@@ -39,9 +45,9 @@ We have already inserted instrumentation points in the tokenizer and scheduler m
Every process involved in tracing during the initialization phase should execute:
```python
process_tracing_init(oltp_traces_endpoint, server_name)
process_tracing_init(otlp_traces_endpoint, server_name)
```
The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.
The otlp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.
Every thread involved in tracing during the initialization phase should execute:
```python
......@@ -95,24 +101,52 @@ We have already inserted instrumentation points in the tokenizer and scheduler m
trace_set_proc_propagate_context(rid, req.trace_context)
```
5. When the request execution flow transfers to another node(PD disaggregation), the trace context needs to be explicitly propagated.
- sender: Execute the following code before sending the request to node thread via http
```python
trace_context = trace_get_remote_propagate_context(bootstrap_room_list)
headers = {"trace_context": trace_context}
session.post(url, headers=headers)
```
- receiver: Execute the following code after receiving the request via http
```python
trace_set_remote_propagate_context(request.headers['trace_context'])
```
## How to Extend the Tracing Framework to Support Complex Tracing Scenarios
The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.
The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure.
The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows:
The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a two-level trace context structure and a four-level span structure: `SglangTraceReqContext`, `SglangTraceThreadContext`. Their relationship is as follows:
```
SglangTraceReqContext (req_id="req-123")
├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0)
│ └── SglangTraceSliceContext (name="prefill") # cur slice
|
└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1)
└── SglangTraceSliceContext (name="prefill") # cur slice
```
Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends.
Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is stored in a list.
In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.
When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span.
We designed a four-level span structure, consisting of `bootstrap_room_span`, `req_root_span`, `thread_span`, and `slice_span`. Among them, `req_root_span` and `thread_span` correspond to `SglangTraceReqContext` and `SglangTraceThreadContext`, respectively, and `slice_span` is stored within the `SglangTraceThreadContext`. The `bootstrap_room_span` is designed to accommodate the separation of PD-disaggregation. On different nodes, we may want to add certain attributes to the `req_root_span`. However, if the `req_root_span` is shared across all nodes, the Prefill and Decode nodes would not be allowed to add attributes due to the constraints imposed by OpenTelemetry's design.
```
bootstrap room span
├── router req root span
| └── router thread span
| └── slice span
├── prefill req root span
| ├── tokenizer thread span
| | └── slice span
| └── scheduler thread span
| └── slice span
└── decode req root span
├── tokenizer thread span
| └── slice span
└── scheduler thread span
└── slice span
```
......@@ -58,6 +58,11 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.tracing.trace import (
trace_event_batch,
trace_slice_batch,
trace_slice_end,
)
from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -313,6 +318,7 @@ class DecodePreallocQueue:
)
req.add_latency(RequestStage.DECODE_PREPARE)
trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
......@@ -527,6 +533,9 @@ class DecodePreallocQueue:
time.perf_counter()
)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
trace_slice_end(
RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
......@@ -775,8 +784,19 @@ class DecodeTransferQueue:
[decode_req.req], decode_req.req.return_logprob
)
self.tree_cache.cache_finished_req(decode_req.req)
trace_slice_end(
RequestStage.DECODE_QUICK_FINISH,
decode_req.req.rid,
thread_finish_flag=True,
)
else:
transferred_reqs.append(decode_req.req)
trace_slice_end(
RequestStage.DECODE_TRANSFERRED,
decode_req.req.rid,
auto_next_anon=True,
)
elif poll in [
KVPoll.Bootstrapping,
KVPoll.WaitingForInput,
......@@ -822,6 +842,7 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
if prepare_mlp_sync_flag:
self._prepare_idle_batch_and_run(None)
else:
......@@ -871,6 +892,7 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
if prepare_mlp_sync_flag:
batch_, batch_result = self._prepare_idle_batch_and_run(
None, delay_process=True
......@@ -953,6 +975,9 @@ class SchedulerDisaggregationDecodeMixin:
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None
if ret:
attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
trace_event_batch("schedule", ret.reqs, attrs=attrs)
return ret
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
......
......@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
NSATokenToKVPool,
SWAKVPool,
)
from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
if TYPE_CHECKING:
......@@ -198,6 +199,7 @@ class PrefillBootstrapQueue:
self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req)
trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True)
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
for req in reqs:
......@@ -289,6 +291,10 @@ class PrefillBootstrapQueue:
req.time_stats.wait_queue_entry_time = time.perf_counter()
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
trace_slice_end(
RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True
)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
......@@ -316,6 +322,9 @@ class SchedulerDisaggregationPrefillMixin:
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if batch:
attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
trace_event_batch("schedule", batch.reqs, attrs=attrs)
if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
......@@ -348,6 +357,9 @@ class SchedulerDisaggregationPrefillMixin:
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if batch:
attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
trace_event_batch("schedule", batch.reqs, attrs=attrs)
if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
......@@ -423,6 +435,7 @@ class SchedulerDisaggregationPrefillMixin:
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD)
trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True)
self.disagg_prefill_inflight_queue.append(req)
if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
req.output_topk_p = batch.spec_info.topk_p[i]
......@@ -487,6 +500,9 @@ class SchedulerDisaggregationPrefillMixin:
if self.enable_overlap:
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
trace_slice(
RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True
)
self.maybe_send_health_check_signal()
......@@ -558,6 +574,9 @@ class SchedulerDisaggregationPrefillMixin:
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1
trace_slice(
RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True
)
self.disagg_prefill_inflight_queue = undone_reqs
......
......@@ -143,10 +143,13 @@ class Engine(EngineBase):
# Enable tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
thread_label = "Tokenizer"
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill Tokenizer"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode Tokenizer"
trace_set_thread_info(thread_label)
try:
self.loop = asyncio.get_running_loop()
......
......@@ -220,9 +220,12 @@ async def lifespan(fast_api_app: FastAPI):
# Init tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
trace_set_thread_info(thread_label)
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill" + thread_label
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode" + thread_label
trace_set_thread_info(thread_label)
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
......
......@@ -129,6 +129,8 @@ class Envs:
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS = EnvInt(500)
SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE = EnvInt(64)
# Scheduler: memory leak test
SGLANG_TEST_RETRACT = EnvBool(False)
......
......@@ -34,13 +34,21 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
WatchLoadUpdateReq,
)
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.schedule_batch import Req, RequestStage
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import (
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
PortArgs,
ServerArgs,
)
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_get_proc_propagate_context,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice_end,
trace_slice_start,
)
from sglang.srt.utils import (
bind_port,
configure_logger,
......@@ -170,11 +178,22 @@ class DataParallelController:
def handle_load_update_req(self, obj):
self.dp_budget.update_budget(obj)
def dispatching_with_trace(self, req: Req):
if self.server_args.enable_trace:
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start(RequestStage.DC_DISPATCH, req.rid)
req.trace_context = trace_get_proc_propagate_context(req.rid)
self.dispatching(req)
if self.server_args.enable_trace:
trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True)
def init_dispatcher(self):
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.dispatching),
(TokenizedEmbeddingReqInput, self.dispatching),
(TokenizedGenerateReqInput, self.dispatching_with_trace),
(TokenizedEmbeddingReqInput, self.dispatching_with_trace),
(BlockReqInput, self.send_to_all_workers),
(WatchLoadUpdateReq, self.handle_load_update_req),
]
......@@ -487,6 +506,14 @@ def run_data_parallel_controller_process(
pipe_writer,
):
kill_itself_when_parent_died()
if server_args.enable_trace:
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
thread_label = "DP Controller"
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill DP Controller"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode DP Controller"
trace_set_thread_info(thread_label)
setproctitle.setproctitle("sglang::data_parallel_controller")
faulthandler.enable()
configure_logger(server_args)
......
......@@ -396,13 +396,23 @@ class MultimodalInputs:
class RequestStage(str, enum.Enum):
# prefill
# Tokenizer
TOKENIZE = "tokenize"
TOKENIZER_DISPATCH = "dispatch"
# DP controller
DC_DISPATCH = "dc_dispatch"
# common/non-disaggregation
PREFILL_WAITING = "prefill_waiting"
REQUEST_PROCESS = "request_process"
DECODE_LOOP = "decode_loop"
PREFILL_FORWARD = "prefill_forward"
PREFILL_CHUNKED_FORWARD = "chunked_prefill"
# disaggregation prefill
PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
# disaggregation decode
......@@ -410,6 +420,8 @@ class RequestStage(str, enum.Enum):
DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred"
DECODE_FAKE_OUTPUT = "fake_output"
DECODE_QUICK_FINISH = "quick_finish"
class Req:
......
......@@ -157,6 +157,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_event_batch,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice_batch,
......@@ -1354,7 +1355,7 @@ class Scheduler(
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
req.time_stats.wait_queue_entry_time = time.perf_counter()
trace_slice_end("process req", req.rid, auto_next_anon=True)
trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add(
......@@ -1618,6 +1619,10 @@ class Scheduler(
if need_dp_attn_preparation:
ret = self.prepare_mlp_sync_batch(ret)
if ret:
attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
trace_event_batch("schedule", ret.reqs, attrs=attrs)
return ret
def get_num_allocatable_reqs(self, running_bs):
......@@ -1993,13 +1998,10 @@ class Scheduler(
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if self.enable_trace:
trace_slice_batch("decode loop", batch.reqs)
trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
if self.enable_trace:
trace_slice_batch("prefill", batch.reqs)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
......@@ -2741,10 +2743,13 @@ def run_scheduler_process(
# Set up tracing
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
thread_label = "Scheduler"
if server_args.disaggregation_mode == "prefill":
thread_label = "Prefill Scheduler"
elif server_args.disaggregation_mode == "decode":
thread_label = "Decode Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)
# Create a scheduler and run the event loop
try:
......
......@@ -14,7 +14,13 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOutput,
BatchTokenIDOutput,
)
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
BaseFinishReason,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.tracing.trace import trace_slice
from sglang.srt.utils.common import ceil_div
if TYPE_CHECKING:
......@@ -160,6 +166,14 @@ class SchedulerOutputProcessorMixin:
)
self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished()
trace_slice(
RequestStage.PREFILL_FORWARD,
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
......@@ -188,6 +202,12 @@ class SchedulerOutputProcessorMixin:
)
logprob_pt += num_input_logprobs
trace_slice(
RequestStage.PREFILL_CHUNKED_FORWARD,
req.rid,
auto_next_anon=True,
)
else: # embedding or reward model
is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
......@@ -227,6 +247,13 @@ class SchedulerOutputProcessorMixin:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
trace_slice(
RequestStage.PREFILL_FORWARD,
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def _resolve_spec_overlap_token_ids(
......
......@@ -68,6 +68,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.schedule_batch import RequestStage
from sglang.srt.managers.scheduler import is_health_check_generate_req
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
......@@ -79,6 +80,7 @@ from sglang.srt.tracing.trace import (
trace_get_proc_propagate_context,
trace_req_finish,
trace_req_start,
trace_set_remote_propagate_context,
trace_slice_end,
trace_slice_start,
)
......@@ -383,6 +385,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if request:
if "trace_context" in request.headers:
trace_set_remote_propagate_context(request.headers["trace_context"])
if self.server_args.tokenizer_worker_num > 1:
self._attach_multi_http_worker_info(obj)
......@@ -605,7 +611,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
mm_inputs = None
self._validate_one_request(obj, input_ids)
trace_slice_end("tokenize", obj.rid)
trace_slice_end(RequestStage.TOKENIZE, obj.rid)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
)
......@@ -831,7 +837,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
req, req.text, input_ids_list[i], None, None, token_type_ids
)
)
trace_slice_end("tokenize", req.rid)
trace_slice_end(RequestStage.TOKENIZE, req.rid)
logger.debug(f"Completed batch processing for {batch_size} requests")
return tokenized_objs
......@@ -883,12 +889,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None,
):
trace_slice_start("dispatch", obj.rid)
trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid)
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
self.rid_to_state[obj.rid] = state
trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
trace_slice_end(
RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True
)
return state
def _send_batch_request(
......@@ -2131,7 +2139,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
bootstrap_room = (
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
)
trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
trace_req_start(
obj.rid,
bootstrap_room,
ts=int(created_time * 1e9),
role=self.server_args.disaggregation_mode,
)
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
else:
for i in range(len(obj.rid)):
......@@ -2140,7 +2153,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
else None
)
trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
trace_req_start(
obj.rid[i],
bootstrap_room,
ts=int(created_time * 1e9),
role=self.server_args.disaggregation_mode,
)
trace_slice_start(
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
)
......
......@@ -299,7 +299,7 @@ class ServerArgs:
enable_request_time_stats_logging: bool = False
kv_events_config: Optional[str] = None
enable_trace: bool = False
oltp_traces_endpoint: str = "localhost:4317"
otlp_traces_endpoint: str = "localhost:4317"
# API related
api_key: Optional[str] = None
......@@ -2340,7 +2340,7 @@ class ServerArgs:
help="Enable opentelemetry trace",
)
parser.add_argument(
"--oltp-traces-endpoint",
"--otlp-traces-endpoint",
type=str,
default="localhost:4317",
help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
......
......@@ -15,6 +15,8 @@
from __future__ import annotations
import base64
import json
import logging
import os
import random
......@@ -24,6 +26,8 @@ import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Req
......@@ -85,6 +89,8 @@ class SglangTraceReqContext:
# Indicates whether this instance is a replica from the main process.
# When True, root_span is None and only root_span_context is preserved.
is_copy: bool = False
bootstrap_room_span: Optional[trace.span.Span] = None
bootstrap_room_span_context: Optional[context.Context] = None
root_span: Optional[trace.span.Span] = None
root_span_context: Optional[context.Context] = None
......@@ -96,8 +102,7 @@ class SglangTracePropagateContext:
def to_dict(self):
carrier: dict[str, str] = {}
context.attach(self.root_span_context)
propagate.inject(carrier)
propagate.inject(carrier, self.root_span_context)
if self.prev_span_context:
return {
......@@ -149,6 +154,7 @@ class SglangTraceCustomIdGenerator(id_generator.IdGenerator):
# global variables
remote_trace_contexts: Dict[str, SglangTracePropagateContext] = {}
threads_info: Dict[int, SglangTraceThreadInfo] = {}
reqs_context: Dict[str, SglangTraceReqContext] = {}
......@@ -193,8 +199,17 @@ def process_tracing_init(otlp_endpoint, server_name):
resource=resource, id_generator=SglangTraceCustomIdGenerator()
)
schedule_delay_millis = get_int_env_var(
"SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", 500
)
max_export_batch_size = get_int_env_var(
"SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", 64
)
processor = BatchSpanProcessor(
OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True),
schedule_delay_millis=schedule_delay_millis,
max_export_batch_size=max_export_batch_size,
)
tracer_provider.add_span_processor(processor)
trace.set_tracer_provider(tracer_provider)
......@@ -266,7 +281,9 @@ def __create_thread_context(pid, req_span_context, ts: Optional[int] = None):
return thread_context
def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
def trace_get_proc_propagate_context(
rid, remote_propagate=False
) -> Optional[Dict[str, Any]]:
if not tracing_enabled:
return None
......@@ -283,9 +300,11 @@ def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
elif thread_context.last_span_context:
prev_span_context = thread_context.last_span_context
trace_context = SglangTracePropagateContext(
reqs_context[rid].root_span_context, prev_span_context
)
root_span_context = reqs_context[rid].root_span_context
if remote_propagate:
root_span_context = reqs_context[rid].bootstrap_room_span_context
trace_context = SglangTracePropagateContext(root_span_context, prev_span_context)
return trace_context.to_dict()
......@@ -327,10 +346,54 @@ def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]
].last_span_context = trace_context.prev_span_context
def trace_get_remote_propagate_context(bootstrap_room_list: List[str]):
if not tracing_enabled:
return ""
reqs_trace_contexts = {}
for bootstrap_room in bootstrap_room_list:
# In the router, rid is also the bootstrap room.
bootstrap_room = str(bootstrap_room)
if bootstrap_room not in reqs_context:
continue
_context = trace_get_proc_propagate_context(
bootstrap_room, remote_propagate=True
)
reqs_trace_contexts[bootstrap_room] = _context
json_str = json.dumps(reqs_trace_contexts, ensure_ascii=False)
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
def trace_set_remote_propagate_context(base64_str):
if not tracing_enabled:
return
if base64_str is None or base64_str == "" or base64_str == "None":
return
base64_bytes = base64.b64decode(base64_str)
json_str = base64_bytes.decode("utf-8")
remote_reqs_trace_contexts = json.loads(json_str)
for bootstrap_room in remote_reqs_trace_contexts:
if bootstrap_room in remote_trace_contexts:
continue
remote_trace_contexts[bootstrap_room] = (
SglangTracePropagateContext.instance_from_dict(
remote_reqs_trace_contexts[bootstrap_room]
)
)
def trace_req_start(
rid: str,
bootstrap_room: Optional[int] = None,
ts: Optional[int] = None,
role: Optional[str] = "null",
):
if not tracing_enabled:
return
......@@ -344,6 +407,7 @@ def trace_req_start(
return
# create req context and root span
bootstrap_room = 0 if bootstrap_room is None else bootstrap_room
reqs_context[rid] = SglangTraceReqContext(
rid=rid,
start_time_ns=ts,
......@@ -352,23 +416,42 @@ def trace_req_start(
is_copy=False,
)
# create bootstrap room span
tracer = threads_info[pid].tracer
if str(bootstrap_room) not in remote_trace_contexts:
attrs = {"bootstrap_room": str(hex(bootstrap_room))}
bootstrap_room_span = tracer.start_span(
name=f"Bootstrap Room {hex(bootstrap_room)}",
start_time=ts,
attributes=attrs,
)
reqs_context[rid].bootstrap_room_span = bootstrap_room_span
bootstrap_room_span_context = trace.set_span_in_context(bootstrap_room_span)
else:
bootstrap_room_span_context = remote_trace_contexts[
str(bootstrap_room)
].root_span_context
# Drop the worker_id added by MultiTokenizer
orig_rid = rid.split("_")[-1]
tracer = threads_info[pid].tracer
role = "" if role == "null" else role
attrs = {"rid": orig_rid}
root_span = tracer.start_span(
name=f"Req {orig_rid[:8]}",
name=f"{role} Req {orig_rid[:8]}",
start_time=ts,
context=bootstrap_room_span_context,
attributes=attrs,
)
root_span.set_attributes(
{
"rid": rid,
"bootstrap_room": bootstrap_room if bootstrap_room else "None",
}
)
reqs_context[rid].root_span = root_span
reqs_context[rid].root_span_context = trace.set_span_in_context(root_span)
reqs_context[rid].bootstrap_room_span_context = bootstrap_room_span_context
# create thread context and thread span
reqs_context[rid].threads_context[pid] = __create_thread_context(
......@@ -376,6 +459,10 @@ def trace_req_start(
reqs_context[rid].root_span_context,
ts,
)
if str(bootstrap_room) in remote_trace_contexts:
reqs_context[rid].threads_context[pid].last_span_context = (
remote_trace_contexts[str(bootstrap_room)].prev_span_context
)
def trace_req_finish(
......@@ -399,6 +486,10 @@ def trace_req_finish(
req_context.root_span.set_attributes(attrs)
req_context.root_span.end(end_time=ts)
if str(req_context.bootstrap_room) in remote_trace_contexts:
del remote_trace_contexts[str(req_context.bootstrap_room)]
else:
req_context.bootstrap_room_span.end(end_time=ts)
del reqs_context[rid]
......@@ -518,7 +609,9 @@ trace_slice = trace_slice_end
# Add event to the current slice on the same thread with the same rid.
def trace_event(name: str, rid: str, ts: Optional[int] = None):
def trace_event(
name: str, rid: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None
):
if not tracing_enabled:
return
......@@ -539,7 +632,7 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
ts = ts or __get_cur_time_ns()
slice_info = thread_context.cur_slice_stack[-1]
slice_info.span.add_event(name=name, timestamp=ts)
slice_info.span.add_event(name=name, timestamp=ts, attributes=attrs)
# Add attrs to the current slice on the same thread with the same rid.
......@@ -569,6 +662,9 @@ def trace_slice_batch(
name: str,
reqs: List[Req],
):
if not tracing_enabled:
return
for req in reqs:
trace_slice(
name,
......@@ -576,3 +672,16 @@ def trace_slice_batch(
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
def trace_event_batch(
name: str,
reqs: List[Req],
ts: Optional[int] = None,
attrs: Dict[str, Any] = None,
):
if not tracing_enabled:
return
for req in reqs:
trace_event(name, req.rid, ts=ts, attrs=attrs)
import argparse
import bisect
import json
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple
parser = argparse.ArgumentParser(
description="Convert SGLang OTEL trace files to Perfetto format.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-i",
"--input",
dest="input_file",
required=True,
type=str,
help="Path to the input OTEL trace file (JSON or JSONL format).",
)
parser.add_argument(
"-o",
"--output",
dest="output_file",
type=str,
default="sglang_trace_perfetto.json",
help="Path to the output Perfetto JSON file.",
)
parser.add_argument(
"-f", "--torch-file", dest="torch_file", help="specify torch profile file"
)
args = parser.parse_args()
perfetto_data = None
if args.torch_file:
with open(args.torch_file, "r", encoding="utf-8") as file:
perfetto_data = json.load(file)
baseline = perfetto_data["baseTimeNanoseconds"]
else:
baseline = 0
def id_generator():
i = 0
while True:
yield i
i += 1
relation_id_gen = id_generator()
class SpanLayoutContainer:
def __init__(self):
self.intervals = []
def check_overlap(self, start, end):
idx = bisect.bisect_left(self.intervals, (start, float("-inf")))
if idx > 0:
prev_start, prev_end = self.intervals[idx - 1]
if prev_end > start:
return True
if idx < len(self.intervals):
next_start, next_end = self.intervals[idx]
if next_start < end:
return True
return False
def insert_span(self, start, end):
bisect.insort_left(self.intervals, (start, end))
def new_metadata_level1(name: str, pid):
return {
"name": "process_name",
"ph": "M",
"pid": pid,
"args": {"name": name},
}
def new_metadata_level2(name: str, pid, slot_seq):
return {
"name": "thread_name",
"ph": "M",
"pid": pid,
"tid": slot_seq,
"args": {"name": name},
}
def __find_line(graph, trans_graph_status, slot_meta_data, pid, start, end):
if pid in trans_graph_status:
line = trans_graph_status[pid]
if start == end:
return line
# check conflict
if not graph[pid][line].check_overlap(start, end):
return line
if pid not in graph:
line = 1
graph[pid] = {line: SpanLayoutContainer()}
trans_graph_status[pid] = line
slot_meta_data.append(new_metadata_level2("slot", pid, line))
return line
for line in graph[pid]:
if not graph[pid][line].check_overlap(start, end):
trans_graph_status[pid] = line
return line
new_line = len(graph[pid]) + 1
graph[pid][new_line] = SpanLayoutContainer()
trans_graph_status[pid] = new_line
slot_meta_data.append(new_metadata_level2("slot", pid, new_line))
return new_line
OtelSpan = Dict[str, Any]
def load_otel_data(path: str | Path):
p = Path(path)
with p.open("rt", encoding="utf-8") as f:
first = f.read(1)
f.seek(0)
if first == "[":
data = json.load(f) # JSON array
else:
data = [json.loads(line) for line in f if line.strip()] # JSONL
return data
def extract_all_otel_spans(otel_data):
otel_spans = []
for line_data in otel_data:
for resource_spans in line_data["resourceSpans"]:
for scope_spans in resource_spans["scopeSpans"]:
for span in scope_spans["spans"]:
if "attributes" in span:
attributes_dict = {
attr.get("key"): next(
iter(attr.get("value", {}).values()), None
)
for attr in span["attributes"]
}
span["attributes"] = attributes_dict
else:
span["attributes"] = {}
otel_spans.append(span)
return otel_spans
def build_otel_span_tree(otel_spans):
span_id_map = {span["spanId"]: span for span in otel_spans}
for span in otel_spans:
span["child"] = []
bootstrap_room_spans = []
for span in otel_spans:
span_id = span["spanId"]
parent_span_id = span.get("parentSpanId", "")
if parent_span_id == "":
# check if root span is a request span
attrs = span.get("attributes", {})
bootstrap_room_spans.append(span)
elif parent_span_id in span_id_map:
parent_span = span_id_map[parent_span_id]
parent_span["child"].append(span)
link_spans = []
if "links" in span:
for link in span["links"]:
link_span = span_id_map.get(link["spanId"])
if link_span:
link_spans.append(link_span)
span["links"] = link_spans
return bootstrap_room_spans
def generate_perfetto_span(otel_bootstrap_room_spans, thread_meta_data):
for bootstrap_room_span in otel_bootstrap_room_spans:
bootstrap_room = bootstrap_room_span["attributes"]["bootstrap_room"]
bootstrap_room_span["spans"] = []
for node_req_span in bootstrap_room_span["child"]:
rid = node_req_span["attributes"]["rid"]
for thread_span in node_req_span["child"]:
pid = int(thread_span["attributes"]["pid"])
thread_name = f'{thread_span["attributes"]["host_id"][:8]}:{thread_span["attributes"]["thread_label"]}'
if "tp_rank" in thread_span["attributes"]:
thread_name += f"-TP{thread_span['attributes']['tp_rank']}"
if pid not in thread_meta_data:
thread_meta_data[pid] = new_metadata_level1(thread_name, pid)
for span in thread_span["child"]:
span["attributes"]["bootstrap_room"] = bootstrap_room
span["attributes"]["rid"] = rid
span["host_id"] = thread_span["attributes"]["host_id"]
span["pid"] = pid
span["startTimeUnixNano"] = int(span["startTimeUnixNano"])
span["endTimeUnixNano"] = int(span["endTimeUnixNano"])
ts = span["startTimeUnixNano"]
dur = span["endTimeUnixNano"] - ts
perfetto_span = {
"ph": "X",
"name": span.get("name", "unknown"),
"cat": "sglang",
"ts": (ts - baseline) / 1000.0,
"dur": (dur - 1000) / 1000.0,
"pid": pid,
"tid": 0,
"args": span["attributes"],
}
span["perfetto_span"] = perfetto_span
bootstrap_room_span["spans"].append(span)
def generate_perfetto_span_layout(otel_bootstrap_room_spans, slot_meta_data):
for bootstrap_room_span in otel_bootstrap_room_spans:
bootstrap_room_span["spans"] = sorted(
bootstrap_room_span["spans"], key=lambda x: int(x["startTimeUnixNano"])
)
otel_bootstrap_room_spans = sorted(
otel_bootstrap_room_spans, key=lambda x: int(x["spans"][0]["startTimeUnixNano"])
)
graph = {}
for bootstrap_room_span in otel_bootstrap_room_spans:
req_thread_status = {}
for span in bootstrap_room_span["spans"]:
line = __find_line(
graph,
req_thread_status,
slot_meta_data,
span["perfetto_span"]["pid"],
span["startTimeUnixNano"],
span["endTimeUnixNano"],
)
graph[span["perfetto_span"]["pid"]][line].insert_span(
span["startTimeUnixNano"], span["endTimeUnixNano"]
)
span["perfetto_span"]["tid"] = line
def generate_perfetto_events(otel_bootstrap_room_spans):
for bootstrap_room_span in otel_bootstrap_room_spans:
for span in bootstrap_room_span["spans"]:
span["perfetto_events"] = []
if "events" in span:
for event in span["events"]:
attributes_dict = {
attr.get("key"): next(
iter(attr.get("value", {}).values()), None
)
for attr in event["attributes"]
}
perfetto_event = {
"ph": "i",
"cat": "sglang",
"ts": (int(event["timeUnixNano"]) - baseline) / 1000.0,
"pid": span["perfetto_span"]["pid"],
"tid": span["perfetto_span"]["tid"],
"name": event.get("name", "unknown"),
"args": attributes_dict,
}
span["perfetto_events"].append(perfetto_event)
def generate_perfetto_links(otel_bootstrap_room_spans):
for bootstrap_room_span in otel_bootstrap_room_spans:
for span in bootstrap_room_span["spans"]:
span["perfetto_links"] = []
if "links" in span:
for link_span in span["links"]:
if "correlation" in link_span["perfetto_span"]["args"]:
id = link_span["perfetto_span"]["args"]["correlation"]
else:
id = next(relation_id_gen)
link_span["perfetto_span"]["args"]["correlation"] = id
perfetto_start_node = {
"ph": "s",
"id": id,
"pid": link_span["perfetto_span"]["pid"],
"tid": link_span["perfetto_span"]["tid"],
"ts": link_span["perfetto_span"]["ts"],
"cat": "ac2g",
"name": "ac2g",
}
perfetto_end_node = {
"ph": "f",
"id": id,
"pid": span["perfetto_span"]["pid"],
"tid": span["perfetto_span"]["tid"],
"ts": span["perfetto_span"]["ts"],
"cat": "ac2g",
"name": "ac2g",
"bp": "e",
}
span["perfetto_links"].append(perfetto_start_node)
span["perfetto_links"].append(perfetto_end_node)
def gather_all_perfetto_elems(
otel_bootstrap_room_spans, thread_meta_data, slot_meta_data
):
elems = []
elems.extend(thread_meta_data.values())
elems.extend(slot_meta_data)
for bootstrap_room_span in otel_bootstrap_room_spans:
for span in bootstrap_room_span["spans"]:
elems.append(span["perfetto_span"])
elems.extend(span["perfetto_events"])
elems.extend(span["perfetto_links"])
return elems
def write_json(perfetto_elems):
global perfetto_data
if args.torch_file:
perfetto_data["traceEvents"].extend(perfetto_elems)
filered_data = [
item
for item in perfetto_data["traceEvents"]
if item.get("cat") != "gpu_user_annotation"
]
perfetto_data["traceEvents"] = filered_data
else:
perfetto_data = perfetto_elems
with open(args.output_file, "w", encoding="utf-8") as file:
json.dump(perfetto_data, file, ensure_ascii=False, indent=4)
def main():
start_time = time.time()
otel_data = load_otel_data(args.input_file)
otel_spans = extract_all_otel_spans(otel_data)
otel_bootstrap_room_spans = build_otel_span_tree(otel_spans)
thread_meta_data = {}
generate_perfetto_span(otel_bootstrap_room_spans, thread_meta_data)
slot_meta_data = []
generate_perfetto_span_layout(otel_bootstrap_room_spans, slot_meta_data)
generate_perfetto_events(otel_bootstrap_room_spans)
generate_perfetto_links(otel_bootstrap_room_spans)
perfetto_elems = gather_all_perfetto_elems(
otel_bootstrap_room_spans, thread_meta_data, slot_meta_data
)
write_json(perfetto_elems)
end_time = time.time()
execution_time = end_time - start_time
print(f"\nConversion finished successfully!")
print(f"Output written to: {args.output_file}")
print(f"Execution time: {execution_time * 1000:.4f} ms")
if __name__ == "__main__":
main()
......@@ -38,9 +38,22 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
router_args = args
if router_args.mini_lb:
if router_args.enable_trace:
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_set_thread_info,
)
process_tracing_init(router_args.otlp_traces_endpoint, "sglang")
trace_set_thread_info("Mini lb")
mini_lb = MiniLoadBalancer(router_args)
mini_lb.start()
else:
# TODO: support tracing for router(Rust).
del router_args.enable_trace
del router_args.otlp_traces_endpoint
if Router is None:
raise RuntimeError("Rust Router is not installed")
router_args._validate_router_args()
......
......@@ -18,6 +18,14 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang_router.router_args import RouterArgs
from sglang.srt.tracing.trace import (
trace_get_remote_propagate_context,
trace_req_finish,
trace_req_start,
trace_slice_end,
trace_slice_start,
)
logger = logging.getLogger(__name__)
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
......@@ -46,6 +54,7 @@ class MiniLoadBalancer:
self.prefill_urls = [url[0] for url in router_args.prefill_urls]
self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls]
self.decode_urls = router_args.decode_urls
self.enable_trace = router_args.enable_trace
def _validate_router_args(self, router_args: RouterArgs):
logger.warning(
......@@ -91,11 +100,33 @@ class MiniLoadBalancer:
total=self.timeout
) # Add timeout for request reliability
) as session:
headers = {}
bootstrap_room_list = []
if self.enable_trace:
bootstrap_room_list = (
modified_request["bootstrap_room"]
if isinstance(modified_request["bootstrap_room"], list)
else [modified_request["bootstrap_room"]]
)
trace_context = trace_get_remote_propagate_context(bootstrap_room_list)
headers = {"trace_context": trace_context}
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
session.post(
f"{prefill_server}/{endpoint}",
json=modified_request,
headers=headers,
),
session.post(
f"{decode_server}/{endpoint}",
json=modified_request,
headers=headers,
),
]
for bootstrap_room in bootstrap_room_list:
trace_slice_end("mini_lb_launch", bootstrap_room, auto_next_anon=True)
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
......@@ -114,6 +145,14 @@ class MiniLoadBalancer:
else:
ret_json = await decode_response.json()
for bootstrap_room in bootstrap_room_list:
trace_slice_end(
"wait_PD_finish",
bootstrap_room,
thread_finish_flag=True,
)
trace_req_finish(bootstrap_room)
return ORJSONResponse(
content=ret_json,
status_code=decode_response.status,
......@@ -131,10 +170,36 @@ class MiniLoadBalancer:
) # Add timeout for request reliability
) as session:
# Create the tasks for both prefill and decode requests
headers = {}
bootstrap_room_list = []
if self.enable_trace:
bootstrap_room_list = (
modified_request["bootstrap_room"]
if isinstance(modified_request["bootstrap_room"], list)
else [modified_request["bootstrap_room"]]
)
trace_context = trace_get_remote_propagate_context(
bootstrap_room_list
)
headers = {"trace_context": trace_context}
tasks = [
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
session.post(f"{decode_server}/{endpoint}", json=modified_request),
session.post(
f"{prefill_server}/{endpoint}",
json=modified_request,
headers=headers,
),
session.post(
f"{decode_server}/{endpoint}",
json=modified_request,
headers=headers,
),
]
for bootstrap_room in bootstrap_room_list:
trace_slice_end(
"mini_lb_launch", bootstrap_room, auto_next_anon=True
)
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response, decode_response = await asyncio.gather(*tasks)
......@@ -174,6 +239,14 @@ class MiniLoadBalancer:
):
yield chunk
for bootstrap_room in bootstrap_room_list:
trace_slice_end(
"wait_PD_finish",
bootstrap_room,
thread_finish_flag=True,
)
trace_req_finish(bootstrap_room)
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
......@@ -367,7 +440,10 @@ async def handle_completion_request(request_data: dict):
def _generate_bootstrap_room():
return random.randint(0, 2**63 - 1)
bootstrap_room = random.randint(0, 2**63 - 1)
trace_req_start(bootstrap_room, bootstrap_room, role="router")
trace_slice_start("mini_lb_launch", bootstrap_room)
return bootstrap_room
# We may utilize `GenerateReqInput`'s logic later
......
......@@ -112,6 +112,9 @@ class RouterArgs:
client_cert_path: Optional[str] = None
client_key_path: Optional[str] = None
ca_cert_paths: List[str] = dataclasses.field(default_factory=list)
# Trace
enable_trace: bool = False
otlp_traces_endpoint: str = "localhost:4317"
@staticmethod
def add_cli_args(
......@@ -608,6 +611,17 @@ class RouterArgs:
default=[],
help="Path(s) to CA certificate(s) for verifying worker TLS certificates. Can specify multiple CAs.",
)
parser.add_argument(
f"--{prefix}enable-trace",
action="store_true",
help="Enable opentelemetry trace",
)
parser.add_argument(
f"--{prefix}otlp-traces-endpoint",
type=str,
default="localhost:4317",
help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
)
@classmethod
def from_cli_args(
......
......@@ -74,7 +74,7 @@ class TestTrace(CustomTestCase):
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--enable-trace", "--oltp-traces-endpoint", "0.0.0.0:4317"],
other_args=["--enable-trace", "--otlp-traces-endpoint", "0.0.0.0:4317"],
)
try:
......@@ -121,7 +121,7 @@ class TestTrace(CustomTestCase):
model_path=model_path,
random_seed=42,
enable_trace=True,
oltp_traces_endpoint="localhost:4317",
otlp_traces_endpoint="localhost:4317",
)
try:
......@@ -148,7 +148,7 @@ class TestTrace(CustomTestCase):
model_path=model_path,
random_seed=42,
enable_trace=True,
oltp_traces_endpoint="localhost:4317",
otlp_traces_endpoint="localhost:4317",
is_embedding=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment