Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a9499885
"tests/pytorch/test_basics.py" did not exist on "5e75f5dbfcc30d9170dc1a2999ef19efc10246d7"
Unverified
Commit
a9499885
authored
Apr 13, 2025
by
Byron Hsu
Committed by
GitHub
Apr 14, 2025
Browse files
[PD] Add transfer backend abstraction (#5328)
parent
f7655790
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
236 additions
and
41 deletions
+236
-41
python/sglang/srt/disaggregation/base/__init__.py
python/sglang/srt/disaggregation/base/__init__.py
+8
-0
python/sglang/srt/disaggregation/base/conn.py
python/sglang/srt/disaggregation/base/conn.py
+106
-0
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+18
-5
python/sglang/srt/disaggregation/mooncake/__init__.py
python/sglang/srt/disaggregation/mooncake/__init__.py
+6
-0
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+20
-27
python/sglang/srt/disaggregation/mooncake/transfer_engine.py
python/sglang/srt/disaggregation/mooncake/transfer_engine.py
+0
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+18
-4
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+31
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+7
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+13
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/disaggregation/base/__init__.py
0 → 100644
View file @
a9499885
from
.conn
import
(
BaseKVBootstrapServer
,
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
python/sglang/srt/disaggregation/base/conn.py
0 → 100644
View file @
a9499885
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
numpy
as
np
import
numpy.typing
as
npt
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
class
KVArgs
:
engine_rank
:
int
kv_data_ptrs
:
list
[
int
]
kv_data_lens
:
list
[
int
]
kv_item_lens
:
list
[
int
]
aux_data_ptrs
:
list
[
int
]
aux_data_lens
:
list
[
int
]
aux_item_lens
:
list
[
int
]
ib_device
:
str
class
KVPoll
:
Failed
=
0
Bootstrapping
=
1
WaitingForInput
=
2
Transferring
=
3
Success
=
4
class
BaseKVManager
(
ABC
):
"""Base class for managing transfers states"""
@
abstractmethod
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
):
...
class
BaseKVSender
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
):
...
@
abstractmethod
def
init
(
self
,
num_kv_indices
:
int
,
aux_index
:
Optional
[
int
]
=
None
):
"""
Notify the decoder server about the kv indices length and aux index
"""
...
@
abstractmethod
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
]):
"""
Send the kv cache at the given kv indices to the decoder server
"""
...
@
abstractmethod
def
poll
(
self
)
->
KVPoll
:
"""
Check the status of the kv cache transfer
"""
...
@
abstractmethod
def
failure_exception
(
self
):
"""
Raise an exception if the kv cache transfer fails
"""
...
class
BaseKVReceiver
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
...
@
abstractmethod
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
"""
Notify the prefill server about the kv indices and aux index
"""
...
@
abstractmethod
def
poll
(
self
)
->
KVPoll
:
"""
Check the status of the kv cache transfer
"""
...
@
abstractmethod
def
failure_exception
(
self
):
"""
Raise an exception if the kv cache transfer fails
"""
...
class
BaseKVBootstrapServer
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
port
:
int
):
...
python/sglang/srt/disaggregation/decode.py
View file @
a9499885
...
@@ -28,10 +28,19 @@ import numpy as np
...
@@ -28,10 +28,19 @@ import numpy as np
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
sglang.srt.disaggregation.conn
import
KVArgs
,
KVManager
,
KVPoll
,
KVReceiver
from
sglang.srt.disaggregation.base
import
(
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
DisaggregationMode
,
KVClassType
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
get_kv_class
,
poll_and_all_reduce
,
poll_and_all_reduce
,
)
)
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
...
@@ -51,7 +60,7 @@ if TYPE_CHECKING:
...
@@ -51,7 +60,7 @@ if TYPE_CHECKING:
@
dataclass
@
dataclass
class
DecodeRequest
:
class
DecodeRequest
:
req
:
Req
req
:
Req
kv_receiver
:
KVReceiver
kv_receiver
:
Base
KVReceiver
waiting_for_input
:
bool
=
False
waiting_for_input
:
bool
=
False
metadata_buffer_index
:
int
=
-
1
metadata_buffer_index
:
int
=
-
1
...
@@ -75,6 +84,7 @@ class DecodePreallocQueue:
...
@@ -75,6 +84,7 @@ class DecodePreallocQueue:
tp_rank
:
int
,
tp_rank
:
int
,
tp_size
:
int
,
tp_size
:
int
,
bootstrap_port
:
int
,
bootstrap_port
:
int
,
transfer_backend
:
TransferBackend
,
):
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
...
@@ -94,9 +104,10 @@ class DecodePreallocQueue:
...
@@ -94,9 +104,10 @@ class DecodePreallocQueue:
# Queue for requests pending pre-allocation
# Queue for requests pending pre-allocation
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
queue
:
List
[
DecodeRequest
]
=
[]
self
.
transfer_backend
=
transfer_backend
self
.
kv_manager
=
self
.
_init_kv_manager
()
self
.
kv_manager
=
self
.
_init_kv_manager
()
def
_init_kv_manager
(
self
)
->
KVManager
:
def
_init_kv_manager
(
self
)
->
Base
KVManager
:
kv_args
=
KVArgs
()
kv_args
=
KVArgs
()
kv_args
.
engine_rank
=
self
.
tp_rank
kv_args
.
engine_rank
=
self
.
tp_rank
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
=
(
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
=
(
...
@@ -117,13 +128,15 @@ class DecodePreallocQueue:
...
@@ -117,13 +128,15 @@ class DecodePreallocQueue:
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
]
kv_args
.
ib_device
=
"mock-ib-device"
kv_args
.
ib_device
=
"mock-ib-device"
kv_manager
=
KVManager
(
kv_args
,
DisaggregationMode
(
"decode"
))
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
kv_manager
=
kv_manager_class
(
kv_args
,
DisaggregationMode
.
DECODE
)
return
kv_manager
return
kv_manager
def
add
(
self
,
req
:
Req
)
->
None
:
def
add
(
self
,
req
:
Req
)
->
None
:
"""Add a request to the pending queue."""
"""Add a request to the pending queue."""
kv_receiver
=
KVReceiver
(
kv_receiver_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
RECEIVER
)
kv_receiver
=
kv_receiver_class
(
mgr
=
self
.
kv_manager
,
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
self
.
bootstrap_port
}
"
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
self
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
bootstrap_room
=
req
.
bootstrap_room
,
...
...
python/sglang/srt/disaggregation/mooncake/__init__.py
0 → 100644
View file @
a9499885
from
.conn
import
(
MooncakeKVBootstrapServer
,
MooncakeKVManager
,
MooncakeKVReceiver
,
MooncakeKVSender
,
)
python/sglang/srt/disaggregation/conn.py
→
python/sglang/srt/disaggregation/
mooncake/
conn.py
View file @
a9499885
...
@@ -12,7 +12,15 @@ import numpy.typing as npt
...
@@ -12,7 +12,15 @@ import numpy.typing as npt
import
zmq
import
zmq
from
aiohttp
import
web
from
aiohttp
import
web
from
sglang.srt.disaggregation.transfer_engine.mooncake
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.base.conn
import
(
BaseKVBootstrapServer
,
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -44,25 +52,6 @@ def group_concurrent_contiguous(
...
@@ -44,25 +52,6 @@ def group_concurrent_contiguous(
return
src_groups
,
dst_groups
return
src_groups
,
dst_groups
class
KVArgs
:
engine_rank
:
int
kv_data_ptrs
:
list
[
int
]
kv_data_lens
:
list
[
int
]
kv_item_lens
:
list
[
int
]
aux_data_ptrs
:
list
[
int
]
aux_data_lens
:
list
[
int
]
aux_item_lens
:
list
[
int
]
ib_device
:
str
class
KVPoll
:
Failed
=
0
Bootstrapping
=
1
WaitingForInput
=
2
Transferring
=
3
Success
=
4
RequestPoolType
=
Dict
[
int
,
Tuple
[
npt
.
NDArray
[
np
.
int64
],
Optional
[
int
]]]
RequestPoolType
=
Dict
[
int
,
Tuple
[
npt
.
NDArray
[
np
.
int64
],
Optional
[
int
]]]
WaitingPoolType
=
Dict
[
WaitingPoolType
=
Dict
[
int
,
Tuple
[
str
,
list
[
int
],
npt
.
NDArray
[
np
.
int64
],
list
[
int
],
int
]
int
,
Tuple
[
str
,
list
[
int
],
npt
.
NDArray
[
np
.
int64
],
list
[
int
],
int
]
...
@@ -71,8 +60,7 @@ KVSENDER_POLLING_PORT = 17788
...
@@ -71,8 +60,7 @@ KVSENDER_POLLING_PORT = 17788
KVRECEIVER_POLLING_PORT
=
27788
KVRECEIVER_POLLING_PORT
=
27788
class
KVManager
:
class
MooncakeKVManager
(
BaseKVManager
):
# TODO: make it general and support multiple transfer backend before merging
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
):
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
):
self
.
engine
=
MooncakeTransferEngine
()
self
.
engine
=
MooncakeTransferEngine
()
self
.
kv_args
=
args
self
.
kv_args
=
args
...
@@ -331,9 +319,11 @@ class KVManager:
...
@@ -331,9 +319,11 @@ class KVManager:
return
self
.
engine
.
get_session_id
()
return
self
.
engine
.
get_session_id
()
class
KVSender
:
class
MooncakeKVSender
(
BaseKVSender
)
:
def
__init__
(
self
,
mgr
:
KVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
):
def
__init__
(
self
,
mgr
:
MooncakeKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
int
):
self
.
kv_mgr
=
mgr
self
.
kv_mgr
=
mgr
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
kv_mgr
.
set_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
kv_mgr
.
set_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
...
@@ -353,10 +343,13 @@ class KVSender:
...
@@ -353,10 +343,13 @@ class KVSender:
raise
Exception
(
"Fake KVSender Exception"
)
raise
Exception
(
"Fake KVSender Exception"
)
class
KVReceiver
:
class
Mooncake
KVReceiver
(
BaseKVReceiver
)
:
def
__init__
(
def
__init__
(
self
,
mgr
:
KVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
self
,
mgr
:
MooncakeKVManager
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
):
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
bootstrap_addr
=
bootstrap_addr
...
@@ -403,7 +396,7 @@ class KVReceiver:
...
@@ -403,7 +396,7 @@ class KVReceiver:
raise
Exception
(
"Fake KVReceiver Exception"
)
raise
Exception
(
"Fake KVReceiver Exception"
)
class
KVBootstrapServer
:
class
Mooncake
KVBootstrapServer
(
BaseKVBootstrapServer
)
:
def
__init__
(
self
,
port
:
int
):
def
__init__
(
self
,
port
:
int
):
self
.
port
=
port
self
.
port
=
port
self
.
app
=
web
.
Application
()
self
.
app
=
web
.
Application
()
...
...
python/sglang/srt/disaggregation/transfer_engine
/mooncake
.py
→
python/sglang/srt/disaggregation/
mooncake/
transfer_engine.py
View file @
a9499885
File moved
python/sglang/srt/disaggregation/prefill.py
View file @
a9499885
...
@@ -24,10 +24,19 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -24,10 +24,19 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
import
torch
from
sglang.srt.disaggregation.conn
import
KVArgs
,
KVManager
,
KVPoll
,
KVSender
from
sglang.srt.disaggregation.base
import
(
BaseKVManager
,
BaseKVReceiver
,
BaseKVSender
,
KVArgs
,
KVPoll
,
)
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
DisaggregationMode
,
KVClassType
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
get_kv_class
,
poll_and_all_reduce
,
poll_and_all_reduce
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
...
@@ -38,6 +47,7 @@ if TYPE_CHECKING:
...
@@ -38,6 +47,7 @@ if TYPE_CHECKING:
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
,
Scheduler
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
,
Scheduler
from
sglang.srt.mem_cache.memory_pool
import
KVCache
from
sglang.srt.mem_cache.memory_pool
import
KVCache
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -56,6 +66,7 @@ class PrefillBootstrapQueue:
...
@@ -56,6 +66,7 @@ class PrefillBootstrapQueue:
tp_size
:
int
,
tp_size
:
int
,
bootstrap_port
:
int
,
bootstrap_port
:
int
,
gloo_group
:
ProcessGroup
,
gloo_group
:
ProcessGroup
,
transfer_backend
:
TransferBackend
,
):
):
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
aux_dtype
=
aux_dtype
self
.
aux_dtype
=
aux_dtype
...
@@ -64,6 +75,7 @@ class PrefillBootstrapQueue:
...
@@ -64,6 +75,7 @@ class PrefillBootstrapQueue:
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
self
.
transfer_backend
=
transfer_backend
self
.
kv_manager
=
self
.
_init_kv_manager
()
self
.
kv_manager
=
self
.
_init_kv_manager
()
self
.
queue
:
List
[
Req
]
=
[]
self
.
queue
:
List
[
Req
]
=
[]
self
.
gloo_group
=
gloo_group
self
.
gloo_group
=
gloo_group
...
@@ -74,7 +86,7 @@ class PrefillBootstrapQueue:
...
@@ -74,7 +86,7 @@ class PrefillBootstrapQueue:
output_id_buffer
=
self
.
metadata_buffers
[
0
]
output_id_buffer
=
self
.
metadata_buffers
[
0
]
output_id_buffer
[
idx
]
=
token_id
output_id_buffer
[
idx
]
=
token_id
def
_init_kv_manager
(
self
)
->
KVManager
:
def
_init_kv_manager
(
self
)
->
Base
KVManager
:
kv_args
=
KVArgs
()
kv_args
=
KVArgs
()
kv_args
.
engine_rank
=
self
.
tp_rank
kv_args
.
engine_rank
=
self
.
tp_rank
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
=
(
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
=
(
...
@@ -96,11 +108,13 @@ class PrefillBootstrapQueue:
...
@@ -96,11 +108,13 @@ class PrefillBootstrapQueue:
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
]
kv_args
.
ib_device
=
"mock-ib-device"
kv_args
.
ib_device
=
"mock-ib-device"
kv_manager
=
KVManager
(
kv_args
,
DisaggregationMode
(
"prefill"
))
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
kv_manager
=
kv_manager_class
(
kv_args
,
DisaggregationMode
.
PREFILL
)
return
kv_manager
return
kv_manager
def
add
(
self
,
req
:
Req
)
->
None
:
def
add
(
self
,
req
:
Req
)
->
None
:
req
.
disagg_kv_sender
=
KVSender
(
kv_sender_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
SENDER
)
req
.
disagg_kv_sender
=
kv_sender_class
(
mgr
=
self
.
kv_manager
,
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
self
.
bootstrap_port
}
"
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
self
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
bootstrap_room
=
req
.
bootstrap_room
,
...
...
python/sglang/srt/disaggregation/utils.py
View file @
a9499885
...
@@ -42,3 +42,34 @@ class ReqToMetadataIdxAllocator:
...
@@ -42,3 +42,34 @@ class ReqToMetadataIdxAllocator:
def
free
(
self
,
free_index
:
int
):
def
free
(
self
,
free_index
:
int
):
self
.
free_slots
.
append
(
free_index
)
self
.
free_slots
.
append
(
free_index
)
class
TransferBackend
(
Enum
):
MOONCAKE
=
"mooncake"
FAKE
=
"fake"
class
KVClassType
(
Enum
):
MANAGER
=
"manager"
SENDER
=
"sender"
RECEIVER
=
"receiver"
BOOTSTRAP_SERVER
=
"bootstrap_server"
def
get_kv_class
(
transfer_backend
:
TransferBackend
,
class_type
:
KVClassType
):
if
transfer_backend
==
TransferBackend
.
MOONCAKE
:
from
sglang.srt.disaggregation.mooncake
import
(
MooncakeKVBootstrapServer
,
MooncakeKVManager
,
MooncakeKVReceiver
,
MooncakeKVSender
,
)
class_mapping
=
{
KVClassType
.
MANAGER
:
MooncakeKVManager
,
KVClassType
.
SENDER
:
MooncakeKVSender
,
KVClassType
.
RECEIVER
:
MooncakeKVReceiver
,
KVClassType
.
BOOTSTRAP_SERVER
:
MooncakeKVBootstrapServer
,
}
return
class_mapping
.
get
(
class_type
)
raise
ValueError
(
f
"Unsupported transfer backend:
{
transfer_backend
}
"
)
python/sglang/srt/managers/schedule_batch.py
View file @
a9499885
...
@@ -45,7 +45,7 @@ import triton.language as tl
...
@@ -45,7 +45,7 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.disaggregation.
conn
import
KVSender
from
sglang.srt.disaggregation.
base
import
Base
KVSender
from
sglang.srt.disaggregation.decode
import
ScheduleBatchDisaggregationDecodeMixin
from
sglang.srt.disaggregation.decode
import
ScheduleBatchDisaggregationDecodeMixin
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
...
@@ -525,7 +525,7 @@ class Req:
...
@@ -525,7 +525,7 @@ class Req:
# For disaggregation
# For disaggregation
self
.
bootstrap_host
:
str
=
bootstrap_host
self
.
bootstrap_host
:
str
=
bootstrap_host
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
disagg_kv_sender
:
Optional
[
KVSender
]
=
None
self
.
disagg_kv_sender
:
Optional
[
Base
KVSender
]
=
None
# used for warmup because we don't have a pair yet when init
# used for warmup because we don't have a pair yet when init
self
.
skip_kv_transfer
:
bool
=
False
self
.
skip_kv_transfer
:
bool
=
False
...
...
python/sglang/srt/managers/scheduler.py
View file @
a9499885
...
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
...
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
DisaggregationMode
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
)
)
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
...
@@ -530,6 +531,10 @@ class Scheduler(
...
@@ -530,6 +531,10 @@ class Scheduler(
)
)
def
init_disaggregation
(
self
):
def
init_disaggregation
(
self
):
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
if
(
if
(
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
):
# *2 for the headroom.
):
# *2 for the headroom.
...
@@ -567,6 +572,7 @@ class Scheduler(
...
@@ -567,6 +572,7 @@ class Scheduler(
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
transfer_backend
=
self
.
transfer_backend
,
)
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# *2 for the headroom.
# *2 for the headroom.
...
@@ -592,6 +598,7 @@ class Scheduler(
...
@@ -592,6 +598,7 @@ class Scheduler(
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
gloo_group
=
self
.
tp_worker
.
get_attention_tp_cpu_group
(),
gloo_group
=
self
.
tp_worker
.
get_attention_tp_cpu_group
(),
transfer_backend
=
self
.
transfer_backend
,
)
)
# The prefill requests that are in the middle of kv sending
# The prefill requests that are in the middle of kv sending
self
.
disagg_prefill_inflight_queue
:
List
[
Req
]
=
[]
self
.
disagg_prefill_inflight_queue
:
List
[
Req
]
=
[]
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
a9499885
...
@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
...
@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
from
sglang.srt.aio_rwlock
import
RWLock
from
sglang.srt.aio_rwlock
import
RWLock
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.disaggregation.conn
import
KVBootstrapServer
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
DisaggregationMode
,
KVClassType
,
TransferBackend
,
get_kv_class
,
)
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
...
@@ -329,10 +333,16 @@ class TokenizerManager:
...
@@ -329,10 +333,16 @@ class TokenizerManager:
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
disaggregation_mode
=
DisaggregationMode
(
self
.
server_args
.
disaggregation_mode
self
.
server_args
.
disaggregation_mode
)
)
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
)
# for disaggregtion, start kv boostrap server on prefill
# for disaggregtion, start kv boostrap server on prefill
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
# only start bootstrap server on prefill tm
# only start bootstrap server on prefill tm
self
.
bootstrap_server
=
KVBootstrapServer
(
kv_bootstrap_server_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
BOOTSTRAP_SERVER
)
self
.
bootstrap_server
=
kv_bootstrap_server_class
(
self
.
server_args
.
disaggregation_bootstrap_port
self
.
server_args
.
disaggregation_bootstrap_port
)
)
...
...
python/sglang/srt/server_args.py
View file @
a9499885
...
@@ -195,6 +195,7 @@ class ServerArgs:
...
@@ -195,6 +195,7 @@ class ServerArgs:
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode
:
str
=
"null"
disaggregation_mode
:
str
=
"null"
disaggregation_bootstrap_port
:
int
=
8998
disaggregation_bootstrap_port
:
int
=
8998
disaggregation_transfer_backend
:
str
=
"mooncake"
# multimodal
# multimodal
disable_fast_image_processor
:
bool
=
False
disable_fast_image_processor
:
bool
=
False
...
@@ -1173,6 +1174,12 @@ class ServerArgs:
...
@@ -1173,6 +1174,12 @@ class ServerArgs:
default
=
ServerArgs
.
disaggregation_bootstrap_port
,
default
=
ServerArgs
.
disaggregation_bootstrap_port
,
help
=
"Bootstrap server port on the prefill server. Default is 8998."
,
help
=
"Bootstrap server port on the prefill server. Default is 8998."
,
)
)
parser
.
add_argument
(
"--disaggregation-transfer-backend"
,
type
=
str
,
default
=
ServerArgs
.
disaggregation_transfer_backend
,
help
=
"The backend for disaggregation transfer. Default is mooncake."
,
)
# Multimodal
# Multimodal
parser
.
add_argument
(
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment