Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6e95f5e5
Unverified
Commit
6e95f5e5
authored
Sep 05, 2025
by
Liangsheng Yin
Committed by
GitHub
Sep 05, 2025
Browse files
Simplify `Router` arguments passing and build it in docker image (#9964)
parent
0e9387a9
Changes
24
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1107 additions
and
1560 deletions
+1107
-1560
docker/Dockerfile
docker/Dockerfile
+14
-1
docs/advanced_features/pd_disaggregation.md
docs/advanced_features/pd_disaggregation.md
+3
-3
docs/advanced_features/router.md
docs/advanced_features/router.md
+15
-15
docs/references/multi_node_deployment/lws_pd/lws-examples/lb.yaml
...erences/multi_node_deployment/lws_pd/lws-examples/lb.yaml
+2
-1
docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md
.../references/multi_node_deployment/lws_pd/lws_pd_deploy.md
+2
-1
python/sglang/srt/disaggregation/launch_lb.py
python/sglang/srt/disaggregation/launch_lb.py
+0
-118
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+6
-445
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+2
-49
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+1
-13
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-7
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+19
-0
scripts/ci/ci_install_dependency.sh
scripts/ci/ci_install_dependency.sh
+4
-0
sgl-router/py_src/sglang_router/__init__.py
sgl-router/py_src/sglang_router/__init__.py
+1
-5
sgl-router/py_src/sglang_router/launch_router.py
sgl-router/py_src/sglang_router/launch_router.py
+21
-758
sgl-router/py_src/sglang_router/mini_lb.py
sgl-router/py_src/sglang_router/mini_lb.py
+395
-0
sgl-router/py_src/sglang_router/router.py
sgl-router/py_src/sglang_router/router.py
+41
-123
sgl-router/py_src/sglang_router/router_args.py
sgl-router/py_src/sglang_router/router_args.py
+577
-0
sgl-router/py_test/test_launch_router.py
sgl-router/py_test/test_launch_router.py
+3
-15
sgl-router/py_test/test_launch_server.py
sgl-router/py_test/test_launch_server.py
+1
-1
sgl-router/pyproject.toml
sgl-router/pyproject.toml
+0
-5
No files found.
docker/Dockerfile
View file @
6e95f5e5
...
@@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
...
@@ -36,7 +36,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
ibverbs-providers infiniband-diags perftest
\
ibverbs-providers infiniband-diags perftest
\
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev
\
libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev
\
libboost-all-dev libssl-dev
\
libboost-all-dev libssl-dev
\
libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc
\
libgrpc-dev libgrpc++-dev libprotobuf-dev
protobuf-compiler
protobuf-compiler-grpc
\
pybind11-dev
\
pybind11-dev
\
libhiredis-dev libcurl4-openssl-dev
\
libhiredis-dev libcurl4-openssl-dev
\
libczmq4 libczmq-dev
\
libczmq4 libczmq-dev
\
...
@@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
...
@@ -218,6 +218,19 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
&&
cp
-r
cmake-3.31.1-linux-x86_64/share/
*
/usr/local/share/
\
&&
cp
-r
cmake-3.31.1-linux-x86_64/share/
*
/usr/local/share/
\
&&
rm
-rf
cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
&&
rm
-rf
cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
# Install Rust toolchain for sgl-router
ENV
PATH="/root/.cargo/bin:${PATH}"
RUN
curl
--proto
'=https'
--tlsv1
.2
-sSf
https://sh.rustup.rs | sh
-s
--
-y
\
&&
rustc
--version
&&
cargo
--version
# Build and install sgl-router
RUN
python3
-m
pip
install
--no-cache-dir
setuptools-rust
\
&&
cd
/sgl-workspace/sglang/sgl-router
\
&&
cargo build
--release
\
&&
python3
-m
pip
install
--no-cache-dir
.
\
&&
rm
-rf
/root/.cache
# Add yank script
# Add yank script
COPY
--chown=root:root <<-"EOF" /usr/local/bin/yank
COPY
--chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash
#!/bin/bash
...
...
docs/advanced_features/pd_disaggregation.md
View file @
6e95f5e5
...
@@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine
...
@@ -36,7 +36,7 @@ uv pip install mooncake-transfer-engine
```
bash
```
bash
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-ib-device
mlx5_roce0
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-ib-device
mlx5_roce0
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-ib-device
mlx5_roce0
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-ib-device
mlx5_roce0
$
python
-m
sglang
.srt.
disaggregation
.mini_lb
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
$
python
-m
sglang
_router.launch_router
--pd-
disaggregation
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
```
```
### DeepSeek Multi-Node
### DeepSeek Multi-Node
...
@@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
...
@@ -100,7 +100,7 @@ pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
```
bash
```
bash
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-transfer-backend
nixl
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-transfer-backend
nixl
$
python
-m
sglang
.srt.
disaggregation
.mini_lb
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
$
python
-m
sglang
_router.launch_router
--pd-
disaggregation
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
```
```
### DeepSeek Multi-Node
### DeepSeek Multi-Node
...
@@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
...
@@ -137,7 +137,7 @@ export ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE=true
```
bash
```
bash
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-transfer-backend
ascend
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
prefill
--disaggregation-transfer-backend
ascend
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-transfer-backend
ascend
$
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--disaggregation-mode
decode
--port
30001
--base-gpu-id
1
--disaggregation-transfer-backend
ascend
$
python
-m
sglang
.srt.
disaggregation
.mini_lb
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
$
python
-m
sglang
_router.launch_router
--pd-
disaggregation
--prefill
http://127.0.0.1:30000
--decode
http://127.0.0.1:30001
--host
0.0.0.0
--port
8000
```
```
### DeepSeek Multi-Node
### DeepSeek Multi-Node
...
...
docs/advanced_features/router.md
View file @
6e95f5e5
...
@@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci
...
@@ -278,7 +278,7 @@ The most sophisticated policy that combines cache optimization with load balanci
3.
**Cache Management**
:
3.
**Cache Management**
:
-
Maintains approximate radix trees per worker
-
Maintains approximate radix trees per worker
-
Periodically evicts LRU entries based on
`--eviction-interval`
and
`--max-tree-size`
-
Periodically evicts LRU entries based on
`--eviction-interval
-secs
`
and
`--max-tree-size`
### Data Parallelism Aware Routing
### Data Parallelism Aware Routing
...
@@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
...
@@ -296,7 +296,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Core Settings
### Core Settings
| Parameter | Type | Default | Description |
| Parameter | Type | Default | Description |
|---------------------------
--|--
----
|
-----------
--|
---------------------------------------------------------------
--
|
|
---------------------------
|
----
|
-----------
|
---------------------------------------------------------------
|
|
`--host`
| str | 127.0.0.1 | Router server host address |
|
`--host`
| str | 127.0.0.1 | Router server host address |
|
`--port`
| int | 30000 | Router server port |
|
`--port`
| int | 30000 | Router server port |
|
`--worker-urls`
| list | [] | Worker URLs for separate launch mode |
|
`--worker-urls`
| list | [] | Worker URLs for separate launch mode |
...
@@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
...
@@ -307,18 +307,18 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Cache-Aware Routing Parameters
### Cache-Aware Routing Parameters
| Parameter | Type | Default | Description |
| Parameter
| Type | Default | Description |
|--------------------------
-|
-----
--|
--------
--|
------------------------------------------------------
--
|
|
--------------------------
|
-----
|
--------
|
------------------------------------------------------
|
|
`--cache-threshold`
| float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
|
`--cache-threshold`
| float | 0.5 | Minimum prefix match ratio for cache routing (0.0-1.0) |
|
`--balance-abs-threshold`
| int | 32 | Absolute load difference threshold |
|
`--balance-abs-threshold`
| int | 32 | Absolute load difference threshold |
|
`--balance-rel-threshold`
| float | 1.0001 | Relative load ratio threshold |
|
`--balance-rel-threshold`
| float | 1.0001 | Relative load ratio threshold |
|
`--eviction-interval
`
| int | 60 | Seconds between cache eviction cycles |
|
`--eviction-interval
-secs`
| int | 60 | Seconds between cache eviction cycles |
|
`--max-tree-size`
| int | 16777216 | Maximum nodes in routing tree |
|
`--max-tree-size`
| int | 16777216 | Maximum nodes in routing tree |
### Fault Tolerance Parameters
### Fault Tolerance Parameters
| Parameter | Type | Default | Description |
| Parameter | Type | Default | Description |
|----------------------------
--|--
-----
|
-------
--|
-------------------------------------
--
|
|
----------------------------
|
-----
|
-------
|
-------------------------------------
|
|
`--retry-max-retries`
| int | 3 | Maximum retry attempts per request |
|
`--retry-max-retries`
| int | 3 | Maximum retry attempts per request |
|
`--retry-initial-backoff-ms`
| int | 100 | Initial retry backoff in milliseconds |
|
`--retry-initial-backoff-ms`
| int | 100 | Initial retry backoff in milliseconds |
|
`--retry-max-backoff-ms`
| int | 10000 | Maximum retry backoff in milliseconds |
|
`--retry-max-backoff-ms`
| int | 10000 | Maximum retry backoff in milliseconds |
...
@@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
...
@@ -334,7 +334,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Prefill-Decode Disaggregation Parameters
### Prefill-Decode Disaggregation Parameters
| Parameter | Type | Default | Description |
| Parameter | Type | Default | Description |
|---------------------------------
--|
----
--|
-------
--|
-----------------------------------------------------
--
|
|
---------------------------------
|
----
|
-------
|
-----------------------------------------------------
|
|
`--pd-disaggregation`
| flag | False | Enable PD disaggregated mode |
|
`--pd-disaggregation`
| flag | False | Enable PD disaggregated mode |
|
`--prefill`
| list | [] | Prefill server URLs with optional bootstrap ports |
|
`--prefill`
| list | [] | Prefill server URLs with optional bootstrap ports |
|
`--decode`
| list | [] | Decode server URLs |
|
`--decode`
| list | [] | Decode server URLs |
...
@@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
...
@@ -346,7 +346,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Kubernetes Integration
### Kubernetes Integration
| Parameter | Type | Default | Description |
| Parameter | Type | Default | Description |
|-------------------------------
--|
----
--|
------------------------
--|
----------------------------------------------------
--
|
|
-------------------------------
|
----
|
------------------------
|
----------------------------------------------------
|
|
`--service-discovery`
| flag | False | Enable Kubernetes service discovery |
|
`--service-discovery`
| flag | False | Enable Kubernetes service discovery |
|
`--selector`
| list | [] | Label selector for workers (key1=value1 key2=value2) |
|
`--selector`
| list | [] | Label selector for workers (key1=value1 key2=value2) |
|
`--prefill-selector`
| list | [] | Label selector for prefill servers in PD mode |
|
`--prefill-selector`
| list | [] | Label selector for prefill servers in PD mode |
...
@@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
...
@@ -358,7 +358,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### Observability
### Observability
| Parameter | Type | Default | Description |
| Parameter | Type | Default | Description |
|----------------------
--|
----
--|
---------
--|
-----------------------------------------------------
--
|
|
----------------------
|
----
|
---------
|
-----------------------------------------------------
|
|
`--prometheus-port`
| int | 29000 | Prometheus metrics port |
|
`--prometheus-port`
| int | 29000 | Prometheus metrics port |
|
`--prometheus-host`
| str | 127.0.0.1 | Prometheus metrics host |
|
`--prometheus-host`
| str | 127.0.0.1 | Prometheus metrics host |
|
`--log-dir`
| str | None | Directory for log files |
|
`--log-dir`
| str | None | Directory for log files |
...
@@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
...
@@ -368,7 +368,7 @@ This mode coordinates with SGLang's DP controller for optimized request distribu
### CORS Configuration
### CORS Configuration
| Parameter | Type | Default | Description |
| Parameter | Type | Default | Description |
|------------------------
--|
----
--|
-------
--|
--------------------
--
|
|
------------------------
|
----
|
-------
|
--------------------
|
|
`--cors-allowed-origins`
| list | [] | Allowed CORS origins |
|
`--cors-allowed-origins`
| list | [] | Allowed CORS origins |
## Advanced Features
## Advanced Features
...
@@ -429,7 +429,7 @@ python -m sglang_router.launch_router \
...
@@ -429,7 +429,7 @@ python -m sglang_router.launch_router \
2.
**High latency**
: Check if cache-aware routing is causing imbalance. Try adjusting
`--balance-abs-threshold`
and
`--balance-rel-threshold`
.
2.
**High latency**
: Check if cache-aware routing is causing imbalance. Try adjusting
`--balance-abs-threshold`
and
`--balance-rel-threshold`
.
3.
**Memory growth**
: Reduce
`--max-tree-size`
or decrease
`--eviction-interval`
for more aggressive cache cleanup.
3.
**Memory growth**
: Reduce
`--max-tree-size`
or decrease
`--eviction-interval
-secs
`
for more aggressive cache cleanup.
4.
**Circuit breaker triggering frequently**
: Increase
`--cb-failure-threshold`
or extend
`--cb-window-duration-secs`
.
4.
**Circuit breaker triggering frequently**
: Increase
`--cb-failure-threshold`
or extend
`--cb-window-duration-secs`
.
...
...
docs/references/multi_node_deployment/lws_pd/lws-examples/lb.yaml
View file @
6e95f5e5
...
@@ -27,7 +27,8 @@ spec:
...
@@ -27,7 +27,8 @@ spec:
command
:
command
:
-
python
-
python
-
-m
-
-m
-
sglang.srt.disaggregation.mini_lb
-
sglang_router.launch_router
-
--pd-disaggregation
-
--prefill
-
--prefill
-
http://deepseekr10528-prefill-main:30000
-
http://deepseekr10528-prefill-main:30000
-
--decode
-
--decode
...
...
docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md
View file @
6e95f5e5
...
@@ -714,7 +714,8 @@ spec:
...
@@ -714,7 +714,8 @@ spec:
command
:
command
:
-
python
-
python
-
-m
-
-m
-
sglang.srt.disaggregation.mini_lb
-
sglang_router.launch_router
-
--pd-disaggregation
-
--prefill
-
--prefill
-
http://deepseekr10528-prefill-main:30000
-
http://deepseekr10528-prefill-main:30000
-
--decode
-
--decode
...
...
python/sglang/srt/disaggregation/launch_lb.py
deleted
100644 → 0
View file @
0e9387a9
import
argparse
import
dataclasses
from
sglang.srt.disaggregation.mini_lb
import
PrefillConfig
,
run
@
dataclasses
.
dataclass
class
LBArgs
:
host
:
str
=
"0.0.0.0"
port
:
int
=
8000
policy
:
str
=
"random"
prefill_infos
:
list
=
dataclasses
.
field
(
default_factory
=
list
)
decode_infos
:
list
=
dataclasses
.
field
(
default_factory
=
list
)
log_interval
:
int
=
5
timeout
:
int
=
600
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
LBArgs
.
host
,
help
=
f
"Host to bind the server (default:
{
LBArgs
.
host
}
)"
,
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
LBArgs
.
port
,
help
=
f
"Port to bind the server (default:
{
LBArgs
.
port
}
)"
,
)
parser
.
add_argument
(
"--policy"
,
type
=
str
,
default
=
LBArgs
.
policy
,
choices
=
[
"random"
,
"po2"
],
help
=
f
"Policy to use for load balancing (default:
{
LBArgs
.
policy
}
)"
,
)
parser
.
add_argument
(
"--prefill"
,
type
=
str
,
default
=
[],
nargs
=
"+"
,
help
=
"URLs for prefill servers"
,
)
parser
.
add_argument
(
"--decode"
,
type
=
str
,
default
=
[],
nargs
=
"+"
,
help
=
"URLs for decode servers"
,
)
parser
.
add_argument
(
"--prefill-bootstrap-ports"
,
type
=
int
,
nargs
=
"+"
,
help
=
"Bootstrap ports for prefill servers"
,
)
parser
.
add_argument
(
"--log-interval"
,
type
=
int
,
default
=
LBArgs
.
log_interval
,
help
=
f
"Log interval in seconds (default:
{
LBArgs
.
log_interval
}
)"
,
)
parser
.
add_argument
(
"--timeout"
,
type
=
int
,
default
=
LBArgs
.
timeout
,
help
=
f
"Timeout in seconds (default:
{
LBArgs
.
timeout
}
)"
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"LBArgs"
:
bootstrap_ports
=
args
.
prefill_bootstrap_ports
if
bootstrap_ports
is
None
:
bootstrap_ports
=
[
None
]
*
len
(
args
.
prefill
)
elif
len
(
bootstrap_ports
)
==
1
:
bootstrap_ports
=
bootstrap_ports
*
len
(
args
.
prefill
)
else
:
if
len
(
bootstrap_ports
)
!=
len
(
args
.
prefill
):
raise
ValueError
(
"Number of prefill URLs must match number of bootstrap ports"
)
prefill_infos
=
[
(
url
,
port
)
for
url
,
port
in
zip
(
args
.
prefill
,
bootstrap_ports
)
]
return
cls
(
host
=
args
.
host
,
port
=
args
.
port
,
policy
=
args
.
policy
,
prefill_infos
=
prefill_infos
,
decode_infos
=
args
.
decode
,
log_interval
=
args
.
log_interval
,
timeout
=
args
.
timeout
,
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"PD Disaggregation Load Balancer Server"
)
LBArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
lb_args
=
LBArgs
.
from_cli_args
(
args
)
prefill_configs
=
[
PrefillConfig
(
url
,
port
)
for
url
,
port
in
lb_args
.
prefill_infos
]
run
(
prefill_configs
,
lb_args
.
decode_infos
,
lb_args
.
host
,
lb_args
.
port
,
lb_args
.
timeout
,
)
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/disaggregation/mini_lb.py
View file @
6e95f5e5
"""
raise
RuntimeError
(
Minimal HTTP load balancer for prefill and decode servers for testing.
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
"""
We recommend installing 'sglang-router' with Rust support for optimal performance.
If you encounter issues building the router with Rust, set the environment variable
import
asyncio
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
import
dataclasses
)
import
logging
import
random
import
urllib
from
http
import
HTTPStatus
from
itertools
import
chain
from
typing
import
List
,
Optional
import
aiohttp
import
orjson
import
uvicorn
from
fastapi
import
FastAPI
,
HTTPException
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang.srt.disaggregation.utils
import
PDRegistryRequest
from
sglang.srt.utils
import
maybe_wrap_ipv6_address
AIOHTTP_STREAM_READ_CHUNK_SIZE
=
(
1024
*
64
)
# 64KB, to prevent aiohttp's "Chunk too big" error
def
setup_logger
():
logger
=
logging
.
getLogger
(
"pdlb"
)
logger
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
handler
=
logging
.
StreamHandler
()
handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
handler
)
return
logger
logger
=
setup_logger
()
@
dataclasses
.
dataclass
class
PrefillConfig
:
url
:
str
bootstrap_port
:
Optional
[
int
]
=
None
class
MiniLoadBalancer
:
def
__init__
(
self
,
prefill_configs
:
List
[
PrefillConfig
],
decode_servers
:
List
[
str
],
timeout
:
int
,
):
self
.
prefill_configs
=
prefill_configs
self
.
prefill_servers
=
[
p
.
url
for
p
in
prefill_configs
]
self
.
decode_servers
=
decode_servers
self
.
timeout
=
timeout
def
add_prefill_server
(
self
,
new_prefill_config
:
PrefillConfig
):
self
.
prefill_configs
.
append
(
new_prefill_config
)
self
.
prefill_servers
.
append
(
new_prefill_config
.
url
)
def
add_decode_server
(
self
,
new_decode_server
:
str
):
self
.
decode_servers
.
append
(
new_decode_server
)
def
select_pair
(
self
):
# TODO: return some message instead of panic
assert
len
(
self
.
prefill_configs
)
>
0
,
"No prefill servers available"
assert
len
(
self
.
decode_servers
)
>
0
,
"No decode servers available"
prefill_config
=
random
.
choice
(
self
.
prefill_configs
)
decode_server
=
random
.
choice
(
self
.
decode_servers
)
return
prefill_config
.
url
,
prefill_config
.
bootstrap_port
,
decode_server
async
def
generate
(
self
,
modified_request
,
prefill_server
,
decode_server
,
endpoint
)
->
ORJSONResponse
:
assert
endpoint
[
0
]
!=
"/"
,
f
"Endpoint should not start with '/':
{
endpoint
}
"
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
self
.
timeout
)
# Add timeout for request reliability
)
as
session
:
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
if
"return_logprob"
in
modified_request
:
prefill_json
=
await
prefill_response
.
json
()
ret_json
=
await
decode_response
.
json
()
# merge `meta_info.input_token_logprobs` from prefill to decode
if
"meta_info"
in
ret_json
:
if
"input_token_logprobs"
in
ret_json
[
"meta_info"
]:
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
prefill_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
else
:
ret_json
=
await
decode_response
.
json
()
return
ORJSONResponse
(
content
=
ret_json
,
status_code
=
decode_response
.
status
,
)
async
def
generate_stream
(
self
,
modified_request
,
prefill_server
,
decode_server
,
endpoint
=
"generate"
):
assert
endpoint
[
0
]
!=
"/"
,
f
"Endpoint should not start with '/':
{
endpoint
}
"
async
def
stream_results
():
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
self
.
timeout
)
# Add timeout for request reliability
)
as
session
:
# Create the tasks for both prefill and decode requests
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
if
modified_request
.
get
(
"return_logprob"
,
False
):
prefill_chunks
=
[]
async
for
chunk
in
prefill_response
.
content
:
prefill_chunks
.
append
(
chunk
)
first_prefill_chunk
=
(
prefill_chunks
[
0
].
decode
(
"utf-8"
)[
5
:].
strip
(
"
\n
"
)
)
first_prefill_chunk_json
=
orjson
.
loads
(
first_prefill_chunk
)
async
for
chunk
in
decode_response
.
content
:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk
=
chunk
.
decode
(
"utf-8"
)
if
(
decoded_chunk
and
decoded_chunk
.
startswith
(
"data:"
)
and
"[DONE]"
not
in
decoded_chunk
):
ret_json
=
orjson
.
loads
(
decoded_chunk
[
5
:].
strip
(
"
\n
"
))
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
first_prefill_chunk_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
yield
b
"data: "
+
orjson
.
dumps
(
ret_json
)
+
b
"
\n\n
"
else
:
yield
chunk
else
:
async
for
chunk
in
decode_response
.
content
.
iter_chunked
(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield
chunk
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
,
)
app
=
FastAPI
()
load_balancer
:
Optional
[
MiniLoadBalancer
]
=
None
@
app
.
get
(
"/health"
)
async
def
health_check
():
return
Response
(
status_code
=
200
)
@
app
.
get
(
"/health_generate"
)
async
def
health_generate
():
prefill_servers
,
decode_servers
=
(
load_balancer
.
prefill_servers
,
load_balancer
.
decode_servers
,
)
async
with
aiohttp
.
ClientSession
()
as
session
:
# Create the tasks
tasks
=
[]
for
server
in
chain
(
prefill_servers
,
decode_servers
):
tasks
.
append
(
session
.
get
(
f
"
{
server
}
/health_generate"
))
for
i
,
response
in
enumerate
(
asyncio
.
as_completed
(
tasks
)):
await
response
return
Response
(
status_code
=
200
)
@
app
.
post
(
"/flush_cache"
)
async
def
flush_cache
():
prefill_servers
,
decode_servers
=
(
load_balancer
.
prefill_servers
,
load_balancer
.
decode_servers
,
)
async
with
aiohttp
.
ClientSession
()
as
session
:
# Create the tasks
tasks
=
[]
for
server
in
chain
(
prefill_servers
,
decode_servers
):
tasks
.
append
(
session
.
post
(
f
"
{
server
}
/flush_cache"
))
for
i
,
response
in
enumerate
(
asyncio
.
as_completed
(
tasks
)):
await
response
return
Response
(
status_code
=
200
)
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_info
():
prefill_servers
,
decode_servers
=
(
load_balancer
.
prefill_servers
,
load_balancer
.
decode_servers
,
)
prefill_infos
=
[]
decode_infos
=
[]
all_internal_states
=
[]
async
with
aiohttp
.
ClientSession
()
as
session
:
for
server
in
chain
(
prefill_servers
):
server_info
=
await
session
.
get
(
f
"
{
server
}
/get_server_info"
)
prefill_infos
.
append
(
await
server_info
.
json
())
for
server
in
chain
(
decode_servers
):
server_info
=
await
session
.
get
(
f
"
{
server
}
/get_server_info"
)
info_json
=
await
server_info
.
json
()
decode_infos
.
append
(
info_json
)
# Extract internal_states from decode servers
if
"internal_states"
in
info_json
:
all_internal_states
.
extend
(
info_json
[
"internal_states"
])
# Return format expected by bench_one_batch_server.py
if
all_internal_states
:
return
{
"internal_states"
:
all_internal_states
,
"prefill"
:
prefill_infos
,
"decode"
:
decode_infos
,
}
else
:
# Fallback with dummy data if no internal states found
return
{
"internal_states"
:
[
{
"last_gen_throughput"
:
0.0
,
"avg_spec_accept_length"
:
None
,
}
],
"prefill"
:
prefill_infos
,
"decode"
:
decode_infos
,
}
@
app
.
get
(
"/get_model_info"
)
async
def
get_model_info
():
global
load_balancer
if
not
load_balancer
or
not
load_balancer
.
prefill_servers
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
SERVICE_UNAVAILABLE
,
detail
=
"There is no server registered"
,
)
target_server_url
=
load_balancer
.
prefill_servers
[
0
]
endpoint_url
=
f
"
{
target_server_url
}
/get_model_info"
async
with
aiohttp
.
ClientSession
()
as
session
:
try
:
async
with
session
.
get
(
endpoint_url
)
as
response
:
if
response
.
status
!=
200
:
error_text
=
await
response
.
text
()
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_GATEWAY
,
detail
=
(
f
"Failed to get model info from
{
target_server_url
}
"
f
"Status:
{
response
.
status
}
, Response:
{
error_text
}
"
),
)
model_info_json
=
await
response
.
json
()
return
ORJSONResponse
(
content
=
model_info_json
)
except
aiohttp
.
ClientError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
SERVICE_UNAVAILABLE
,
detail
=
f
"Failed to get model info from backend"
,
)
@
app
.
post
(
"/generate"
)
async
def
handle_generate_request
(
request_data
:
dict
):
prefill_server
,
bootstrap_port
,
decode_server
=
load_balancer
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
hostname
=
maybe_wrap_ipv6_address
(
parsed_url
.
hostname
)
modified_request
=
request_data
.
copy
()
batch_size
=
_get_request_batch_size
(
modified_request
)
if
batch_size
is
not
None
:
modified_request
.
update
(
{
"bootstrap_host"
:
[
hostname
]
*
batch_size
,
"bootstrap_port"
:
[
bootstrap_port
]
*
batch_size
,
"bootstrap_room"
:
[
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
],
}
)
else
:
modified_request
.
update
(
{
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
_generate_bootstrap_room
(),
}
)
if
request_data
.
get
(
"stream"
,
False
):
return
await
load_balancer
.
generate_stream
(
modified_request
,
prefill_server
,
decode_server
,
"generate"
)
else
:
return
await
load_balancer
.
generate
(
modified_request
,
prefill_server
,
decode_server
,
"generate"
)
async
def
_forward_to_backend
(
request_data
:
dict
,
endpoint_name
:
str
):
prefill_server
,
bootstrap_port
,
decode_server
=
load_balancer
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
hostname
=
maybe_wrap_ipv6_address
(
parsed_url
.
hostname
)
modified_request
=
request_data
.
copy
()
modified_request
.
update
(
{
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
_generate_bootstrap_room
(),
}
)
if
request_data
.
get
(
"stream"
,
False
):
return
await
load_balancer
.
generate_stream
(
modified_request
,
prefill_server
,
decode_server
,
endpoint
=
endpoint_name
,
)
else
:
return
await
load_balancer
.
generate
(
modified_request
,
prefill_server
,
decode_server
,
endpoint
=
endpoint_name
,
)
@
app
.
post
(
"/v1/chat/completions"
)
async
def
handle_chat_completion_request
(
request_data
:
dict
):
return
await
_forward_to_backend
(
request_data
,
"v1/chat/completions"
)
@
app
.
post
(
"/v1/completions"
)
async
def
handle_completion_request
(
request_data
:
dict
):
return
await
_forward_to_backend
(
request_data
,
"v1/completions"
)
def
_generate_bootstrap_room
():
return
random
.
randint
(
0
,
2
**
63
-
1
)
# We may utilize `GenerateReqInput`'s logic later
def
_get_request_batch_size
(
request
):
if
(
text
:
=
request
.
get
(
"text"
))
is
not
None
:
return
None
if
isinstance
(
text
,
str
)
else
len
(
text
)
if
(
input_ids
:
=
request
.
get
(
"input_ids"
))
is
not
None
:
return
None
if
isinstance
(
input_ids
[
0
],
int
)
else
len
(
input_ids
)
return
None
@
app
.
get
(
"/v1/models"
)
async
def
get_models
():
prefill_server
=
load_balancer
.
prefill_servers
[
0
]
# Get the first prefill server
async
with
aiohttp
.
ClientSession
()
as
session
:
try
:
response
=
await
session
.
get
(
f
"
{
prefill_server
}
/v1/models"
)
if
response
.
status
!=
200
:
raise
HTTPException
(
status_code
=
response
.
status
,
detail
=
f
"Prefill server error: Status
{
response
.
status
}
"
,
)
return
ORJSONResponse
(
content
=
await
response
.
json
())
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
str
(
e
))
@
app
.
post
(
"/register"
)
async
def
register
(
obj
:
PDRegistryRequest
):
if
obj
.
mode
==
"prefill"
:
load_balancer
.
add_prefill_server
(
PrefillConfig
(
obj
.
registry_url
,
obj
.
bootstrap_port
)
)
logger
.
info
(
f
"Registered prefill server:
{
obj
.
registry_url
}
with bootstrap port:
{
obj
.
bootstrap_port
}
"
)
elif
obj
.
mode
==
"decode"
:
load_balancer
.
add_decode_server
(
obj
.
registry_url
)
logger
.
info
(
f
"Registered decode server:
{
obj
.
registry_url
}
"
)
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"Invalid mode. Must be either PREFILL or DECODE."
,
)
logger
.
info
(
f
"#Prefill servers:
{
len
(
load_balancer
.
prefill_configs
)
}
, "
f
"#Decode servers:
{
len
(
load_balancer
.
decode_servers
)
}
"
)
return
Response
(
status_code
=
200
)
def
run
(
prefill_configs
,
decode_addrs
,
host
,
port
,
timeout
):
global
load_balancer
load_balancer
=
MiniLoadBalancer
(
prefill_configs
,
decode_addrs
,
timeout
=
timeout
)
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
)
if
__name__
==
"__main__"
:
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
from
sglang.srt.disaggregation.launch_lb
import
main
main
()
python/sglang/srt/disaggregation/utils.py
View file @
6e95f5e5
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
os
import
os
import
random
import
random
import
threading
import
warnings
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
numpy
as
np
import
requests
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt.utils
import
get_ip
,
is_npu
from
sglang.srt.utils
import
is_npu
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
...
@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
...
@@ -305,49 +301,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
return
(
num_kv_indices
+
page_size
-
1
)
//
page_size
return
(
num_kv_indices
+
page_size
-
1
)
//
page_size
#########################
# PDLB Registry
#########################
@
dataclasses
.
dataclass
class
PDRegistryRequest
:
"""A request to register a machine itself to the LB."""
mode
:
str
registry_url
:
str
bootstrap_port
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
if
self
.
mode
==
"prefill"
and
self
.
bootstrap_port
is
None
:
raise
ValueError
(
"Bootstrap port must be set in PREFILL mode."
)
elif
self
.
mode
==
"decode"
and
self
.
bootstrap_port
is
not
None
:
raise
ValueError
(
"Bootstrap port must not be set in DECODE mode."
)
elif
self
.
mode
not
in
[
"prefill"
,
"decode"
]:
raise
ValueError
(
f
"Invalid mode:
{
self
.
mode
}
. Must be 'prefill' or 'decode'."
)
def
register_disaggregation_server
(
mode
:
str
,
server_port
:
int
,
bootstrap_port
:
int
,
pdlb_url
:
str
):
boostrap_port
=
bootstrap_port
if
mode
==
"prefill"
else
None
registry_request
=
PDRegistryRequest
(
mode
=
mode
,
registry_url
=
f
"http://
{
get_ip
()
}
:
{
server_port
}
"
,
bootstrap_port
=
boostrap_port
,
)
res
=
requests
.
post
(
f
"
{
pdlb_url
}
/register"
,
json
=
dataclasses
.
asdict
(
registry_request
),
)
if
res
.
status_code
!=
200
:
warnings
.
warn
(
f
"Failed to register disaggregation server:
{
res
.
status_code
}
{
res
.
text
}
"
)
#########################
#########################
# Misc
# Misc
#########################
#########################
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
6e95f5e5
...
@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError
...
@@ -47,11 +47,7 @@ from fastapi.exceptions import RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
FAKE_BOOTSTRAP_HOST
,
DisaggregationMode
FAKE_BOOTSTRAP_HOST
,
DisaggregationMode
,
register_disaggregation_server
,
)
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.entrypoints.openai.protocol
import
(
from
sglang.srt.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionRequest
,
...
@@ -1405,13 +1401,5 @@ def _wait_and_warmup(
...
@@ -1405,13 +1401,5 @@ def _wait_and_warmup(
if
server_args
.
debug_tensor_dump_input_file
:
if
server_args
.
debug_tensor_dump_input_file
:
kill_process_tree
(
os
.
getpid
())
kill_process_tree
(
os
.
getpid
())
if
server_args
.
pdlb_url
is
not
None
:
register_disaggregation_server
(
server_args
.
disaggregation_mode
,
server_args
.
port
,
server_args
.
disaggregation_bootstrap_port
,
server_args
.
pdlb_url
,
)
if
launch_callback
is
not
None
:
if
launch_callback
is
not
None
:
launch_callback
()
launch_callback
()
python/sglang/srt/server_args.py
View file @
6e95f5e5
...
@@ -367,7 +367,6 @@ class ServerArgs:
...
@@ -367,7 +367,6 @@ class ServerArgs:
disaggregation_prefill_pp
:
Optional
[
int
]
=
1
disaggregation_prefill_pp
:
Optional
[
int
]
=
1
disaggregation_ib_device
:
Optional
[
str
]
=
None
disaggregation_ib_device
:
Optional
[
str
]
=
None
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
pdlb_url
:
Optional
[
str
]
=
None
# For model weight update
# For model weight update
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
...
@@ -2071,12 +2070,6 @@ class ServerArgs:
...
@@ -2071,12 +2070,6 @@ class ServerArgs:
default
=
ServerArgs
.
num_reserved_decode_tokens
,
default
=
ServerArgs
.
num_reserved_decode_tokens
,
help
=
"Number of decode tokens that will have memory reserved when adding new request to the running batch."
,
help
=
"Number of decode tokens that will have memory reserved when adding new request to the running batch."
,
)
)
parser
.
add_argument
(
"--pdlb-url"
,
type
=
str
,
default
=
None
,
help
=
"The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer."
,
)
# Custom weight loader
# Custom weight loader
parser
.
add_argument
(
parser
.
add_argument
(
...
...
python/sglang/test/test_utils.py
View file @
6e95f5e5
...
@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str):
...
@@ -466,6 +466,25 @@ def try_cached_model(model_repo: str):
return
model_dir
if
model_dir
else
model_repo
return
model_dir
if
model_dir
else
model_repo
def
popen_with_error_check
(
command
:
list
[
str
],
allow_exit
:
bool
=
False
):
process
=
subprocess
.
Popen
(
command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
def
_run_and_check
():
stdout
,
stderr
=
process
.
communicate
()
while
process
.
poll
()
is
None
:
time
.
sleep
(
5
)
if
not
allow_exit
or
process
.
returncode
!=
0
:
raise
Exception
(
f
"
{
command
}
exited with code
{
process
.
returncode
}
\n
{
stdout
=
}
\n
{
stderr
=
}
"
)
t
=
threading
.
Thread
(
target
=
_run_and_check
)
t
.
start
()
return
process
def
popen_launch_server
(
def
popen_launch_server
(
model
:
str
,
model
:
str
,
base_url
:
str
,
base_url
:
str
,
...
...
scripts/ci/ci_install_dependency.sh
View file @
6e95f5e5
...
@@ -45,6 +45,10 @@ fi
...
@@ -45,6 +45,10 @@ fi
# Install the main package
# Install the main package
$PIP_CMD
install
-e
"python[dev]"
--extra-index-url
https://download.pytorch.org/whl/
${
CU_VERSION
}
$PIP_INSTALL_SUFFIX
$PIP_CMD
install
-e
"python[dev]"
--extra-index-url
https://download.pytorch.org/whl/
${
CU_VERSION
}
$PIP_INSTALL_SUFFIX
# Install router for pd-disagg test
SGLANG_ROUTER_BUILD_NO_RUST
=
1
$PIP_CMD
install
-e
"sgl-router"
$PIP_INSTALL_SUFFIX
if
[
"
$IS_BLACKWELL
"
=
"1"
]
;
then
if
[
"
$IS_BLACKWELL
"
=
"1"
]
;
then
# TODO auto determine sgl-kernel version
# TODO auto determine sgl-kernel version
SGL_KERNEL_VERSION
=
0.3.8
SGL_KERNEL_VERSION
=
0.3.8
...
...
sgl-router/py_src/sglang_router/__init__.py
View file @
6e95f5e5
# a lightweihgt wrapper on router with argument type and comments
# no wrapper on policy type => direct export
from
sglang_router.router
import
Router
from
sglang_router.version
import
__version__
from
sglang_router.version
import
__version__
from
sglang_router_rs
import
PolicyType
__all__
=
[
"Router"
,
"PolicyType"
,
"__version__"
]
__all__
=
[
"__version__"
]
sgl-router/py_src/sglang_router/launch_router.py
View file @
6e95f5e5
This diff is collapsed.
Click to expand it.
sgl-router/py_src/sglang_router/mini_lb.py
0 → 100644
View file @
6e95f5e5
"""
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
import
asyncio
import
ipaddress
import
logging
import
random
import
urllib
from
http
import
HTTPStatus
from
itertools
import
chain
from
typing
import
Optional
import
aiohttp
import
orjson
import
uvicorn
from
fastapi
import
FastAPI
,
HTTPException
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang_router.router_args
import
RouterArgs
logger
=
logging
.
getLogger
(
__name__
)
AIOHTTP_STREAM_READ_CHUNK_SIZE
=
(
1024
*
64
)
# 64KB, to prevent aiohttp's "Chunk too big" error
def
maybe_wrap_ipv6_address
(
address
:
str
)
->
str
:
try
:
ipaddress
.
IPv6Address
(
address
)
return
f
"[
{
address
}
]"
except
ValueError
:
return
address
class
MiniLoadBalancer
:
def
__init__
(
self
,
router_args
:
RouterArgs
,
):
self
.
_validate_router_args
(
router_args
)
self
.
host
=
router_args
.
host
self
.
port
=
router_args
.
port
self
.
timeout
=
router_args
.
request_timeout_secs
self
.
prefill_urls
=
[
url
[
0
]
for
url
in
router_args
.
prefill_urls
]
self
.
prefill_bootstrap_ports
=
[
url
[
1
]
for
url
in
router_args
.
prefill_urls
]
self
.
decode_urls
=
router_args
.
decode_urls
def
_validate_router_args
(
self
,
router_args
:
RouterArgs
):
logger
.
warning
(
"
\x1b
[33mMiniLB is only for debugging purposes, it only supports random policy!
\033
[0m"
)
# NOTE: too many arguments unsupported, just validate some important ones
if
router_args
.
policy
!=
"random"
:
logger
.
warning
(
"[MiniLB] Overriding policy to random"
)
router_args
.
policy
=
"random"
if
not
router_args
.
pd_disaggregation
:
raise
ValueError
(
"MiniLB only supports PD disaggregation mode"
)
if
len
(
router_args
.
prefill_urls
)
==
0
or
len
(
router_args
.
decode_urls
)
==
0
:
raise
ValueError
(
"MiniLB requires at least one prefill and one decode server"
)
def
start
(
self
):
global
lb
lb
=
self
uvicorn
.
run
(
app
,
host
=
self
.
host
,
port
=
self
.
port
)
def
select_pair
(
self
):
assert
len
(
self
.
prefill_urls
)
>
0
,
"No prefill servers available"
assert
len
(
self
.
decode_urls
)
>
0
,
"No decode servers available"
pidx
=
random
.
randint
(
0
,
len
(
self
.
prefill_urls
)
-
1
)
didx
=
random
.
randint
(
0
,
len
(
self
.
decode_urls
)
-
1
)
return
(
self
.
prefill_urls
[
pidx
],
self
.
prefill_bootstrap_ports
[
pidx
],
self
.
decode_urls
[
didx
],
)
async
def
generate
(
self
,
modified_request
,
prefill_server
,
decode_server
,
endpoint
)
->
ORJSONResponse
:
assert
endpoint
[
0
]
!=
"/"
,
f
"Endpoint should not start with '/':
{
endpoint
}
"
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
self
.
timeout
)
# Add timeout for request reliability
)
as
session
:
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
if
"return_logprob"
in
modified_request
:
prefill_json
=
await
prefill_response
.
json
()
ret_json
=
await
decode_response
.
json
()
# merge `meta_info.input_token_logprobs` from prefill to decode
if
"meta_info"
in
ret_json
:
if
"input_token_logprobs"
in
ret_json
[
"meta_info"
]:
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
prefill_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
else
:
ret_json
=
await
decode_response
.
json
()
return
ORJSONResponse
(
content
=
ret_json
,
status_code
=
decode_response
.
status
,
)
async
def
generate_stream
(
self
,
modified_request
,
prefill_server
,
decode_server
,
endpoint
=
"generate"
):
assert
endpoint
[
0
]
!=
"/"
,
f
"Endpoint should not start with '/':
{
endpoint
}
"
async
def
stream_results
():
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
self
.
timeout
)
# Add timeout for request reliability
)
as
session
:
# Create the tasks for both prefill and decode requests
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
if
modified_request
.
get
(
"return_logprob"
,
False
):
prefill_chunks
=
[]
async
for
chunk
in
prefill_response
.
content
:
prefill_chunks
.
append
(
chunk
)
first_prefill_chunk
=
(
prefill_chunks
[
0
].
decode
(
"utf-8"
)[
5
:].
strip
(
"
\n
"
)
)
first_prefill_chunk_json
=
orjson
.
loads
(
first_prefill_chunk
)
async
for
chunk
in
decode_response
.
content
:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk
=
chunk
.
decode
(
"utf-8"
)
if
(
decoded_chunk
and
decoded_chunk
.
startswith
(
"data:"
)
and
"[DONE]"
not
in
decoded_chunk
):
ret_json
=
orjson
.
loads
(
decoded_chunk
[
5
:].
strip
(
"
\n
"
))
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
first_prefill_chunk_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
yield
b
"data: "
+
orjson
.
dumps
(
ret_json
)
+
b
"
\n\n
"
else
:
yield
chunk
else
:
async
for
chunk
in
decode_response
.
content
.
iter_chunked
(
AIOHTTP_STREAM_READ_CHUNK_SIZE
):
yield
chunk
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
,
)
app
=
FastAPI
()
lb
:
Optional
[
MiniLoadBalancer
]
=
None
@
app
.
get
(
"/health"
)
async
def
health_check
():
return
Response
(
status_code
=
200
)
@
app
.
get
(
"/health_generate"
)
async
def
health_generate
():
async
with
aiohttp
.
ClientSession
()
as
session
:
# Create the tasks
tasks
=
[]
for
server
in
chain
(
lb
.
prefill_urls
,
lb
.
decode_urls
):
tasks
.
append
(
session
.
get
(
f
"
{
server
}
/health_generate"
))
for
i
,
response
in
enumerate
(
asyncio
.
as_completed
(
tasks
)):
await
response
return
Response
(
status_code
=
200
)
@
app
.
post
(
"/flush_cache"
)
async
def
flush_cache
():
async
with
aiohttp
.
ClientSession
()
as
session
:
# Create the tasks
tasks
=
[]
for
server
in
chain
(
lb
.
prefill_urls
,
lb
.
decode_urls
):
tasks
.
append
(
session
.
post
(
f
"
{
server
}
/flush_cache"
))
for
i
,
response
in
enumerate
(
asyncio
.
as_completed
(
tasks
)):
await
response
return
Response
(
status_code
=
200
)
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_info
():
prefill_infos
=
[]
decode_infos
=
[]
all_internal_states
=
[]
async
with
aiohttp
.
ClientSession
()
as
session
:
for
server
in
lb
.
prefill_urls
:
server_info
=
await
session
.
get
(
f
"
{
server
}
/get_server_info"
)
prefill_infos
.
append
(
await
server_info
.
json
())
for
server
in
lb
.
decode_urls
:
server_info
=
await
session
.
get
(
f
"
{
server
}
/get_server_info"
)
info_json
=
await
server_info
.
json
()
decode_infos
.
append
(
info_json
)
# Extract internal_states from decode servers
if
"internal_states"
in
info_json
:
all_internal_states
.
extend
(
info_json
[
"internal_states"
])
# Return format expected by bench_one_batch_server.py
if
all_internal_states
:
return
{
"internal_states"
:
all_internal_states
,
"prefill"
:
prefill_infos
,
"decode"
:
decode_infos
,
}
else
:
# Fallback with dummy data if no internal states found
return
{
"internal_states"
:
[
{
"last_gen_throughput"
:
0.0
,
"avg_spec_accept_length"
:
None
,
}
],
"prefill"
:
prefill_infos
,
"decode"
:
decode_infos
,
}
@
app
.
get
(
"/get_model_info"
)
async
def
get_model_info
():
if
not
lb
or
not
lb
.
prefill_urls
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
SERVICE_UNAVAILABLE
,
detail
=
"There is no server registered"
,
)
target_server_url
=
lb
.
prefill_urls
[
0
]
endpoint_url
=
f
"
{
target_server_url
}
/get_model_info"
async
with
aiohttp
.
ClientSession
()
as
session
:
try
:
async
with
session
.
get
(
endpoint_url
)
as
response
:
if
response
.
status
!=
200
:
error_text
=
await
response
.
text
()
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_GATEWAY
,
detail
=
(
f
"Failed to get model info from
{
target_server_url
}
"
f
"Status:
{
response
.
status
}
, Response:
{
error_text
}
"
),
)
model_info_json
=
await
response
.
json
()
return
ORJSONResponse
(
content
=
model_info_json
)
except
aiohttp
.
ClientError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
SERVICE_UNAVAILABLE
,
detail
=
f
"Failed to get model info from backend"
,
)
@
app
.
post
(
"/generate"
)
async
def
handle_generate_request
(
request_data
:
dict
):
prefill_server
,
bootstrap_port
,
decode_server
=
lb
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
hostname
=
maybe_wrap_ipv6_address
(
parsed_url
.
hostname
)
modified_request
=
request_data
.
copy
()
batch_size
=
_get_request_batch_size
(
modified_request
)
if
batch_size
is
not
None
:
modified_request
.
update
(
{
"bootstrap_host"
:
[
hostname
]
*
batch_size
,
"bootstrap_port"
:
[
bootstrap_port
]
*
batch_size
,
"bootstrap_room"
:
[
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
],
}
)
else
:
modified_request
.
update
(
{
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
_generate_bootstrap_room
(),
}
)
if
request_data
.
get
(
"stream"
,
False
):
return
await
lb
.
generate_stream
(
modified_request
,
prefill_server
,
decode_server
,
"generate"
)
else
:
return
await
lb
.
generate
(
modified_request
,
prefill_server
,
decode_server
,
"generate"
)
async
def
_forward_to_backend
(
request_data
:
dict
,
endpoint_name
:
str
):
prefill_server
,
bootstrap_port
,
decode_server
=
lb
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
hostname
=
maybe_wrap_ipv6_address
(
parsed_url
.
hostname
)
modified_request
=
request_data
.
copy
()
modified_request
.
update
(
{
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
_generate_bootstrap_room
(),
}
)
if
request_data
.
get
(
"stream"
,
False
):
return
await
lb
.
generate_stream
(
modified_request
,
prefill_server
,
decode_server
,
endpoint
=
endpoint_name
,
)
else
:
return
await
lb
.
generate
(
modified_request
,
prefill_server
,
decode_server
,
endpoint
=
endpoint_name
,
)
@
app
.
post
(
"/v1/chat/completions"
)
async
def
handle_chat_completion_request
(
request_data
:
dict
):
return
await
_forward_to_backend
(
request_data
,
"v1/chat/completions"
)
@
app
.
post
(
"/v1/completions"
)
async
def
handle_completion_request
(
request_data
:
dict
):
return
await
_forward_to_backend
(
request_data
,
"v1/completions"
)
def
_generate_bootstrap_room
():
return
random
.
randint
(
0
,
2
**
63
-
1
)
# We may utilize `GenerateReqInput`'s logic later
def
_get_request_batch_size
(
request
):
if
(
text
:
=
request
.
get
(
"text"
))
is
not
None
:
return
None
if
isinstance
(
text
,
str
)
else
len
(
text
)
if
(
input_ids
:
=
request
.
get
(
"input_ids"
))
is
not
None
:
return
None
if
isinstance
(
input_ids
[
0
],
int
)
else
len
(
input_ids
)
return
None
@
app
.
get
(
"/v1/models"
)
async
def
get_models
():
prefill_server
=
lb
.
prefill_urls
[
0
]
# Get the first prefill server
async
with
aiohttp
.
ClientSession
()
as
session
:
try
:
response
=
await
session
.
get
(
f
"
{
prefill_server
}
/v1/models"
)
if
response
.
status
!=
200
:
raise
HTTPException
(
status_code
=
response
.
status
,
detail
=
f
"Prefill server error: Status
{
response
.
status
}
"
,
)
return
ORJSONResponse
(
content
=
await
response
.
json
())
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
str
(
e
))
sgl-router/py_src/sglang_router/router.py
View file @
6e95f5e5
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
from
sglang_router.router_args
import
RouterArgs
from
sglang_router_rs
import
PolicyType
from
sglang_router_rs
import
PolicyType
from
sglang_router_rs
import
Router
as
_Router
from
sglang_router_rs
import
Router
as
_Router
def
policy_from_str
(
policy_str
:
Optional
[
str
])
->
PolicyType
:
"""Convert policy string to PolicyType enum."""
if
policy_str
is
None
:
return
None
policy_map
=
{
"random"
:
PolicyType
.
Random
,
"round_robin"
:
PolicyType
.
RoundRobin
,
"cache_aware"
:
PolicyType
.
CacheAware
,
"power_of_two"
:
PolicyType
.
PowerOfTwo
,
}
return
policy_map
[
policy_str
]
class
Router
:
class
Router
:
"""
"""
A high-performance router for distributing requests across worker nodes.
A high-performance router for distributing requests across worker nodes.
...
@@ -78,130 +92,34 @@ class Router:
...
@@ -78,130 +92,34 @@ class Router:
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
"""
"""
def
__init__
(
def
__init__
(
self
,
router
:
_Router
):
self
,
self
.
_router
=
router
worker_urls
:
List
[
str
],
policy
:
PolicyType
=
PolicyType
.
RoundRobin
,
host
:
str
=
"127.0.0.1"
,
port
:
int
=
3001
,
worker_startup_timeout_secs
:
int
=
600
,
worker_startup_check_interval
:
int
=
30
,
cache_threshold
:
float
=
0.3
,
balance_abs_threshold
:
int
=
64
,
balance_rel_threshold
:
float
=
1.5
,
eviction_interval_secs
:
int
=
120
,
max_tree_size
:
int
=
2
**
26
,
max_payload_size
:
int
=
512
*
1024
*
1024
,
# 512MB
dp_aware
:
bool
=
False
,
api_key
:
Optional
[
str
]
=
None
,
log_dir
:
Optional
[
str
]
=
None
,
log_level
:
Optional
[
str
]
=
None
,
service_discovery
:
bool
=
False
,
selector
:
Dict
[
str
,
str
]
=
None
,
service_discovery_port
:
int
=
80
,
service_discovery_namespace
:
Optional
[
str
]
=
None
,
prefill_selector
:
Dict
[
str
,
str
]
=
None
,
decode_selector
:
Dict
[
str
,
str
]
=
None
,
bootstrap_port_annotation
:
str
=
"sglang.ai/bootstrap-port"
,
prometheus_port
:
Optional
[
int
]
=
None
,
prometheus_host
:
Optional
[
str
]
=
None
,
request_timeout_secs
:
int
=
1800
,
request_id_headers
:
Optional
[
List
[
str
]]
=
None
,
pd_disaggregation
:
bool
=
False
,
prefill_urls
:
Optional
[
List
[
tuple
]]
=
None
,
decode_urls
:
Optional
[
List
[
str
]]
=
None
,
prefill_policy
:
Optional
[
PolicyType
]
=
None
,
decode_policy
:
Optional
[
PolicyType
]
=
None
,
max_concurrent_requests
:
int
=
256
,
queue_size
:
int
=
100
,
queue_timeout_secs
:
int
=
60
,
rate_limit_tokens_per_second
:
Optional
[
int
]
=
None
,
cors_allowed_origins
:
List
[
str
]
=
None
,
retry_max_retries
:
int
=
5
,
retry_initial_backoff_ms
:
int
=
50
,
retry_max_backoff_ms
:
int
=
30_000
,
retry_backoff_multiplier
:
float
=
1.5
,
retry_jitter_factor
:
float
=
0.2
,
cb_failure_threshold
:
int
=
10
,
cb_success_threshold
:
int
=
3
,
cb_timeout_duration_secs
:
int
=
60
,
cb_window_duration_secs
:
int
=
120
,
disable_retries
:
bool
=
False
,
disable_circuit_breaker
:
bool
=
False
,
health_failure_threshold
:
int
=
3
,
health_success_threshold
:
int
=
2
,
health_check_timeout_secs
:
int
=
5
,
health_check_interval_secs
:
int
=
60
,
health_check_endpoint
:
str
=
"/health"
,
model_path
:
Optional
[
str
]
=
None
,
tokenizer_path
:
Optional
[
str
]
=
None
,
):
if
selector
is
None
:
selector
=
{}
if
prefill_selector
is
None
:
prefill_selector
=
{}
if
decode_selector
is
None
:
decode_selector
=
{}
if
cors_allowed_origins
is
None
:
cors_allowed_origins
=
[]
self
.
_router
=
_Router
(
@
staticmethod
worker_urls
=
worker_urls
,
def
from_args
(
args
:
RouterArgs
)
->
"Router"
:
policy
=
policy
,
"""Create a router from a RouterArgs instance."""
host
=
host
,
port
=
port
,
args_dict
=
vars
(
args
)
worker_startup_timeout_secs
=
worker_startup_timeout_secs
,
# Convert RouterArgs to _Router parameters
worker_startup_check_interval
=
worker_startup_check_interval
,
args_dict
[
"worker_urls"
]
=
(
cache_threshold
=
cache_threshold
,
[]
balance_abs_threshold
=
balance_abs_threshold
,
if
args_dict
[
"service_discovery"
]
or
args_dict
[
"pd_disaggregation"
]
balance_rel_threshold
=
balance_rel_threshold
,
else
args_dict
[
"worker_urls"
]
eviction_interval_secs
=
eviction_interval_secs
,
)
max_tree_size
=
max_tree_size
,
args_dict
[
"policy"
]
=
policy_from_str
(
args_dict
[
"policy"
])
max_payload_size
=
max_payload_size
,
args_dict
[
"prefill_urls"
]
=
(
dp_aware
=
dp_aware
,
args_dict
[
"prefill_urls"
]
if
args_dict
[
"pd_disaggregation"
]
else
None
api_key
=
api_key
,
log_dir
=
log_dir
,
log_level
=
log_level
,
service_discovery
=
service_discovery
,
selector
=
selector
,
service_discovery_port
=
service_discovery_port
,
service_discovery_namespace
=
service_discovery_namespace
,
prefill_selector
=
prefill_selector
,
decode_selector
=
decode_selector
,
bootstrap_port_annotation
=
bootstrap_port_annotation
,
prometheus_port
=
prometheus_port
,
prometheus_host
=
prometheus_host
,
request_timeout_secs
=
request_timeout_secs
,
request_id_headers
=
request_id_headers
,
pd_disaggregation
=
pd_disaggregation
,
prefill_urls
=
prefill_urls
,
decode_urls
=
decode_urls
,
prefill_policy
=
prefill_policy
,
decode_policy
=
decode_policy
,
max_concurrent_requests
=
max_concurrent_requests
,
queue_size
=
queue_size
,
queue_timeout_secs
=
queue_timeout_secs
,
rate_limit_tokens_per_second
=
rate_limit_tokens_per_second
,
cors_allowed_origins
=
cors_allowed_origins
,
retry_max_retries
=
retry_max_retries
,
retry_initial_backoff_ms
=
retry_initial_backoff_ms
,
retry_max_backoff_ms
=
retry_max_backoff_ms
,
retry_backoff_multiplier
=
retry_backoff_multiplier
,
retry_jitter_factor
=
retry_jitter_factor
,
cb_failure_threshold
=
cb_failure_threshold
,
cb_success_threshold
=
cb_success_threshold
,
cb_timeout_duration_secs
=
cb_timeout_duration_secs
,
cb_window_duration_secs
=
cb_window_duration_secs
,
disable_retries
=
disable_retries
,
disable_circuit_breaker
=
disable_circuit_breaker
,
health_failure_threshold
=
health_failure_threshold
,
health_success_threshold
=
health_success_threshold
,
health_check_timeout_secs
=
health_check_timeout_secs
,
health_check_interval_secs
=
health_check_interval_secs
,
health_check_endpoint
=
health_check_endpoint
,
model_path
=
model_path
,
tokenizer_path
=
tokenizer_path
,
)
)
args_dict
[
"decode_urls"
]
=
(
args_dict
[
"decode_urls"
]
if
args_dict
[
"pd_disaggregation"
]
else
None
)
args_dict
[
"prefill_policy"
]
=
policy_from_str
(
args_dict
[
"prefill_policy"
])
args_dict
[
"decode_policy"
]
=
policy_from_str
(
args_dict
[
"decode_policy"
])
# remoge mini_lb parameter
args_dict
.
pop
(
"mini_lb"
)
return
Router
(
_Router
(
**
args_dict
))
def
start
(
self
)
->
None
:
def
start
(
self
)
->
None
:
"""Start the router server.
"""Start the router server.
...
...
sgl-router/py_src/sglang_router/router_args.py
0 → 100644
View file @
6e95f5e5
This diff is collapsed.
Click to expand it.
sgl-router/py_test/test_launch_router.py
View file @
6e95f5e5
...
@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
...
@@ -33,7 +33,7 @@ class TestLaunchRouter(unittest.TestCase):
cache_threshold
=
0.5
,
cache_threshold
=
0.5
,
balance_abs_threshold
=
32
,
balance_abs_threshold
=
32
,
balance_rel_threshold
=
1.0001
,
balance_rel_threshold
=
1.0001
,
eviction_interval
=
60
,
eviction_interval
_secs
=
60
,
max_tree_size
=
2
**
24
,
max_tree_size
=
2
**
24
,
max_payload_size
=
256
*
1024
*
1024
,
# 256MB
max_payload_size
=
256
*
1024
*
1024
,
# 256MB
verbose
=
False
,
verbose
=
False
,
...
@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
...
@@ -176,9 +176,8 @@ class TestLaunchRouter(unittest.TestCase):
"""Test basic PD router functionality without actually starting servers."""
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
# without actually starting it (which would require real prefill/decode servers)
from
sglang_router
import
Router
from
sglang_router.launch_router
import
RouterArgs
from
sglang_router.launch_router
import
RouterArgs
from
sglang_router
_rs
import
PolicyType
from
sglang_router
.router
import
PolicyType
,
Router
# Test RouterArgs parsing for PD mode
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
# Simulate the parsed args structure from argparse with action="append"
...
@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
...
@@ -209,18 +208,7 @@ class TestLaunchRouter(unittest.TestCase):
self
.
assertEqual
(
router_args
.
decode_urls
[
1
],
"http://decode2:8081"
)
self
.
assertEqual
(
router_args
.
decode_urls
[
1
],
"http://decode2:8081"
)
# Test Router creation in PD mode
# Test Router creation in PD mode
router
=
Router
(
router
=
Router
.
from_args
(
router_args
)
worker_urls
=
[],
# Empty for PD mode
pd_disaggregation
=
True
,
prefill_urls
=
[
(
"http://prefill1:8080"
,
9000
),
(
"http://prefill2:8080"
,
None
),
],
decode_urls
=
[
"http://decode1:8081"
,
"http://decode2:8081"
],
policy
=
PolicyType
.
CacheAware
,
host
=
"127.0.0.1"
,
port
=
3001
,
)
self
.
assertIsNotNone
(
router
)
self
.
assertIsNotNone
(
router
)
def
test_policy_validation
(
self
):
def
test_policy_validation
(
self
):
...
...
sgl-router/py_test/test_launch_server.py
View file @
6e95f5e5
...
@@ -77,7 +77,7 @@ def popen_launch_router(
...
@@ -77,7 +77,7 @@ def popen_launch_router(
port
,
port
,
"--dp"
,
"--dp"
,
str
(
dp_size
),
str
(
dp_size
),
"--router-eviction-interval"
,
"--router-eviction-interval
-secs
"
,
"5"
,
"5"
,
"--router-policy"
,
"--router-policy"
,
policy
,
policy
,
...
...
sgl-router/pyproject.toml
View file @
6e95f5e5
...
@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
...
@@ -28,8 +28,3 @@ find = { where = ["py_src"] }
# workaround for https://github.com/pypa/twine/issues/1216
# workaround for https://github.com/pypa/twine/issues/1216
[tool.setuptools]
[tool.setuptools]
license-files
=
[]
license-files
=
[]
[[tool.setuptools-rust.ext-modules]]
target
=
"sglang_router_rs"
path
=
"Cargo.toml"
binding
=
"PyO3"
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment