Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11e27d09
Unverified
Commit
11e27d09
authored
Apr 26, 2025
by
IAN
Committed by
GitHub
Apr 26, 2025
Browse files
[PD]: Support Muti Prefill in one node (#5704)
Co-authored-by:
shuaills
<
shishuaiuoe@gmail.com
>
parent
50eda839
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
9 deletions
+55
-9
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+1
-1
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+45
-8
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
11e27d09
...
@@ -137,7 +137,7 @@ class DecodePreallocQueue:
...
@@ -137,7 +137,7 @@ class DecodePreallocQueue:
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
)
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
)
kv_receiver
=
kv_receiver_class
(
kv_receiver
=
kv_receiver_class
(
mgr
=
self
.
kv_manager
,
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
self
.
bootstrap_port
}
"
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
bootstrap_room
=
req
.
bootstrap_room
,
)
)
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
))
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
))
...
...
python/sglang/srt/disaggregation/mini_lb.py
View file @
11e27d09
...
@@ -6,6 +6,7 @@ import asyncio
...
@@ -6,6 +6,7 @@ import asyncio
import
random
import
random
import
urllib
import
urllib
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
List
import
aiohttp
import
aiohttp
import
orjson
import
orjson
...
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
...
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
class
PrefillConfig
:
def
__init__
(
self
,
url
:
str
,
bootstrap_port
:
int
):
self
.
url
=
url
self
.
bootstrap_port
=
bootstrap_port
class
MiniLoadBalancer
:
class
MiniLoadBalancer
:
def
__init__
(
self
,
prefill_servers
,
decode_servers
):
def
__init__
(
self
,
prefill_configs
:
List
[
PrefillConfig
],
decode_servers
:
List
[
str
]):
self
.
prefill_servers
=
prefill_servers
self
.
prefill_configs
=
prefill_configs
self
.
prefill_servers
=
[
p
.
url
for
p
in
prefill_configs
]
self
.
decode_servers
=
decode_servers
self
.
decode_servers
=
decode_servers
def
select_pair
(
self
):
def
select_pair
(
self
):
return
random
.
choice
(
self
.
prefill_servers
),
random
.
choice
(
self
.
decode_servers
)
prefill_config
=
random
.
choice
(
self
.
prefill_configs
)
decode_server
=
random
.
choice
(
self
.
decode_servers
)
return
prefill_config
.
url
,
prefill_config
.
bootstrap_port
,
decode_server
async
def
generate
(
async
def
generate
(
self
,
modified_request
,
prefill_server
,
decode_server
,
endpoint
self
,
modified_request
,
prefill_server
,
decode_server
,
endpoint
...
@@ -160,7 +170,7 @@ async def get_model_info():
...
@@ -160,7 +170,7 @@ async def get_model_info():
@
app
.
post
(
"/generate"
)
@
app
.
post
(
"/generate"
)
async
def
handle_generate_request
(
request_data
:
dict
):
async
def
handle_generate_request
(
request_data
:
dict
):
prefill_server
,
decode_server
=
load_balancer
.
select_pair
()
prefill_server
,
bootstrap_port
,
decode_server
=
load_balancer
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
...
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
...
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
modified_request
.
update
(
modified_request
.
update
(
{
{
"bootstrap_host"
:
[
hostname
]
*
batch_size
,
"bootstrap_host"
:
[
hostname
]
*
batch_size
,
"bootstrap_port"
:
[
bootstrap_port
]
*
batch_size
,
"bootstrap_room"
:
[
"bootstrap_room"
:
[
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
],
],
...
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
...
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
modified_request
.
update
(
modified_request
.
update
(
{
{
"bootstrap_host"
:
hostname
,
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
_generate_bootstrap_room
(),
"bootstrap_room"
:
_generate_bootstrap_room
(),
}
}
)
)
...
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
...
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
@
app
.
post
(
"/v1/chat/completions"
)
@
app
.
post
(
"/v1/chat/completions"
)
async
def
handle_completion_request
(
request_data
:
dict
):
async
def
handle_completion_request
(
request_data
:
dict
):
prefill_server
,
decode_server
=
load_balancer
.
select_pair
()
prefill_server
,
bootstrap_port
,
decode_server
=
load_balancer
.
select_pair
()
# Parse and transform prefill_server for bootstrap data
# Parse and transform prefill_server for bootstrap data
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
...
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
...
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
modified_request
.
update
(
modified_request
.
update
(
{
{
"bootstrap_host"
:
hostname
,
"bootstrap_host"
:
hostname
,
"bootstrap_port"
:
bootstrap_port
,
"bootstrap_room"
:
random
.
randint
(
0
,
2
**
63
-
1
),
"bootstrap_room"
:
random
.
randint
(
0
,
2
**
63
-
1
),
}
}
)
)
...
@@ -255,9 +268,9 @@ async def get_models():
...
@@ -255,9 +268,9 @@ async def get_models():
raise
HTTPException
(
status_code
=
500
,
detail
=
str
(
e
))
raise
HTTPException
(
status_code
=
500
,
detail
=
str
(
e
))
def
run
(
prefill_
addr
s
,
decode_addrs
,
host
,
port
):
def
run
(
prefill_
config
s
,
decode_addrs
,
host
,
port
):
global
load_balancer
global
load_balancer
load_balancer
=
MiniLoadBalancer
(
prefill_
addr
s
,
decode_addrs
)
load_balancer
=
MiniLoadBalancer
(
prefill_
config
s
,
decode_addrs
)
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
)
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
)
...
@@ -268,6 +281,11 @@ if __name__ == "__main__":
...
@@ -268,6 +281,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--prefill"
,
required
=
True
,
help
=
"Comma-separated URLs for prefill servers"
"--prefill"
,
required
=
True
,
help
=
"Comma-separated URLs for prefill servers"
)
)
parser
.
add_argument
(
"--prefill-bootstrap-ports"
,
help
=
"Comma-separated bootstrap ports for prefill servers"
,
default
=
"8998"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode"
,
required
=
True
,
help
=
"Comma-separated URLs for decode servers"
"--decode"
,
required
=
True
,
help
=
"Comma-separated URLs for decode servers"
)
)
...
@@ -278,4 +296,23 @@ if __name__ == "__main__":
...
@@ -278,4 +296,23 @@ if __name__ == "__main__":
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Port to bind the server (default: 8000)"
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Port to bind the server (default: 8000)"
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
run
(
args
.
prefill
.
split
(
","
),
args
.
decode
.
split
(
","
),
args
.
host
,
args
.
port
)
prefill_urls
=
args
.
prefill
.
split
(
","
)
bootstrap_ports
=
[
int
(
p
)
for
p
in
args
.
prefill_bootstrap_ports
.
split
(
","
)]
if
len
(
bootstrap_ports
)
==
1
:
bootstrap_ports
=
bootstrap_ports
*
len
(
prefill_urls
)
else
:
if
len
(
bootstrap_ports
)
!=
len
(
prefill_urls
):
raise
ValueError
(
"Number of prefill URLs must match number of bootstrap ports"
)
exit
(
1
)
prefill_configs
=
[]
for
url
,
port
in
zip
(
prefill_urls
,
bootstrap_ports
):
prefill_configs
.
append
(
PrefillConfig
(
url
,
port
))
decode_addrs
=
args
.
decode
.
split
(
","
)
run
(
prefill_configs
,
decode_addrs
,
args
.
host
,
args
.
port
)
python/sglang/srt/managers/io_struct.py
View file @
11e27d09
...
@@ -97,6 +97,7 @@ class GenerateReqInput:
...
@@ -97,6 +97,7 @@ class GenerateReqInput:
# For disaggregated inference
# For disaggregated inference
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
def
normalize_batch_and_arguments
(
self
):
def
normalize_batch_and_arguments
(
self
):
...
@@ -400,6 +401,9 @@ class GenerateReqInput:
...
@@ -400,6 +401,9 @@ class GenerateReqInput:
bootstrap_host
=
(
bootstrap_host
=
(
self
.
bootstrap_host
[
i
]
if
self
.
bootstrap_host
is
not
None
else
None
self
.
bootstrap_host
[
i
]
if
self
.
bootstrap_host
is
not
None
else
None
),
),
bootstrap_port
=
(
self
.
bootstrap_port
[
i
]
if
self
.
bootstrap_port
is
not
None
else
None
),
bootstrap_room
=
(
bootstrap_room
=
(
self
.
bootstrap_room
[
i
]
if
self
.
bootstrap_room
is
not
None
else
None
self
.
bootstrap_room
[
i
]
if
self
.
bootstrap_room
is
not
None
else
None
),
),
...
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
...
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
# For disaggregated inference
# For disaggregated inference
bootstrap_host
:
Optional
[
str
]
=
None
bootstrap_host
:
Optional
[
str
]
=
None
bootstrap_port
:
Optional
[
int
]
=
None
bootstrap_room
:
Optional
[
int
]
=
None
bootstrap_room
:
Optional
[
int
]
=
None
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
11e27d09
...
@@ -391,6 +391,7 @@ class Req:
...
@@ -391,6 +391,7 @@ class Req:
return_hidden_states
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
eos_token_ids
:
Optional
[
Set
[
int
]]
=
None
,
eos_token_ids
:
Optional
[
Set
[
int
]]
=
None
,
bootstrap_host
:
Optional
[
str
]
=
None
,
bootstrap_host
:
Optional
[
str
]
=
None
,
bootstrap_port
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
):
# Input and output info
# Input and output info
...
@@ -526,6 +527,7 @@ class Req:
...
@@ -526,6 +527,7 @@ class Req:
# For disaggregation
# For disaggregation
self
.
bootstrap_host
:
str
=
bootstrap_host
self
.
bootstrap_host
:
str
=
bootstrap_host
self
.
bootstrap_port
:
Optional
[
int
]
=
bootstrap_port
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
disagg_kv_sender
:
Optional
[
BaseKVSender
]
=
None
self
.
disagg_kv_sender
:
Optional
[
BaseKVSender
]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
11e27d09
...
@@ -791,6 +791,7 @@ class Scheduler(
...
@@ -791,6 +791,7 @@ class Scheduler(
return_hidden_states
=
recv_req
.
return_hidden_states
,
return_hidden_states
=
recv_req
.
return_hidden_states
,
eos_token_ids
=
self
.
model_config
.
hf_eos_token_id
,
eos_token_ids
=
self
.
model_config
.
hf_eos_token_id
,
bootstrap_host
=
recv_req
.
bootstrap_host
,
bootstrap_host
=
recv_req
.
bootstrap_host
,
bootstrap_port
=
recv_req
.
bootstrap_port
,
bootstrap_room
=
recv_req
.
bootstrap_room
,
bootstrap_room
=
recv_req
.
bootstrap_room
,
)
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
11e27d09
...
@@ -498,6 +498,7 @@ class TokenizerManager:
...
@@ -498,6 +498,7 @@ class TokenizerManager:
token_ids_logprob
,
token_ids_logprob
,
obj
.
stream
,
obj
.
stream
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_host
=
obj
.
bootstrap_host
,
bootstrap_port
=
obj
.
bootstrap_port
,
bootstrap_room
=
obj
.
bootstrap_room
,
bootstrap_room
=
obj
.
bootstrap_room
,
lora_path
=
obj
.
lora_path
,
lora_path
=
obj
.
lora_path
,
input_embeds
=
input_embeds
,
input_embeds
=
input_embeds
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment