Unverified Commit f29753dc authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: Add NVTX markers for sglang EPD (#7079)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 5aa2f53f
...@@ -17,6 +17,7 @@ from transformers import AutoTokenizer ...@@ -17,6 +17,7 @@ from transformers import AutoTokenizer
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo._core import Client, Context from dynamo._core import Client, Context
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import SglangMultimodalRequest from dynamo.sglang.protocol import SglangMultimodalRequest
...@@ -96,6 +97,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -96,6 +97,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
def cleanup(self) -> None: def cleanup(self) -> None:
pass pass
@_nvtx.range_decorator("mm:enc:generate", color="blue")
async def generate( async def generate(
self, request: SglangMultimodalRequest, context: Context self, request: SglangMultimodalRequest, context: Context
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
...@@ -138,6 +140,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -138,6 +140,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
) )
image_urls.append(mm_input.image_url) image_urls.append(mm_input.image_url)
with _nvtx.annotate("mm:enc:vision_encode", color="red"):
image_grid_dim, precomputed_embeddings = await self.encoder._encode( image_grid_dim, precomputed_embeddings = await self.encoder._encode(
image_urls image_urls
) )
...@@ -251,6 +254,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -251,6 +254,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
response_generator = await self.pd_worker_client.round_robin( response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json() request.model_dump_json()
) )
with _nvtx.annotate("mm:enc:embedding_transfer", color="purple"):
await readable.wait_for_completion() await readable.wait_for_completion()
async for response in response_generator: async for response in response_generator:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import Any, AsyncIterator, Optional from typing import Any, AsyncIterator, Callable, Optional
import sglang as sgl import sglang as sgl
import torch import torch
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo._core import Client, Context from dynamo._core import Client, Context
from dynamo.common.constants import DisaggregationMode from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import ( from dynamo.sglang.protocol import (
...@@ -113,7 +114,9 @@ class EmbeddingsProcessor: ...@@ -113,7 +114,9 @@ class EmbeddingsProcessor:
) )
self._connector = connect.Connector() self._connector = connect.Connector()
with _nvtx.annotate("mm:nixl:begin_read", color="blue"):
read_op = await self._connector.begin_read(serialized_request, descriptor) read_op = await self._connector.begin_read(serialized_request, descriptor)
with _nvtx.annotate("mm:nixl:wait_completion", color="cyan"):
await read_op.wait_for_completion() await read_op.wait_for_completion()
return embeddings, descriptor return embeddings, descriptor
...@@ -317,23 +320,48 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -317,23 +320,48 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
request: Multimodal request with input and parameters. request: Multimodal request with input and parameters.
context: Context object for cancellation handling. context: Context object for cancellation handling.
""" """
rng_pd = _nvtx.start_range("mm:pd:generate", color="green")
rng_ttft = _nvtx.start_range("mm:pd:ttft", color="yellow")
ttft_ended = False
def _end_ttft() -> None:
nonlocal ttft_ended
if not ttft_ended:
_nvtx.end_range(rng_ttft)
ttft_ended = True
try: try:
request = self._validate_and_parse_request(request) request = self._validate_and_parse_request(request)
# Route to appropriate generation method based on serving mode # Route to appropriate generation method based on serving mode
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
async for output in self._generate_disaggregated(request): rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
try:
async for output in self._generate_disaggregated(
request, _end_ttft
):
yield output yield output
finally:
_nvtx.end_range(rng_disagg)
else: else:
async for output in self._generate_aggregated(request): rng_agg = _nvtx.start_range("mm:pd:generate_agg", color="red")
try:
async for output in self._generate_aggregated(request, _end_ttft):
yield output yield output
finally:
_nvtx.end_range(rng_agg)
except Exception as e: except Exception as e:
logger.error(f"Error in multimodal generation: {e}", exc_info=True) logger.error(f"Error in multimodal generation: {e}", exc_info=True)
yield ErrorResponseBuilder.build_error_response(e) yield ErrorResponseBuilder.build_error_response(e)
finally:
_end_ttft()
_nvtx.end_range(rng_pd)
async def _generate_disaggregated( async def _generate_disaggregated(
self, request: SglangMultimodalRequest self,
request: SglangMultimodalRequest,
end_ttft: Callable[[], None],
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Handle disaggregated mode generation""" """Handle disaggregated mode generation"""
input_ids = request.request.token_ids input_ids = request.request.token_ids
...@@ -357,11 +385,24 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -357,11 +385,24 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
bootstrap_room=bootstrap_info["bootstrap_room"], bootstrap_room=bootstrap_info["bootstrap_room"],
) )
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
first_token = True
try:
async for output in StreamProcessor.process_sglang_stream(decode_stream): async for output in StreamProcessor.process_sglang_stream(decode_stream):
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
first_token = False
yield output yield output
finally:
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
async def _generate_aggregated( async def _generate_aggregated(
self, request: SglangMultimodalRequest self,
request: SglangMultimodalRequest,
end_ttft: Callable[[], None],
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Handle aggregated mode generation""" """Handle aggregated mode generation"""
input_ids = request.request.token_ids input_ids = request.request.token_ids
...@@ -370,6 +411,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -370,6 +411,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try: try:
sampling_params = SglangUtils.build_sampling_params(request) sampling_params = SglangUtils.build_sampling_params(request)
with _nvtx.annotate("mm:pd:load_multimodal", color="cyan"):
mm_items, combined_embeddings = await _build_mm_items( mm_items, combined_embeddings = await _build_mm_items(
request, self.embeddings_processor request, self.embeddings_processor
) )
...@@ -387,8 +429,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -387,8 +429,19 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
stream=True, stream=True,
) )
rng_first = _nvtx.start_range("mm:dec:first_token", color="purple")
first_token = True
try:
async for output in StreamProcessor.process_sglang_stream(agg_stream): async for output in StreamProcessor.process_sglang_stream(agg_stream):
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
first_token = False
yield output yield output
finally:
if first_token:
end_ttft()
_nvtx.end_range(rng_first)
except RuntimeError as e: except RuntimeError as e:
if "shape mismatch" in str(e): if "shape mismatch" in str(e):
...@@ -479,6 +532,15 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -479,6 +532,15 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
disagg_request: Disaggregated multimodal request. disagg_request: Disaggregated multimodal request.
context: Context object for cancellation handling. context: Context object for cancellation handling.
""" """
rng_bootstrap = _nvtx.start_range("mm:prefill:bootstrap", color="yellow")
bootstrap_ended = False
def _end_bootstrap() -> None:
nonlocal bootstrap_ended
if not bootstrap_ended:
_nvtx.end_range(rng_bootstrap)
bootstrap_ended = True
bootstrap_room = None bootstrap_room = None
try: try:
# Validate and parse request # Validate and parse request
...@@ -492,6 +554,7 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -492,6 +554,7 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
"bootstrap_room": bootstrap_room, "bootstrap_room": bootstrap_room,
} }
_end_bootstrap()
yield json.dumps(bootstrap_info) yield json.dumps(bootstrap_info)
# Process prefill generation # Process prefill generation
...@@ -503,6 +566,8 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -503,6 +566,8 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
{"bootstrap_room": bootstrap_room} if bootstrap_room is not None else {} {"bootstrap_room": bootstrap_room} if bootstrap_room is not None else {}
) )
yield ErrorResponseBuilder.build_error_response(e, extra_fields) yield ErrorResponseBuilder.build_error_response(e, extra_fields)
finally:
_end_bootstrap()
def _validate_and_parse_disagg_request( def _validate_and_parse_disagg_request(
self, disagg_request self, disagg_request
...@@ -529,9 +594,11 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -529,9 +594,11 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
sampling_params = disagg_request.sampling_params sampling_params = disagg_request.sampling_params
# Process embeddings from encode worker using our embeddings processor # Process embeddings from encode worker using our embeddings processor
with _nvtx.annotate("mm:prefill:load_multimodal", color="cyan"):
mm_items, _ = await _build_mm_items(request, self.embeddings_processor) mm_items, _ = await _build_mm_items(request, self.embeddings_processor)
# Start SGLang prefill generation (like regular SGLang) # Start SGLang prefill generation (like regular SGLang)
with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"):
results = await self.engine.async_generate( results = await self.engine.async_generate(
input_ids=input_ids, input_ids=input_ids,
image_data=mm_items, image_data=mm_items,
......
...@@ -403,6 +403,39 @@ export SGLANG_ENCODER_MM_LOAD_WORKERS=16 ...@@ -403,6 +403,39 @@ export SGLANG_ENCODER_MM_LOAD_WORKERS=16
Only applies to the EPD encode worker (which uses [SGLang's MMEncoder](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/disaggregation/encode_server.py) internally). Only applies to the EPD encode worker (which uses [SGLang's MMEncoder](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/disaggregation/encode_server.py) internally).
## Profiling
Dynamo's SGLang multimodal workers include NVTX markers for `nsys` profiling. They are disabled by default (zero overhead) and enabled by setting `DYN_NVTX=1`.
```bash
cd $DYNAMO_HOME/examples/backends/sglang
DYN_NVTX=1 nsys profile --trace=cuda,nvtx -o profile.nsys-rep \
bash launch/multimodal_epd.sh ...
```
| ENV Variable | Default | Description |
|---|---|---|
| `DYN_NVTX` | `0` | Set to `1` to enable NVTX range/mark annotations in multimodal encode/prefill/decode worker paths for `nsys` profiling |
Key NVTX ranges emitted:
| Range | Worker | Description |
|-------|--------|-------------|
| `mm:enc:generate` | Encode | Full encode request lifetime |
| `mm:enc:vision_encode` | Encode | Vision encode call (`MMEncoder._encode`) |
| `mm:enc:embedding_transfer` | Encode | Embedding handoff to downstream worker |
| `mm:nixl:begin_read` | PD (agg) / Prefill | Begin NIXL read operation for embeddings |
| `mm:nixl:wait_completion` | PD (agg) / Prefill | Wait for NIXL embedding transfer completion |
| `mm:pd:generate` | Aggregated worker / Decode worker (`MultimodalWorkerHandler`) | Full worker-side request lifetime |
| `mm:pd:generate_agg` | PD (agg) | Aggregated generation path |
| `mm:pd:load_multimodal` | PD (agg) | Build multimodal items from transferred embeddings |
| `mm:pd:generate_disagg` | Decode worker (disagg entrypoint) | Disaggregated generation path |
| `mm:prefill:bootstrap` | Prefill (disagg) | Bootstrap coordination path before returning `{bootstrap_host, bootstrap_port, bootstrap_room}` |
| `mm:prefill:load_multimodal` | Prefill (disagg) | Build multimodal items from transferred embeddings in the prefill worker |
| `mm:prefill:engine_async_generate` | Prefill (disagg) | SGLang prefill engine invocation (`engine.async_generate`) |
| `mm:pd:ttft` | Aggregated worker / Decode worker (`MultimodalWorkerHandler`) | Worker-entry TTFT: from request arrival at this worker to first output token (excludes client->frontend->worker network transit) |
| `mm:dec:first_token` | Aggregated worker / Decode worker (`MultimodalWorkerHandler`) | Decode-stage first-token range (starts when decode stream is launched; not worker-entry TTFT) |
## Known Limitations ## Known Limitations
- **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported - **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment