Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6d3b35fa
Unverified
Commit
6d3b35fa
authored
Apr 08, 2025
by
Byron Hsu
Committed by
GitHub
Apr 08, 2025
Browse files
[PD] Simplify mini LB (#4911)
Co-authored-by:
Liangsheng Yin
<
hnyls2002@gmail.com
>
parent
a73c4df4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
125 deletions
+53
-125
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+53
-125
No files found.
python/sglang/srt/disaggregation/mini_lb.py
View file @
6d3b35fa
"""
"""
Minimal HTTP load balancer for prefill and decode servers for testing
purpose
.
Minimal HTTP load balancer for prefill and decode servers for testing.
"""
"""
import
asyncio
import
asyncio
...
@@ -22,64 +22,59 @@ class MiniLoadBalancer:
...
@@ -22,64 +22,59 @@ class MiniLoadBalancer:
def
select_pair
(
self
):
def
select_pair
(
self
):
return
random
.
choice
(
self
.
prefill_servers
),
random
.
choice
(
self
.
decode_servers
)
return
random
.
choice
(
self
.
prefill_servers
),
random
.
choice
(
self
.
decode_servers
)
async
def
generate_request
(
self
,
request_data
):
async
def
generate
(
prefill_server
,
decode_server
=
self
.
select_pair
()
self
,
modified_request
,
prefill_server
,
decode_server
)
->
ORJSONResponse
:
# Parse and transform prefill_server
parsed_url
=
urllib
.
parse
.
urlparse
(
prefill_server
)
hostname
=
parsed_url
.
hostname
bootstrap_host
=
f
"
{
hostname
}
"
modified_request
=
request_data
.
copy
()
modified_request
.
update
(
{
"bootstrap_host"
:
bootstrap_host
,
"bootstrap_room"
:
random
.
randint
(
0
,
2
**
63
-
1
),
}
)
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
aiohttp
.
ClientSession
()
as
session
:
# Create the tasks
tasks
=
[
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/generate"
,
json
=
modified_request
),
session
.
post
(
f
"
{
prefill_server
}
/generate"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/generate"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/generate"
,
json
=
modified_request
),
]
]
# Wait for both responses to complete. Prefill should end first.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
return
ORJSONResponse
(
content
=
await
decode_response
.
json
(),
status_code
=
decode_response
.
status
,
)
prefill_response
=
None
async
def
generate_stream
(
self
,
modified_request
,
prefill_server
,
decode_server
):
decode_response
=
None
async
def
stream_results
():
async
with
aiohttp
.
ClientSession
(
# Process responses as they arrive
timeout
=
aiohttp
.
ClientTimeout
(
for
i
,
response
in
enumerate
(
asyncio
.
as_completed
(
tasks
)):
total
=
3600
response
=
await
response
)
# Add timeout for request reliability
# Check if this is the prefill or decode response based on order created
)
as
session
:
if
i
==
0
:
# First completed task
try
:
if
str
(
response
.
url
).
startswith
(
prefill_server
):
# Create the tasks for both prefill and decode requests
prefill_response
=
response
tasks
=
[
if
response
.
status
!=
200
:
session
.
post
(
raise
HTTPException
(
f
"
{
prefill_server
}
/generate"
,
json
=
modified_request
status_code
=
response
.
status
,
),
detail
=
f
"Prefill server error: Status
{
response
.
status
}
Details:
{
await
response
.
text
()
}
"
,
session
.
post
(
)
f
"
{
decode_server
}
/generate"
,
json
=
modified_request
else
:
),
decode_response
=
response
]
if
response
.
status
!=
200
:
# Wait for both responses to complete. Since this is streaming, they return immediately.
raise
HTTPException
(
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
status_code
=
response
.
status
,
async
for
chunk
in
decode_response
.
content
:
detail
=
f
"Decode server error: Status
{
response
.
status
}
Details:
{
await
response
.
text
()
}
"
,
yield
chunk
)
except
Exception
as
e
:
else
:
# Second completed task
error_msg
=
{
if
str
(
response
.
url
).
startswith
(
prefill_server
):
"error"
:
{
"message"
:
f
"Stream processing error:
{
str
(
e
)
}
"
}
prefill_response
=
response
}
else
:
yield
b
"data: "
+
orjson
.
dumps
(
decode_response
=
response
error_msg
,
option
=
orjson
.
OPT_NON_STR_KEYS
)
+
b
"
\n\n
"
if
response
.
status
!=
200
:
finally
:
raise
HTTPException
(
if
prefill_response
is
not
None
:
status_code
=
response
.
status
,
await
prefill_response
.
release
()
detail
=
f
"
{
'Prefill'
if
str
(
response
.
url
).
startswith
(
prefill_server
)
else
'Decode'
}
server error: Status
{
response
.
status
}
Details:
{
await
response
.
text
()
}
"
,
)
return
StreamingResponse
(
stream_results
(),
return
await
decode_response
.
json
()
media_type
=
"text/event-stream"
,
)
app
=
FastAPI
()
app
=
FastAPI
()
...
@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict):
...
@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict):
}
}
)
)
# Check if streaming is requested
if
request_data
.
get
(
"stream"
,
False
):
if
request_data
.
get
(
"stream"
,
False
):
return
await
load_balancer
.
generate_stream
(
async
def
stream_results
():
modified_request
,
prefill_server
,
decode_server
async
with
aiohttp
.
ClientSession
(
)
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3600
)
else
:
)
as
session
:
return
await
load_balancer
.
generate
(
try
:
modified_request
,
prefill_server
,
decode_server
# Create the tasks
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/generate"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/generate"
,
json
=
modified_request
),
]
prefill_response
=
None
decode_response
=
None
# Process responses as they arrive
for
i
,
response_task
in
enumerate
(
asyncio
.
as_completed
(
tasks
)):
response
=
await
response_task
# Check the response immediately
if
str
(
response
.
url
).
startswith
(
prefill_server
):
prefill_response
=
response
if
response
.
status
!=
200
:
error_msg
=
{
"error"
:
{
"message"
:
f
"Prefill server error: Status
{
response
.
status
}
, Details:
{
await
response
.
text
()
}
"
}
}
yield
b
"data: "
+
orjson
.
dumps
(
error_msg
,
option
=
orjson
.
OPT_NON_STR_KEYS
)
+
b
"
\n\n
"
return
else
:
decode_response
=
response
if
response
.
status
!=
200
:
error_msg
=
{
"error"
:
{
"message"
:
f
"Decode server error: Status
{
response
.
status
}
"
}
}
yield
b
"data: "
+
orjson
.
dumps
(
error_msg
,
option
=
orjson
.
OPT_NON_STR_KEYS
)
+
b
"
\n\n
"
return
# Stream successful decode server response
async
for
line
in
decode_response
.
content
:
yield
line
yield
b
"data: [DONE]
\n\n
"
except
Exception
as
e
:
error_msg
=
{
"error"
:
{
"message"
:
f
"Stream processing error:
{
str
(
e
)
}
"
}
}
yield
b
"data: "
+
orjson
.
dumps
(
error_msg
,
option
=
orjson
.
OPT_NON_STR_KEYS
)
+
b
"
\n\n
"
finally
:
if
prefill_response
is
not
None
:
await
prefill_response
.
release
()
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
,
)
)
# Non-streaming case
result
=
await
load_balancer
.
generate_request
(
request_data
)
return
ORJSONResponse
(
content
=
result
)
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment