Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c6c62640
Unverified
Commit
c6c62640
authored
Apr 29, 2025
by
ybyang
Committed by
GitHub
Apr 29, 2025
Browse files
[PD] support pd fake transfer for warmup (#5726)
parent
92ab0a20
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
146 additions
and
7 deletions
+146
-7
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+8
-2
python/sglang/srt/disaggregation/fake/__init__.py
python/sglang/srt/disaggregation/fake/__init__.py
+1
-0
python/sglang/srt/disaggregation/fake/conn.py
python/sglang/srt/disaggregation/fake/conn.py
+88
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+6
-1
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+16
-2
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+27
-2
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
c6c62640
...
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
...
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
BaseKVReceiver
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
BaseKVReceiver
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
DisaggregationMode
,
FakeBootstrapHost
,
KVClassType
,
KVClassType
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
...
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
...
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
def
add
(
self
,
req
:
Req
)
->
None
:
def
add
(
self
,
req
:
Req
)
->
None
:
"""Add a request to the pending queue."""
"""Add a request to the pending queue."""
if
req
.
bootstrap_host
==
FakeBootstrapHost
:
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
)
# Fake transfer for warmup reqs
kv_receiver_class
=
get_kv_class
(
TransferBackend
.
FAKE
,
KVClassType
.
RECEIVER
)
else
:
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
)
kv_receiver
=
kv_receiver_class
(
kv_receiver
=
kv_receiver_class
(
mgr
=
self
.
kv_manager
,
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
...
...
python/sglang/srt/disaggregation/fake/__init__.py
0 → 100644
View file @
c6c62640
from
.conn
import
FakeKVReceiver
,
FakeKVSender
python/sglang/srt/disaggregation/fake/conn.py
0 → 100644
View file @
c6c62640
import
logging
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
from
sglang.srt.disaggregation.base.conn
import
(
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
logger
=
logging
.
getLogger
(
__name__
)
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
class
FakeKVSender
(
BaseKVSender
):
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
):
self
.
has_sent
=
False
def
poll
(
self
)
->
KVPoll
:
if
self
.
has_sent
is
False
:
# Assume handshake completed instantly
return
KVPoll
.
WaitingForInput
else
:
# Assume transfer completed instantly
logger
.
info
(
"FakeKVSender poll success"
)
return
KVPoll
.
Success
def
init
(
self
,
kv_indices
:
list
[
int
],
aux_index
:
Optional
[
int
]
=
None
,
dest_ranks
:
Optional
[
list
[
int
]]
=
None
,
):
logger
.
info
(
f
"FakeKVSender init with kv_indices:
{
kv_indices
}
, aux_index:
{
aux_index
}
, dest_ranks:
{
dest_ranks
}
"
)
pass
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
index_slice
:
slice
,
is_last
:
bool
,
):
logger
.
info
(
f
"FakeKVSender send with kv_indices:
{
kv_indices
}
, index_slice:
{
index_slice
}
, is_last:
{
is_last
}
"
)
if
is_last
:
self
.
has_sent
=
True
logger
.
info
(
f
"FakeKVSender send success"
)
else
:
self
.
has_sent
=
False
logger
.
info
(
f
"FakeKVSender send fake transfering"
)
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVSender Exception"
)
class
FakeKVReceiver
(
BaseKVReceiver
):
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
self
.
has_init
=
False
def
poll
(
self
)
->
KVPoll
:
if
self
.
has_init
is
False
:
# Assume handshake completed instantly
return
KVPoll
.
WaitingForInput
else
:
# Assume transfer completed instantly
logger
.
info
(
"FakeKVReceiver poll success"
)
return
KVPoll
.
Success
def
init
(
self
,
kv_indices
:
list
[
int
],
aux_index
:
Optional
[
int
]
=
None
):
self
.
has_init
=
True
logger
.
info
(
f
"FakeKVReceiver init with kv_indices:
{
kv_indices
}
, aux_index:
{
aux_index
}
"
)
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVReceiver Exception"
)
python/sglang/srt/disaggregation/prefill.py
View file @
c6c62640
...
@@ -29,6 +29,7 @@ import torch
...
@@ -29,6 +29,7 @@ import torch
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
DisaggregationMode
,
FakeBootstrapHost
,
KVClassType
,
KVClassType
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
...
@@ -116,6 +117,10 @@ class PrefillBootstrapQueue:
...
@@ -116,6 +117,10 @@ class PrefillBootstrapQueue:
return
kv_manager
return
kv_manager
def
add
(
self
,
req
:
Req
)
->
None
:
def
add
(
self
,
req
:
Req
)
->
None
:
if
req
.
bootstrap_host
==
FakeBootstrapHost
:
# Fake transfer for warmup reqs
kv_sender_class
=
get_kv_class
(
TransferBackend
.
FAKE
,
KVClassType
.
SENDER
)
else
:
kv_sender_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
SENDER
)
kv_sender_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
SENDER
)
req
.
disagg_kv_sender
=
kv_sender_class
(
req
.
disagg_kv_sender
=
kv_sender_class
(
mgr
=
self
.
kv_manager
,
mgr
=
self
.
kv_manager
,
...
...
python/sglang/srt/disaggregation/utils.py
View file @
c6c62640
...
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
...
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
DECODE
=
"decode"
DECODE
=
"decode"
FakeBootstrapHost
=
"2.2.2.2"
def
poll_and_all_reduce
(
pollers
,
gloo_group
):
def
poll_and_all_reduce
(
pollers
,
gloo_group
):
polls
=
[
int
(
poller
.
poll
())
for
poller
in
pollers
]
polls
=
[
int
(
poller
.
poll
())
for
poller
in
pollers
]
tensor_to_reduce
=
torch
.
tensor
(
polls
,
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
tensor_to_reduce
=
torch
.
tensor
(
polls
,
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
...
@@ -59,6 +62,8 @@ class KVClassType(Enum):
...
@@ -59,6 +62,8 @@ class KVClassType(Enum):
def
get_kv_class
(
transfer_backend
:
TransferBackend
,
class_type
:
KVClassType
):
def
get_kv_class
(
transfer_backend
:
TransferBackend
,
class_type
:
KVClassType
):
from
sglang.srt.disaggregation.fake
import
FakeKVReceiver
,
FakeKVSender
if
transfer_backend
==
TransferBackend
.
MOONCAKE
:
if
transfer_backend
==
TransferBackend
.
MOONCAKE
:
from
sglang.srt.disaggregation.mooncake
import
(
from
sglang.srt.disaggregation.mooncake
import
(
MooncakeKVBootstrapServer
,
MooncakeKVBootstrapServer
,
...
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
...
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
class_mapping
=
{
class_mapping
=
{
KVClassType
.
MANAGER
:
MooncakeKVManager
,
KVClassType
.
MANAGER
:
MooncakeKVManager
,
KVClassType
.
SENDER
:
MooncakeKVSender
,
KVClassType
.
SENDER
:
MooncakeKVSender
,
KVClassType
.
RECEIVER
:
MooncakeKVReceiver
,
KVClassType
.
RECEIVER
:
(
MooncakeKVReceiver
)
,
KVClassType
.
BOOTSTRAP_SERVER
:
MooncakeKVBootstrapServer
,
KVClassType
.
BOOTSTRAP_SERVER
:
MooncakeKVBootstrapServer
,
}
}
return
class_mapping
.
get
(
class_type
)
return
class_mapping
.
get
(
class_type
)
...
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
...
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
class_mapping
=
{
class_mapping
=
{
KVClassType
.
MANAGER
:
NixlKVManager
,
KVClassType
.
MANAGER
:
NixlKVManager
,
KVClassType
.
SENDER
:
NixlKVSender
,
KVClassType
.
SENDER
:
NixlKVSender
,
KVClassType
.
RECEIVER
:
NixlKVReceiver
,
KVClassType
.
RECEIVER
:
(
NixlKVReceiver
)
,
KVClassType
.
BOOTSTRAP_SERVER
:
NixlKVBootstrapServer
,
KVClassType
.
BOOTSTRAP_SERVER
:
NixlKVBootstrapServer
,
}
}
return
class_mapping
.
get
(
class_type
)
return
class_mapping
.
get
(
class_type
)
if
transfer_backend
==
TransferBackend
.
FAKE
:
from
sglang.srt.disaggregation.fake
import
FakeKVReceiver
,
FakeKVSender
class_mapping
=
{
KVClassType
.
SENDER
:
FakeKVSender
,
KVClassType
.
RECEIVER
:
(
FakeKVReceiver
),
}
return
class_mapping
.
get
(
class_type
)
raise
ValueError
(
f
"Unsupported transfer backend:
{
transfer_backend
}
"
)
raise
ValueError
(
f
"Unsupported transfer backend:
{
transfer_backend
}
"
)
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
c6c62640
...
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
...
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang.srt.disaggregation.utils
import
FakeBootstrapHost
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.entrypoints.engine
import
_launch_subprocesses
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call_parser
import
FunctionCallParser
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -821,8 +822,32 @@ def _wait_and_warmup(
...
@@ -821,8 +822,32 @@ def _wait_and_warmup(
)
)
assert
res
.
status_code
==
200
,
f
"
{
res
}
"
assert
res
.
status_code
==
200
,
f
"
{
res
}
"
else
:
else
:
# Warmup request currently hangs in disaggregation mode, so we skip it.
logger
.
info
(
f
"Start of prefill warmup ..."
)
logger
.
info
(
"Skipping warmup request in disaggregation mode"
)
json_data
=
{
"sampling_params"
:
{
"temperature"
:
0.0
,
"max_new_tokens"
:
8
,
"ignore_eos"
:
True
,
},
"bootstrap_host"
:
[
FakeBootstrapHost
]
*
server_args
.
dp_size
,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room"
:
[
i
*
(
2
**
63
//
server_args
.
dp_size
)
+
(
i
%
server_args
.
tp_size
)
for
i
in
range
(
server_args
.
dp_size
)
],
"input_ids"
:
[[
0
,
1
,
2
,
3
]]
*
server_args
.
dp_size
,
}
res
=
requests
.
post
(
url
+
request_name
,
json
=
json_data
,
headers
=
headers
,
timeout
=
1800
,
# because of deep gemm precache is very long if not precache.
)
logger
.
info
(
f
"End of prefill warmup with status
{
res
.
status_code
}
, resp:
{
res
.
json
()
}
"
)
except
Exception
:
except
Exception
:
last_traceback
=
get_exception_traceback
()
last_traceback
=
get_exception_traceback
()
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment