Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a3e4e9bf
".github/vscode:/vscode.git/clone" did not exist on "1f766c36fb61f7b1969664645bf38dae93f568a2"
Unverified
Commit
a3e4e9bf
authored
May 07, 2025
by
Liangsheng Yin
Committed by
GitHub
May 07, 2025
Browse files
Better PD initialization (#5751)
parent
6d4d3bc8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
25 deletions
+141
-25
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+74
-23
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+44
-1
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+12
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/disaggregation/mini_lb.py
View file @
a3e4e9bf
...
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
...
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
"""
"""
import
asyncio
import
asyncio
import
dataclasses
import
logging
import
random
import
random
import
urllib
import
urllib
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
List
from
typing
import
List
,
Optional
import
aiohttp
import
aiohttp
import
orjson
import
orjson
...
@@ -14,11 +16,32 @@ import uvicorn
...
@@ -14,11 +16,32 @@ import uvicorn
from
fastapi
import
FastAPI
,
HTTPException
from
fastapi
import
FastAPI
,
HTTPException
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang.srt.disaggregation.utils
import
PDRegistryRequest
def
setup_logger
():
logger
=
logging
.
getLogger
(
"pdlb"
)
logger
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
handler
=
logging
.
StreamHandler
()
handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
handler
)
return
logger
logger
=
setup_logger
()
@
dataclasses
.
dataclass
class
PrefillConfig
:
class
PrefillConfig
:
def
__init__
(
self
,
url
:
str
,
bootstrap_port
:
int
):
url
:
str
self
.
url
=
url
bootstrap_port
:
Optional
[
int
]
=
None
self
.
bootstrap_port
=
bootstrap_port
class
MiniLoadBalancer
:
class
MiniLoadBalancer
:
...
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
...
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
self
.
decode_servers
=
decode_servers
self
.
decode_servers
=
decode_servers
def
select_pair
(
self
):
def
select_pair
(
self
):
# TODO: return some message instead of panic
assert
len
(
self
.
prefill_configs
)
>
0
,
"No prefill servers available"
assert
len
(
self
.
decode_servers
)
>
0
,
"No decode servers available"
prefill_config
=
random
.
choice
(
self
.
prefill_configs
)
prefill_config
=
random
.
choice
(
self
.
prefill_configs
)
decode_server
=
random
.
choice
(
self
.
decode_servers
)
decode_server
=
random
.
choice
(
self
.
decode_servers
)
return
prefill_config
.
url
,
prefill_config
.
bootstrap_port
,
decode_server
return
prefill_config
.
url
,
prefill_config
.
bootstrap_port
,
decode_server
...
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
...
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
]
# Wait for both responses to complete. Prefill should end first.
# Wait for both responses to complete. Prefill should end first.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
_
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
return
ORJSONResponse
(
return
ORJSONResponse
(
content
=
await
decode_response
.
json
(),
content
=
await
decode_response
.
json
(),
...
@@ -268,6 +295,32 @@ async def get_models():
...
@@ -268,6 +295,32 @@ async def get_models():
raise
HTTPException
(
status_code
=
500
,
detail
=
str
(
e
))
raise
HTTPException
(
status_code
=
500
,
detail
=
str
(
e
))
@
app
.
post
(
"/register"
)
async
def
register
(
obj
:
PDRegistryRequest
):
if
obj
.
mode
==
"prefill"
:
load_balancer
.
prefill_configs
.
append
(
PrefillConfig
(
obj
.
registry_url
,
obj
.
bootstrap_port
)
)
logger
.
info
(
f
"Registered prefill server:
{
obj
.
registry_url
}
with bootstrap port:
{
obj
.
bootstrap_port
}
"
)
elif
obj
.
mode
==
"decode"
:
load_balancer
.
decode_servers
.
append
(
obj
.
registry_url
)
logger
.
info
(
f
"Registered decode server:
{
obj
.
registry_url
}
"
)
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"Invalid mode. Must be either PREFILL or DECODE."
,
)
logger
.
info
(
f
"#Prefill servers:
{
len
(
load_balancer
.
prefill_configs
)
}
, "
f
"#Decode servers:
{
len
(
load_balancer
.
decode_servers
)
}
"
)
return
Response
(
status_code
=
200
)
def
run
(
prefill_configs
,
decode_addrs
,
host
,
port
):
def
run
(
prefill_configs
,
decode_addrs
,
host
,
port
):
global
load_balancer
global
load_balancer
load_balancer
=
MiniLoadBalancer
(
prefill_configs
,
decode_addrs
)
load_balancer
=
MiniLoadBalancer
(
prefill_configs
,
decode_addrs
)
...
@@ -279,15 +332,16 @@ if __name__ == "__main__":
...
@@ -279,15 +332,16 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
description
=
"Mini Load Balancer Server"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Mini Load Balancer Server"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--prefill"
,
required
=
True
,
help
=
"Comma-separated
URLs for prefill servers"
"--prefill"
,
type
=
str
,
default
=
[],
nargs
=
"+"
,
help
=
"
URLs for prefill servers"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--prefill-bootstrap-ports"
,
"--decode"
,
type
=
str
,
default
=
[],
nargs
=
"+"
,
help
=
"URLs for decode servers"
help
=
"Comma-separated bootstrap ports for prefill servers"
,
default
=
"8998"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode"
,
required
=
True
,
help
=
"Comma-separated URLs for decode servers"
"--prefill-bootstrap-ports"
,
type
=
int
,
nargs
=
"+"
,
help
=
"Bootstrap ports for prefill servers"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--host"
,
default
=
"0.0.0.0"
,
help
=
"Host to bind the server (default: 0.0.0.0)"
"--host"
,
default
=
"0.0.0.0"
,
help
=
"Host to bind the server (default: 0.0.0.0)"
...
@@ -297,22 +351,19 @@ if __name__ == "__main__":
...
@@ -297,22 +351,19 @@ if __name__ == "__main__":
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
prefill_url
s
=
args
.
prefill
.
split
(
","
)
bootstrap_port
s
=
args
.
prefill
_bootstrap_ports
bootstrap_ports
=
[
int
(
p
)
for
p
in
args
.
prefill_bootstrap_ports
.
split
(
","
)]
if
bootstrap_ports
is
None
:
bootstrap_ports
=
[
None
]
*
len
(
args
.
prefill
)
if
len
(
bootstrap_ports
)
==
1
:
el
if
len
(
bootstrap_ports
)
==
1
:
bootstrap_ports
=
bootstrap_ports
*
len
(
prefill
_urls
)
bootstrap_ports
=
bootstrap_ports
*
len
(
args
.
prefill
)
else
:
else
:
if
len
(
bootstrap_ports
)
!=
len
(
prefill
_urls
):
if
len
(
bootstrap_ports
)
!=
len
(
args
.
prefill
):
raise
ValueError
(
raise
ValueError
(
"Number of prefill URLs must match number of bootstrap ports"
"Number of prefill URLs must match number of bootstrap ports"
)
)
exit
(
1
)
prefill_configs
=
[]
for
url
,
port
in
zip
(
prefill_urls
,
bootstrap_ports
):
prefill_configs
.
append
(
PrefillConfig
(
url
,
port
))
decode_addrs
=
args
.
decode
.
split
(
","
)
prefill_configs
=
[
PrefillConfig
(
url
,
port
)
for
url
,
port
in
zip
(
args
.
prefill
,
bootstrap_ports
)
]
run
(
prefill_configs
,
decode
_addrs
,
args
.
host
,
args
.
port
)
run
(
prefill_configs
,
args
.
decode
,
args
.
host
,
args
.
port
)
python/sglang/srt/disaggregation/utils.py
View file @
a3e4e9bf
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
import
warnings
from
collections
import
deque
from
collections
import
deque
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
from
typing
import
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
requests
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt.utils
import
get_ip
class
DisaggregationMode
(
Enum
):
class
DisaggregationMode
(
Enum
):
NULL
=
"null"
NULL
=
"null"
...
@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
...
@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
def
kv_to_page_num
(
num_kv_indices
:
int
,
page_size
:
int
):
def
kv_to_page_num
(
num_kv_indices
:
int
,
page_size
:
int
):
# ceil(num_kv_indices / page_size)
# ceil(num_kv_indices / page_size)
return
(
num_kv_indices
+
page_size
-
1
)
//
page_size
return
(
num_kv_indices
+
page_size
-
1
)
//
page_size
@
dataclasses
.
dataclass
class
PDRegistryRequest
:
"""A request to register a machine itself to the LB."""
mode
:
str
registry_url
:
str
bootstrap_port
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
if
self
.
mode
==
"prefill"
and
self
.
bootstrap_port
is
None
:
raise
ValueError
(
"Bootstrap port must be set in PREFILL mode."
)
elif
self
.
mode
==
"decode"
and
self
.
bootstrap_port
is
not
None
:
raise
ValueError
(
"Bootstrap port must not be set in DECODE mode."
)
elif
self
.
mode
not
in
[
"prefill"
,
"decode"
]:
raise
ValueError
(
f
"Invalid mode:
{
self
.
mode
}
. Must be 'prefill' or 'decode'."
)
def
register_disaggregation_server
(
mode
:
str
,
server_port
:
int
,
bootstrap_port
:
int
,
pdlb_url
:
str
):
boostrap_port
=
bootstrap_port
if
mode
==
"prefill"
else
None
registry_request
=
PDRegistryRequest
(
mode
=
mode
,
registry_url
=
f
"http://
{
get_ip
()
}
:
{
server_port
}
"
,
bootstrap_port
=
boostrap_port
,
)
res
=
requests
.
post
(
f
"
{
pdlb_url
}
/register"
,
json
=
dataclasses
.
asdict
(
registry_request
),
)
if
res
.
status_code
!=
200
:
warnings
.
warn
(
f
"Failed to register disaggregation server:
{
res
.
status_code
}
{
res
.
text
}
"
)
python/sglang/srt/entrypoints/http_server.py
View file @
a3e4e9bf
...
@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
...
@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang.srt.disaggregation.utils
import
FakeBootstrapHost
from
sglang.srt.disaggregation.utils
import
(
FakeBootstrapHost
,
register_disaggregation_server
,
)
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -871,5 +874,13 @@ def _wait_and_warmup(
...
@@ -871,5 +874,13 @@ def _wait_and_warmup(
if
server_args
.
debug_tensor_dump_input_file
:
if
server_args
.
debug_tensor_dump_input_file
:
kill_process_tree
(
os
.
getpid
())
kill_process_tree
(
os
.
getpid
())
if
server_args
.
pdlb_url
is
not
None
:
register_disaggregation_server
(
server_args
.
disaggregation_mode
,
server_args
.
port
,
server_args
.
disaggregation_bootstrap_port
,
server_args
.
pdlb_url
,
)
if
launch_callback
is
not
None
:
if
launch_callback
is
not
None
:
launch_callback
()
launch_callback
()
python/sglang/srt/managers/scheduler.py
View file @
a3e4e9bf
...
@@ -925,6 +925,10 @@ class Scheduler(
...
@@ -925,6 +925,10 @@ class Scheduler(
)
)
custom_logit_processor
=
None
custom_logit_processor
=
None
if
recv_req
.
bootstrap_port
is
None
:
# Use default bootstrap port
recv_req
.
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
req
=
Req
(
req
=
Req
(
recv_req
.
rid
,
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_text
,
...
...
python/sglang/srt/server_args.py
View file @
a3e4e9bf
...
@@ -198,6 +198,7 @@ class ServerArgs:
...
@@ -198,6 +198,7 @@ class ServerArgs:
disaggregation_bootstrap_port
:
int
=
8998
disaggregation_bootstrap_port
:
int
=
8998
disaggregation_transfer_backend
:
str
=
"mooncake"
disaggregation_transfer_backend
:
str
=
"mooncake"
disaggregation_ib_device
:
Optional
[
str
]
=
None
disaggregation_ib_device
:
Optional
[
str
]
=
None
pdlb_url
:
Optional
[
str
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Expert parallelism
# Expert parallelism
...
@@ -1254,6 +1255,12 @@ class ServerArgs:
...
@@ -1254,6 +1255,12 @@ class ServerArgs:
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when mooncake backend is enabled."
,
"Default is None, which triggers automatic device detection when mooncake backend is enabled."
,
)
)
parser
.
add_argument
(
"--pdlb-url"
,
type
=
str
,
default
=
None
,
help
=
"The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment