Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11e27d09
Unverified
Commit
11e27d09
authored
Apr 26, 2025
by
IAN
Committed by
GitHub
Apr 26, 2025
Browse files
[PD]: Support Muti Prefill in one node (#5704)
Co-authored-by:
shuaills
<
shishuaiuoe@gmail.com
>
parent
50eda839
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
9 deletions
+55
-9
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+1
-1
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+45
-8
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
11e27d09
...
...
@@ -137,7 +137,7 @@ class DecodePreallocQueue:
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
)
kv_receiver
=
kv_receiver_class
(
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
self
.
bootstrap_port
}
"
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
)
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
))
...
...
python/sglang/srt/disaggregation/mini_lb.py
View file @
11e27d09
...
...
@@ -6,6 +6,7 @@ import asyncio
import
random
import
urllib
from
itertools
import
chain
from
typing
import
List
import
aiohttp
import
orjson
...
...
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
class
PrefillConfig
:
def
__init__
(
self
,
url
:
str
,
bootstrap_port
:
int
):
self
.
url
=
url
self
.
bootstrap_port
=
bootstrap_port
class
MiniLoadBalancer
:
def
__init__
(
self
,
prefill_servers
,
decode_servers
):
self
.
prefill_servers
=
prefill_servers
def
__init__
(
self
,
prefill_configs
:
List
[
PrefillConfig
],
decode_servers
:
List
[
str
]):
self
.
prefill_configs
=
prefill_configs
self
.
prefill_servers
=
[
p
.
url
for
p
in
prefill_configs
]
self
.
decode_servers
=
decode_servers
def
select_pair
(
self
):
return
random
.
choice
(
self
.
prefill_servers
),
random
.
choice
(
self
.
decode_servers
)
prefill_config
=
random
.
choice
(
self
.
prefill_configs
)
decode_server
=
random
.
choice
(
self
.
decode_servers
)
return
prefill_config
.
url
,
prefill_config
.
bootstrap_port
,
decode_server
async
def
generate
(
self
,
modified_request
,
prefill_server
,
decode_server
,
endpoint
...
...
@@ -160,7 +170,7 @@ async def get_model_info():
@
app
.
post
(
"/generate"
)
async
def
handle_generate_request
(
request_data
:
dict
):
prefill_server
,
decode_server
=
load_balancer
.
select_pair
()
prefill_server
,
bootstrap_port
,
decode_server
=
load_balancer
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
...
...
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
modified_request
.
update
(
{
"bootstrap_host"
:
[
hostname
]
*
batch_size
,
"bootstrap_port"
:
[
bootstrap_port
]
*
batch_size
,
"bootstrap_room"
:
[
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
],
...
...
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
modified_request
.
update
(
{
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
_generate_bootstrap_room
(),
}
)
...
...
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
@
app
.
post
(
"/v1/chat/completions"
)
async
def
handle_completion_request
(
request_data
:
dict
):
prefill_server
,
decode_server
=
load_balancer
.
select_pair
()
prefill_server
,
bootstrap_port
,
decode_server
=
load_balancer
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
...
...
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
modified_request
.
update
(
{
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
random
.
randint
(
0
,
2
**
63
-
1
),
}
)
...
...
@@ -255,9 +268,9 @@ async def get_models():
raise
HTTPException
(
status_code
=
500
,
detail
=
str
(
e
))
def
run
(
prefill_
addr
s
,
decode_addrs
,
host
,
port
):
def
run
(
prefill_
config
s
,
decode_addrs
,
host
,
port
):
global
load_balancer
load_balancer
=
MiniLoadBalancer
(
prefill_
addr
s
,
decode_addrs
)
load_balancer
=
MiniLoadBalancer
(
prefill_
config
s
,
decode_addrs
)
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
)
...
...
@@ -268,6 +281,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--prefill"
,
required
=
True
,
help
=
"Comma-separated URLs for prefill servers"
)
parser
.
add_argument
(
"--prefill-bootstrap-ports"
,
help
=
"Comma-separated bootstrap ports for prefill servers"
,
default
=
"8998"
,
)
parser
.
add_argument
(
"--decode"
,
required
=
True
,
help
=
"Comma-separated URLs for decode servers"
)
...
...
@@ -278,4 +296,23 @@ if __name__ == "__main__":
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Port to bind the server (default: 8000)"
)
args
=
parser
.
parse_args
()
run
(
args
.
prefill
.
split
(
","
),
args
.
decode
.
split
(
","
),
args
.
host
,
args
.
port
)
prefill_urls
=
args
.
prefill
.
split
(
","
)
bootstrap_ports
=
[
int
(
p
)
for
p
in
args
.
prefill_bootstrap_ports
.
split
(
","
)]
if
len
(
bootstrap_ports
)
==
1
:
bootstrap_ports
=
bootstrap_ports
*
len
(
prefill_urls
)
else
:
if
len
(
bootstrap_ports
)
!=
len
(
prefill_urls
):
raise
ValueError
(
"Number of prefill URLs must match number of bootstrap ports"
)
exit
(
1
)
prefill_configs
=
[]
for
url
,
port
in
zip
(
prefill_urls
,
bootstrap_ports
):
prefill_configs
.
append
(
PrefillConfig
(
url
,
port
))
decode_addrs
=
args
.
decode
.
split
(
","
)
run
(
prefill_configs
,
decode_addrs
,
args
.
host
,
args
.
port
)
python/sglang/srt/managers/io_struct.py
View file @
11e27d09
...
...
@@ -97,6 +97,7 @@ class GenerateReqInput:
# For disaggregated inference
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
def
normalize_batch_and_arguments
(
self
):
...
...
@@ -400,6 +401,9 @@ class GenerateReqInput:
bootstrap_host
=
(
self
.
bootstrap_host
[
i
]
if
self
.
bootstrap_host
is
not
None
else
None
),
bootstrap_port
=
(
self
.
bootstrap_port
[
i
]
if
self
.
bootstrap_port
is
not
None
else
None
),
bootstrap_room
=
(
self
.
bootstrap_room
[
i
]
if
self
.
bootstrap_room
is
not
None
else
None
),
...
...
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
# For disaggregated inference
bootstrap_host
:
Optional
[
str
]
=
None
bootstrap_port
:
Optional
[
int
]
=
None
bootstrap_room
:
Optional
[
int
]
=
None
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
11e27d09
...
...
@@ -391,6 +391,7 @@ class Req:
return_hidden_states
:
bool
=
False
,
eos_token_ids
:
Optional
[
Set
[
int
]]
=
None
,
bootstrap_host
:
Optional
[
str
]
=
None
,
bootstrap_port
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
# Input and output info
...
...
@@ -526,6 +527,7 @@ class Req:
# For disaggregation
self
.
bootstrap_host
:
str
=
bootstrap_host
self
.
bootstrap_port
:
Optional
[
int
]
=
bootstrap_port
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
disagg_kv_sender
:
Optional
[
BaseKVSender
]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
11e27d09
...
...
@@ -791,6 +791,7 @@ class Scheduler(
return_hidden_states
=
recv_req
.
return_hidden_states
,
eos_token_ids
=
self
.
model_config
.
hf_eos_token_id
,
bootstrap_host
=
recv_req
.
bootstrap_host
,
bootstrap_port
=
recv_req
.
bootstrap_port
,
bootstrap_room
=
recv_req
.
bootstrap_room
,
)
req
.
tokenizer
=
self
.
tokenizer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
11e27d09
...
...
@@ -498,6 +498,7 @@ class TokenizerManager:
token_ids_logprob
,
obj
.
stream
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_port
=
obj
.
bootstrap_port
,
bootstrap_room
=
obj
.
bootstrap_room
,
lora_path
=
obj
.
lora_path
,
input_embeds
=
input_embeds
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment