Unverified Commit 7edb07b5 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat(sglang): disagg DP rank routing + backwards-compatible network imports (#6736)


Co-authored-by: default avatarhuitian bai <baihuitian.bht@gmail.com>
parent 14a6122e
...@@ -4,6 +4,29 @@ Dynamo's SGLang backend wraps SGLang's inference engine (`sgl.Engine`) and diffu ...@@ -4,6 +4,29 @@ Dynamo's SGLang backend wraps SGLang's inference engine (`sgl.Engine`) and diffu
generator (`DiffGenerator`) behind Dynamo's distributed runtime. It handles model generator (`DiffGenerator`) behind Dynamo's distributed runtime. It handles model
registration, request routing, metrics, and disaggregated serving. registration, request routing, metrics, and disaggregated serving.
## SGLang Backwards Compatibility
SGLang is pre-1.0 and regularly moves/renames internal APIs between releases. We
support the current version plus 1 version back (N and N-1). The pattern:
1. **All SGLang imports that have broken (or may break) across versions go through
`_compat.py`**, never directly from `sglang.*` in component code.
2. `_compat.py` uses try/except ImportError: new path first, old path fallback.
3. When SGLang introduces a new class/function that doesn't exist in older versions
(e.g., `NetworkAddress`), add a minimal polyfill in the except branch -- just
enough surface area to cover what Dynamo actually calls.
4. Each fallback branch in `_compat.py` MUST have a comment noting which SGLang
version it supports and when it can be removed, e.g.:
`# Fallback for sglang <= 0.5.9. Remove when min supported version is 0.6.0+`
5. When a new SGLang version is released and the old N-1 falls outside the support
window, delete the corresponding fallback branches and polyfills from `_compat.py`.
If `_compat.py` becomes trivial re-exports, inline the imports and delete the file.
**When you encounter a new SGLang API breakage**: add the affected imports to
`_compat.py` following the existing pattern. Do not scatter try/except blocks across
component files. Do not version-check with `sglang.__version__` -- import probing is
more reliable since SGLang's internal layout doesn't always match the version string.
## Entry Point ## Entry Point
`__main__.py` -> `main.py:main()` -> `main.py:worker()` `__main__.py` -> `main.py:main()` -> `main.py:worker()`
...@@ -272,6 +295,7 @@ Checklist for adding a new worker (e.g., a new modality or serving mode): ...@@ -272,6 +295,7 @@ Checklist for adding a new worker (e.g., a new modality or serving mode):
``` ```
sglang/ sglang/
_compat.py # SGLang version compat shim (network imports + NetworkAddress polyfill)
__main__.py # Entry point __main__.py # Entry point
main.py # Worker dispatch main.py # Worker dispatch
args.py # Config parsing (ServerArgs vs SimpleNamespace) args.py # Config parsing (ServerArgs vs SimpleNamespace)
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Compatibility shim for SGLang internal APIs.
SGLang is pre-1.0 and routinely moves, renames, or introduces APIs between
releases. This module is the single place where we handle those differences
so the rest of the component can import from here without version-specific
try/except blocks.
Policy: support current SGLang release + 1 version back (N and N-1). Each
fallback branch must document which version it covers and when it can be
removed. When the old version falls outside the support window, delete the
fallback and any associated polyfills.
"""
import ipaddress
import logging
import socket
logger = logging.getLogger(__name__)
try:
from sglang.srt.utils.network import ( # noqa: F401
NetworkAddress,
get_local_ip_auto,
get_zmq_socket,
)
_SGLANG_HAS_NETWORK_MODULE = True
except ImportError:
# Fallback for sglang <= 0.5.9. Remove when min supported version is 0.6.0+
from sglang.srt.utils import ( # type: ignore[no-redef] # noqa: F401
get_local_ip_auto,
get_zmq_socket,
)
_SGLANG_HAS_NETWORK_MODULE = False
logger.info(
"sglang.srt.utils.network not found (sglang <= 0.5.9); "
"using compatibility shim for NetworkAddress"
)
class NetworkAddress: # type: ignore[no-redef]
"""Minimal polyfill for sglang.srt.utils.network.NetworkAddress."""
def __init__(self, host: str, port: int) -> None:
self.host = host
self.port = port
@property
def is_ipv6(self) -> bool:
try:
ipaddress.IPv6Address(self.host)
return True
except ValueError:
return False
@classmethod
def parse(cls, addr: str) -> "NetworkAddress":
"""Parse 'host:port', '[IPv6]:port', or bare host."""
addr = addr.strip()
if addr.startswith("["):
end = addr.find("]")
host = addr[1:end] if end != -1 else addr.strip("[]")
rest = addr[end + 1 :] if end != -1 else ""
if rest.startswith(":") and rest[1:].isdigit():
return cls(host, int(rest[1:]))
return cls(host, 0)
if addr.count(":") == 1:
host_part, port_part = addr.rsplit(":", 1)
if port_part.isdigit():
return cls(host_part, int(port_part))
return cls(addr, 0)
def resolved(self) -> "NetworkAddress":
"""DNS-resolve the host, preserving port."""
try:
infos = socket.getaddrinfo(
self.host, None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
resolved_ip = infos[0][4][0]
return NetworkAddress(resolved_ip, self.port)
except socket.gaierror:
return self
def to_host_port_str(self) -> str:
"""Return '[IPv6]:port' or 'host:port'."""
if self.is_ipv6:
return f"[{self.host}]:{self.port}"
return f"{self.host}:{self.port}"
def to_tcp(self) -> str:
"""Return 'tcp://[IPv6]:port' or 'tcp://host:port'."""
if self.is_ipv6:
return f"tcp://[{self.host}]:{self.port}"
return f"tcp://{self.host}:{self.port}"
__all__ = [
"NetworkAddress",
"get_local_ip_auto",
"get_zmq_socket",
"_SGLANG_HAS_NETWORK_MODULE",
]
...@@ -5,12 +5,14 @@ import asyncio ...@@ -5,12 +5,14 @@ import asyncio
import json import json
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse
import sglang as sgl import sglang as sgl
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.disaggregation.kv_events import ZmqEventPublisher from sglang.srt.disaggregation.kv_events import ZmqEventPublisher
from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, maybe_wrap_ipv6_address
from dynamo.sglang._compat import NetworkAddress, get_local_ip_auto, get_zmq_socket
if TYPE_CHECKING: if TYPE_CHECKING:
from prometheus_client import CollectorRegistry from prometheus_client import CollectorRegistry
...@@ -28,8 +30,7 @@ from dynamo.sglang.args import Config ...@@ -28,8 +30,7 @@ from dynamo.sglang.args import Config
def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str: def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str:
"""Format ZMQ endpoint by replacing wildcard with IP address. """Format ZMQ endpoint by replacing wildcard with IP address.
Properly handles IPv6 addresses by wrapping them in square brackets. Properly handles IPv6 addresses using SGLang's NetworkAddress utility.
Uses SGLang's maybe_wrap_ipv6_address for consistent formatting.
Args: Args:
endpoint_template: ZMQ endpoint template with wildcard (e.g., "tcp://*:5557") endpoint_template: ZMQ endpoint template with wildcard (e.g., "tcp://*:5557")
...@@ -44,9 +45,12 @@ def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str: ...@@ -44,9 +45,12 @@ def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str:
>>> format_zmq_endpoint("tcp://*:5557", "2a02:6b8:c46:2b4:0:74c1:75b0:0") >>> format_zmq_endpoint("tcp://*:5557", "2a02:6b8:c46:2b4:0:74c1:75b0:0")
'tcp://[2a02:6b8:c46:2b4:0:74c1:75b0:0]:5557' 'tcp://[2a02:6b8:c46:2b4:0:74c1:75b0:0]:5557'
""" """
# Use SGLang's utility to wrap IPv6 addresses in brackets parsed = urlparse(endpoint_template)
formatted_ip = maybe_wrap_ipv6_address(ip_address) if parsed.scheme != "tcp" or parsed.port is None:
return endpoint_template.replace("*", formatted_ip) raise ValueError(
f"Expected tcp://host:port endpoint, got {endpoint_template!r}"
)
return NetworkAddress(ip_address, parsed.port).to_tcp()
# Note: We use SGLang's ZmqEventPublisher.offset_endpoint_port() directly # Note: We use SGLang's ZmqEventPublisher.offset_endpoint_port() directly
......
...@@ -3,16 +3,15 @@ ...@@ -3,16 +3,15 @@
import asyncio import asyncio
import logging import logging
import socket
from typing import Any, List, Optional from typing import Any, List, Optional
import sglang as sgl import sglang as sgl
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Endpoint from dynamo._core import Endpoint
from dynamo.common.utils.output_modalities import get_output_modalities from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_model
from dynamo.sglang._compat import NetworkAddress, get_local_ip_auto
from dynamo.sglang.args import DynamoConfig from dynamo.sglang.args import DynamoConfig
...@@ -88,57 +87,28 @@ def _get_bootstrap_info_for_config( ...@@ -88,57 +87,28 @@ def _get_bootstrap_info_for_config(
return None, None return None, None
if inner_tm.server_args.dist_init_addr: if inner_tm.server_args.dist_init_addr:
# IPv6-ready host extraction and resolution: dist_init = NetworkAddress.parse(inner_tm.server_args.dist_init_addr)
# 1) Extract raw host from "host:port" or "[IPv6]:port"/"[IPv6]". resolved = dist_init.resolved()
# 2) Resolve via AF_UNSPEC to accept A/AAAA and literals. bootstrap_host = (
# 3) Bracket-wrap IPv6 for safe "{host}:{port}" URL formatting. NetworkAddress(resolved.host, bootstrap_port)
addr = inner_tm.server_args.dist_init_addr.strip() .to_host_port_str()
if addr.startswith("["): .rsplit(":", 1)[0]
end = addr.find("]")
host_core = addr[1:end] if end != -1 else addr.strip("[]")
else:
# Only treat single ':' with numeric suffix as host:port; otherwise it's an IPv6/FQDN host.
if addr.count(":") == 1:
host_candidate, maybe_port = addr.rsplit(":", 1)
host_core = host_candidate if maybe_port.isdigit() else addr
else:
host_core = addr
try:
infos = socket.getaddrinfo(
host_core,
None,
family=socket.AF_UNSPEC,
type=socket.SOCK_STREAM,
) )
resolved = infos[0][4][0] # let OS policy pick v4/v6
bootstrap_host = resolved
addr_family = infos[0][0]
logging.info( logging.info(
f"Resolved bootstrap host '{host_core}' -> '{resolved}' " f"Resolved bootstrap host '{dist_init.host}' -> '{resolved.host}' "
f"({'IPv6' if addr_family == socket.AF_INET6 else 'IPv4'})" f"({'IPv6' if resolved.is_ipv6 else 'IPv4'})"
)
except socket.gaierror as e:
# Fallback: keep literal/FQDN as-is (still wrap IPv6 below)
bootstrap_host = host_core
logging.warning(
f"Failed to resolve bootstrap host '{host_core}': {e}, using as-is"
) )
else: else:
# get_local_ip_auto() tries IPv4 first, then IPv6. For explicit control, # get_local_ip_auto() tries IPv4 first, then IPv6. For explicit control,
# set SGLANG_HOST_IP env var (use bracketed format for IPv6: [addr]) # set SGLANG_HOST_IP env var (use bracketed format for IPv6: [addr])
bootstrap_host = get_local_ip_auto() local_ip = get_local_ip_auto()
is_ipv6 = ":" in bootstrap_host local_addr = NetworkAddress(local_ip, bootstrap_port)
bootstrap_host = local_addr.to_host_port_str().rsplit(":", 1)[0]
logging.info( logging.info(
f"Using auto-detected local IP: {bootstrap_host} " f"Using auto-detected local IP: {local_ip} "
f"({'IPv6' if is_ipv6 else 'IPv4'})" f"({'IPv6' if local_addr.is_ipv6 else 'IPv4'})"
) )
# Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid.
assert isinstance(bootstrap_host, str)
if ":" in bootstrap_host and not bootstrap_host.startswith("["):
bootstrap_host = f"[{bootstrap_host}]"
logging.info(f"Wrapped IPv6 address with brackets: {bootstrap_host}")
return bootstrap_host, bootstrap_port return bootstrap_host, bootstrap_port
except Exception as e: except Exception as e:
logging.warning(f"Failed to get bootstrap info: {e}") logging.warning(f"Failed to get bootstrap info: {e}")
......
...@@ -5,7 +5,6 @@ import asyncio ...@@ -5,7 +5,6 @@ import asyncio
import inspect import inspect
import logging import logging
import random import random
import socket
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import ( from typing import (
...@@ -20,12 +19,12 @@ from typing import ( ...@@ -20,12 +19,12 @@ from typing import (
) )
import sglang as sgl import sglang as sgl
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang._compat import NetworkAddress, get_local_ip_auto
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -516,40 +515,18 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]): ...@@ -516,40 +515,18 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port
if inner_tm.server_args.dist_init_addr: if inner_tm.server_args.dist_init_addr:
# IPv6-ready host extraction and resolution: dist_init = NetworkAddress.parse(inner_tm.server_args.dist_init_addr)
# 1) Extract raw host from "host:port" or "[IPv6]:port"/"[IPv6]". bootstrap_host = (
# 2) Resolve via AF_UNSPEC to accept A/AAAA and literals. NetworkAddress(dist_init.resolved().host, bootstrap_port)
# 3) Bracket-wrap IPv6 for safe "{host}:{port}" URL formatting. .to_host_port_str()
addr = inner_tm.server_args.dist_init_addr.strip() .rsplit(":", 1)[0]
if addr.startswith("["):
end = addr.find("]")
host_core = addr[1:end] if end != -1 else addr.strip("[]")
else:
# Only treat single ':' with numeric suffix as host:port; otherwise it's an IPv6/FQDN host.
if addr.count(":") == 1:
host_candidate, maybe_port = addr.rsplit(":", 1)
host_core = host_candidate if maybe_port.isdigit() else addr
else:
host_core = addr
try:
infos = socket.getaddrinfo(
host_core,
None,
family=socket.AF_UNSPEC,
type=socket.SOCK_STREAM,
) )
resolved = infos[0][4][0] # let OS policy pick v4/v6
bootstrap_host = resolved
except socket.gaierror:
# Fallback: keep literal/FQDN as-is (still wrap IPv6 below)
bootstrap_host = host_core
else: else:
bootstrap_host = get_local_ip_auto() bootstrap_host = (
NetworkAddress(get_local_ip_auto(), bootstrap_port)
# Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid. .to_host_port_str()
assert isinstance(bootstrap_host, str) .rsplit(":", 1)[0]
if ":" in bootstrap_host and not bootstrap_host.startswith("["): )
bootstrap_host = f"[{bootstrap_host}]"
return bootstrap_host, bootstrap_port return bootstrap_host, bootstrap_port
......
...@@ -12,6 +12,10 @@ from dynamo.sglang.args import Config ...@@ -12,6 +12,10 @@ from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
# Sentinel value matching u32::MAX from prefill_router.rs SimpleRouter path,
# indicating no specific data-parallel rank was selected.
_DP_RANK_UNSET = 2**32 - 1
class PrefillWorkerHandler(BaseWorkerHandler): class PrefillWorkerHandler(BaseWorkerHandler):
"""Handler for prefill workers in disaggregated serving mode.""" """Handler for prefill workers in disaggregated serving mode."""
...@@ -129,7 +133,12 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -129,7 +133,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
} }
input_param = self._get_input_param(inner_request) input_param = self._get_input_param(inner_request)
priority = (inner_request.get("routing") or {}).get("priority") routing = inner_request.get("routing") or {}
priority = routing.get("priority")
dp_rank = routing.get("dp_rank")
if dp_rank is not None and dp_rank == _DP_RANK_UNSET:
dp_rank = None
trace_header = self._get_trace_header(context) if self.enable_trace else None trace_header = self._get_trace_header(context) if self.enable_trace else None
...@@ -142,6 +151,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -142,6 +151,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
external_trace_header=trace_header, external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
data_parallel_rank=dp_rank,
**self._priority_kwargs(priority), **self._priority_kwargs(priority),
) )
......
...@@ -553,7 +553,7 @@ impl PrefillRouter { ...@@ -553,7 +553,7 @@ impl PrefillRouter {
r.peek_next_worker() r.peek_next_worker()
} }
.ok_or_else(|| anyhow::anyhow!("No workers available for prefill"))?; .ok_or_else(|| anyhow::anyhow!("No workers available for prefill"))?;
Ok((worker_id, 0)) Ok((worker_id, u32::MAX))
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment