"vscode:/vscode.git/clone" did not exist on "b4c8d9481d00225bedb065471cf8b6bf35f769e1"
Unverified Commit f29753dc authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: Add NVTX markers for sglang EPD (#7079)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 5aa2f53f
......@@ -17,6 +17,7 @@ from transformers import AutoTokenizer
import dynamo.nixl_connect as connect
from dynamo._core import Client, Context
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import SglangMultimodalRequest
......@@ -96,6 +97,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
def cleanup(self) -> None:
pass
@_nvtx.range_decorator("mm:enc:generate", color="blue")
async def generate(
self, request: SglangMultimodalRequest, context: Context
) -> AsyncIterator[str]:
......@@ -138,9 +140,10 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
)
image_urls.append(mm_input.image_url)
image_grid_dim, precomputed_embeddings = await self.encoder._encode(
image_urls
)
with _nvtx.annotate("mm:enc:vision_encode", color="red"):
image_grid_dim, precomputed_embeddings = await self.encoder._encode(
image_urls
)
image_grid_thw_list = (
image_grid_dim.tolist()
......@@ -251,7 +254,8 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
with _nvtx.annotate("mm:enc:embedding_transfer", color="purple"):
await readable.wait_for_completion()
async for response in response_generator:
yield response.data() if hasattr(response, "data") else str(
......
......@@ -4,7 +4,7 @@
import asyncio
import json
import logging
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncIterator, Callable, Optional
import sglang as sgl
import torch
......@@ -12,6 +12,7 @@ import torch
import dynamo.nixl_connect as connect
from dynamo._core import Client, Context
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import (
......@@ -113,8 +114,10 @@ class EmbeddingsProcessor:
)
self._connector = connect.Connector()
read_op = await self._connector.begin_read(serialized_request, descriptor)
await read_op.wait_for_completion()
with _nvtx.annotate("mm:nixl:begin_read", color="blue"):
read_op = await self._connector.begin_read(serialized_request, descriptor)
with _nvtx.annotate("mm:nixl:wait_completion", color="cyan"):
await read_op.wait_for_completion()
return embeddings, descriptor
......@@ -317,23 +320,48 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
request: Multimodal request with input and parameters.
context: Context object for cancellation handling.
"""
rng_pd = _nvtx.start_range("mm:pd:generate", color="green")
rng_ttft = _nvtx.start_range("mm:pd:ttft", color="yellow")
ttft_ended = False
def _end_ttft() -> None:
nonlocal ttft_ended
if not ttft_ended:
_nvtx.end_range(rng_ttft)
ttft_ended = True
try:
request = self._validate_and_parse_request(request)
# Route to appropriate generation method based on serving mode
if self.serving_mode == DisaggregationMode.DECODE:
async for output in self._generate_disaggregated(request):
yield output
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
try:
async for output in self._generate_disaggregated(
request, _end_ttft
):
yield output
finally:
_nvtx.end_range(rng_disagg)
else:
async for output in self._generate_aggregated(request):
yield output
rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
try:
async for output in self._generate_aggregated(request, _end_ttft):
yield output
finally:
_nvtx.end_range(rng_agg)
except Exception as e:
logger.error(f"Error in multimodal generation: {e}", exc_info=True)
yield ErrorResponseBuilder.build_error_response(e)
finally:
_end_ttft()
_nvtx.end_range(rng_pd)
async def _generate_disaggregated(
self, request: SglangMultimodalRequest
self,
request: SglangMultimodalRequest,
end_ttft: Callable[[], None],
) -> AsyncIterator[str]:
"""Handle disaggregated mode generation"""
input_ids = request.request.token_ids
......@@ -357,11 +385,24 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
bootstrap_room=bootstrap_info["bootstrap_room"],
)
async for output in StreamProcessor.process_sglang_stream(decode_stream):
yield output
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
first_token = True
try:
async for output in StreamProcessor.process_sglang_stream(decode_stream):
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
first_token = False
yield output
finally:
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
async def _generate_aggregated(
self, request: SglangMultimodalRequest
self,
request: SglangMultimodalRequest,
end_ttft: Callable[[], None],
) -> AsyncIterator[str]:
"""Handle aggregated mode generation"""
input_ids = request.request.token_ids
......@@ -370,9 +411,10 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try:
sampling_params = SglangUtils.build_sampling_params(request)
mm_items, combined_embeddings = await _build_mm_items(
request, self.embeddings_processor
)
with _nvtx.annotate("mm:pd:load_multimodal", color="cyan"):
mm_items, combined_embeddings = await _build_mm_items(
request, self.embeddings_processor
)
logger.debug(
"Generated combined multimodal item with embeddings shape: "
......@@ -387,8 +429,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
stream=True,
)
async for output in StreamProcessor.process_sglang_stream(agg_stream):
yield output
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
first_token = True
try:
async for output in StreamProcessor.process_sglang_stream(agg_stream):
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
first_token = False
yield output
finally:
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
except RuntimeError as e:
if "shape mismatch" in str(e):
......@@ -479,6 +532,15 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
disagg_request: Disaggregated multimodal request.
context: Context object for cancellation handling.
"""
rng_bootstrap = _nvtx.start_range("mm:prefill:bootstrap", color="yellow")
bootstrap_ended = False
def _end_bootstrap() -> None:
nonlocal bootstrap_ended
if not bootstrap_ended:
_nvtx.end_range(rng_bootstrap)
bootstrap_ended = True
bootstrap_room = None
try:
# Validate and parse request
......@@ -492,6 +554,7 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
"bootstrap_room": bootstrap_room,
}
_end_bootstrap()
yield json.dumps(bootstrap_info)
# Process prefill generation
......@@ -503,6 +566,8 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
{"bootstrap_room": bootstrap_room} if bootstrap_room is not None else {}
)
yield ErrorResponseBuilder.build_error_response(e, extra_fields)
finally:
_end_bootstrap()
def _validate_and_parse_disagg_request(
self, disagg_request
......@@ -529,18 +594,20 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
sampling_params = disagg_request.sampling_params
# Process embeddings from encode worker using our embeddings processor
mm_items, _ = await _build_mm_items(request, self.embeddings_processor)
with _nvtx.annotate("mm:prefill:load_multimodal", color="cyan"):
mm_items, _ = await _build_mm_items(request, self.embeddings_processor)
# Start SGLang prefill generation (like regular SGLang)
results = await self.engine.async_generate(
input_ids=input_ids,
image_data=mm_items,
sampling_params=sampling_params,
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
)
with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"):
results = await self.engine.async_generate(
input_ids=input_ids,
image_data=mm_items,
sampling_params=sampling_params,
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
)
# Consume results without yielding (prefill doesn't return text, just coordinates)
asyncio.create_task(self._consume_results(results))
......
......@@ -403,6 +403,39 @@ export SGLANG_ENCODER_MM_LOAD_WORKERS=16
Only applies to the EPD encode worker (which uses [SGLang's MMEncoder](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/disaggregation/encode_server.py) internally).
## Profiling
Dynamo's SGLang multimodal workers include NVTX markers for `nsys` profiling. They are disabled by default (zero overhead) and enabled by setting `DYN_NVTX=1`.
```bash
cd $DYNAMO_HOME/examples/backends/sglang
DYN_NVTX=1 nsys profile --trace=cuda,nvtx -o profile.nsys-rep \
bash launch/multimodal_epd.sh ...
```
| ENV Variable | Default | Description |
|---|---|---|
| `DYN_NVTX` | `0` | Set to `1` to enable NVTX range/mark annotations in multimodal encode/prefill/decode worker paths for `nsys` profiling |
Key NVTX ranges emitted:
| Range | Worker | Description |
|-------|--------|-------------|
| `mm:enc:generate` | Encode | Full encode request lifetime |
| `mm:enc:vision_encode` | Encode | Vision encode call (`MMEncoder._encode`) |
| `mm:enc:embedding_transfer` | Encode | Embedding handoff to downstream worker |
| `mm:nixl:begin_read` | PD (agg) / Prefill | Begin NIXL read operation for embeddings |
| `mm:nixl:wait_completion` | PD (agg) / Prefill | Wait for NIXL embedding transfer completion |
| `mm:pd:generate` | Aggregated worker / Decode worker (`MultimodalWorkerHandler`) | Full worker-side request lifetime |
| `mm:pd:generate_agg` | PD (agg) | Aggregated generation path |
| `mm:pd:load_multimodal` | PD (agg) | Build multimodal items from transferred embeddings |
| `mm:pd:generate_disagg` | Decode worker (disagg entrypoint) | Disaggregated generation path |
| `mm:prefill:bootstrap` | Prefill (disagg) | Bootstrap coordination path before returning `{bootstrap_host, bootstrap_port, bootstrap_room}` |
| `mm:prefill:load_multimodal` | Prefill (disagg) | Build multimodal items from transferred embeddings in the prefill worker |
| `mm:prefill:engine_async_generate` | Prefill (disagg) | SGLang prefill engine invocation (`engine.async_generate`) |
| `mm:pd:ttft` | Aggregated worker / Decode worker (`MultimodalWorkerHandler`) | Worker-entry TTFT: from request arrival at this worker to first output token (excludes client->frontend->worker network transit) |
| `mm:dec:first_token` | Aggregated worker / Decode worker (`MultimodalWorkerHandler`) | Decode-stage first-token range (starts when decode stream is launched; not worker-entry TTFT) |
## Known Limitations
- **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment