"tests/pytorch/test_basics.py" did not exist on "5e75f5dbfcc30d9170dc1a2999ef19efc10246d7"
Unverified Commit a9499885 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Add transfer backend abstraction (#5328)

parent f7655790
from .conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.utils import DisaggregationMode
class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str
class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4
class BaseKVManager(ABC):
"""Base class for managing transfers states"""
@abstractmethod
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ...
class BaseKVSender(ABC):
@abstractmethod
def __init__(
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
): ...
@abstractmethod
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
"""
Notify the decoder server about the kv indices length and aux index
"""
...
@abstractmethod
def send(self, kv_indices: npt.NDArray[np.int64]):
"""
Send the kv cache at the given kv indices to the decoder server
"""
...
@abstractmethod
def poll(self) -> KVPoll:
"""
Check the status of the kv cache transfer
"""
...
@abstractmethod
def failure_exception(self):
"""
Raise an exception if the kv cache transfer fails
"""
...
class BaseKVReceiver(ABC):
@abstractmethod
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
): ...
@abstractmethod
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
"""
Notify the prefill server about the kv indices and aux index
"""
...
@abstractmethod
def poll(self) -> KVPoll:
"""
Check the status of the kv cache transfer
"""
...
@abstractmethod
def failure_exception(self):
"""
Raise an exception if the kv cache transfer fails
"""
...
class BaseKVBootstrapServer(ABC):
@abstractmethod
def __init__(self, port: int): ...
...@@ -28,10 +28,19 @@ import numpy as np ...@@ -28,10 +28,19 @@ import numpy as np
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver from sglang.srt.disaggregation.base import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
KVClassType,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
poll_and_all_reduce, poll_and_all_reduce,
) )
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
...@@ -51,7 +60,7 @@ if TYPE_CHECKING: ...@@ -51,7 +60,7 @@ if TYPE_CHECKING:
@dataclass @dataclass
class DecodeRequest: class DecodeRequest:
req: Req req: Req
kv_receiver: KVReceiver kv_receiver: BaseKVReceiver
waiting_for_input: bool = False waiting_for_input: bool = False
metadata_buffer_index: int = -1 metadata_buffer_index: int = -1
...@@ -75,6 +84,7 @@ class DecodePreallocQueue: ...@@ -75,6 +84,7 @@ class DecodePreallocQueue:
tp_rank: int, tp_rank: int,
tp_size: int, tp_size: int,
bootstrap_port: int, bootstrap_port: int,
transfer_backend: TransferBackend,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
...@@ -94,9 +104,10 @@ class DecodePreallocQueue: ...@@ -94,9 +104,10 @@ class DecodePreallocQueue:
# Queue for requests pending pre-allocation # Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = [] self.queue: List[DecodeRequest] = []
self.transfer_backend = transfer_backend
self.kv_manager = self._init_kv_manager() self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> KVManager: def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs() kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = ( kv_data_ptrs, kv_data_lens, kv_item_lens = (
...@@ -117,13 +128,15 @@ class DecodePreallocQueue: ...@@ -117,13 +128,15 @@ class DecodePreallocQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args, DisaggregationMode("decode")) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE)
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
"""Add a request to the pending queue.""" """Add a request to the pending queue."""
kv_receiver = KVReceiver( kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
......
from .conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
...@@ -12,7 +12,15 @@ import numpy.typing as npt ...@@ -12,7 +12,15 @@ import numpy.typing as npt
import zmq import zmq
from aiohttp import web from aiohttp import web
from sglang.srt.disaggregation.transfer_engine.mooncake import MooncakeTransferEngine from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -44,25 +52,6 @@ def group_concurrent_contiguous( ...@@ -44,25 +52,6 @@ def group_concurrent_contiguous(
return src_groups, dst_groups return src_groups, dst_groups
class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str
class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4
RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]] RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]]
WaitingPoolType = Dict[ WaitingPoolType = Dict[
int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int] int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int]
...@@ -71,8 +60,7 @@ KVSENDER_POLLING_PORT = 17788 ...@@ -71,8 +60,7 @@ KVSENDER_POLLING_PORT = 17788
KVRECEIVER_POLLING_PORT = 27788 KVRECEIVER_POLLING_PORT = 27788
class KVManager: class MooncakeKVManager(BaseKVManager):
# TODO: make it general and support multiple transfer backend before merging
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
self.engine = MooncakeTransferEngine() self.engine = MooncakeTransferEngine()
self.kv_args = args self.kv_args = args
...@@ -331,9 +319,11 @@ class KVManager: ...@@ -331,9 +319,11 @@ class KVManager:
return self.engine.get_session_id() return self.engine.get_session_id()
class KVSender: class MooncakeKVSender(BaseKVSender):
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): def __init__(
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
):
self.kv_mgr = mgr self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
...@@ -353,10 +343,13 @@ class KVSender: ...@@ -353,10 +343,13 @@ class KVSender:
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
class KVReceiver: class MooncakeKVReceiver(BaseKVReceiver):
def __init__( def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None self,
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
): ):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
...@@ -403,7 +396,7 @@ class KVReceiver: ...@@ -403,7 +396,7 @@ class KVReceiver:
raise Exception("Fake KVReceiver Exception") raise Exception("Fake KVReceiver Exception")
class KVBootstrapServer: class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int): def __init__(self, port: int):
self.port = port self.port = port
self.app = web.Application() self.app = web.Application()
......
...@@ -24,10 +24,19 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -24,10 +24,19 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender from sglang.srt.disaggregation.base import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
KVClassType,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
poll_and_all_reduce, poll_and_all_reduce,
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
...@@ -38,6 +47,7 @@ if TYPE_CHECKING: ...@@ -38,6 +47,7 @@ if TYPE_CHECKING:
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.mem_cache.memory_pool import KVCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,6 +66,7 @@ class PrefillBootstrapQueue: ...@@ -56,6 +66,7 @@ class PrefillBootstrapQueue:
tp_size: int, tp_size: int,
bootstrap_port: int, bootstrap_port: int,
gloo_group: ProcessGroup, gloo_group: ProcessGroup,
transfer_backend: TransferBackend,
): ):
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
self.aux_dtype = aux_dtype self.aux_dtype = aux_dtype
...@@ -64,6 +75,7 @@ class PrefillBootstrapQueue: ...@@ -64,6 +75,7 @@ class PrefillBootstrapQueue:
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.transfer_backend = transfer_backend
self.kv_manager = self._init_kv_manager() self.kv_manager = self._init_kv_manager()
self.queue: List[Req] = [] self.queue: List[Req] = []
self.gloo_group = gloo_group self.gloo_group = gloo_group
...@@ -74,7 +86,7 @@ class PrefillBootstrapQueue: ...@@ -74,7 +86,7 @@ class PrefillBootstrapQueue:
output_id_buffer = self.metadata_buffers[0] output_id_buffer = self.metadata_buffers[0]
output_id_buffer[idx] = token_id output_id_buffer[idx] = token_id
def _init_kv_manager(self) -> KVManager: def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs() kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = ( kv_data_ptrs, kv_data_lens, kv_item_lens = (
...@@ -96,11 +108,13 @@ class PrefillBootstrapQueue: ...@@ -96,11 +108,13 @@ class PrefillBootstrapQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args, DisaggregationMode("prefill")) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL)
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
req.disagg_kv_sender = KVSender( kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
......
...@@ -42,3 +42,34 @@ class ReqToMetadataIdxAllocator: ...@@ -42,3 +42,34 @@ class ReqToMetadataIdxAllocator:
def free(self, free_index: int): def free(self, free_index: int):
self.free_slots.append(free_index) self.free_slots.append(free_index)
class TransferBackend(Enum):
MOONCAKE = "mooncake"
FAKE = "fake"
class KVClassType(Enum):
MANAGER = "manager"
SENDER = "sender"
RECEIVER = "receiver"
BOOTSTRAP_SERVER = "bootstrap_server"
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
if transfer_backend == TransferBackend.MOONCAKE:
from sglang.srt.disaggregation.mooncake import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
class_mapping = {
KVClassType.MANAGER: MooncakeKVManager,
KVClassType.SENDER: MooncakeKVSender,
KVClassType.RECEIVER: MooncakeKVReceiver,
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
...@@ -45,7 +45,7 @@ import triton.language as tl ...@@ -45,7 +45,7 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.conn import KVSender from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
...@@ -525,7 +525,7 @@ class Req: ...@@ -525,7 +525,7 @@ class Req:
# For disaggregation # For disaggregation
self.bootstrap_host: str = bootstrap_host self.bootstrap_host: str = bootstrap_host
self.bootstrap_room: Optional[int] = bootstrap_room self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[KVSender] = None self.disagg_kv_sender: Optional[BaseKVSender] = None
# used for warmup because we don't have a pair yet when init # used for warmup because we don't have a pair yet when init
self.skip_kv_transfer: bool = False self.skip_kv_transfer: bool = False
......
...@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import ( ...@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend,
) )
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
...@@ -530,6 +531,10 @@ class Scheduler( ...@@ -530,6 +531,10 @@ class Scheduler(
) )
def init_disaggregation(self): def init_disaggregation(self):
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
if ( if (
self.disaggregation_mode == DisaggregationMode.DECODE self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom. ): # *2 for the headroom.
...@@ -567,6 +572,7 @@ class Scheduler( ...@@ -567,6 +572,7 @@ class Scheduler(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port, bootstrap_port=self.server_args.disaggregation_bootstrap_port,
transfer_backend=self.transfer_backend,
) )
elif self.disaggregation_mode == DisaggregationMode.PREFILL: elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom. # *2 for the headroom.
...@@ -592,6 +598,7 @@ class Scheduler( ...@@ -592,6 +598,7 @@ class Scheduler(
tp_size=self.tp_size, tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port, bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(), gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
transfer_backend=self.transfer_backend,
) )
# The prefill requests that are in the middle of kv sending # The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = [] self.disagg_prefill_inflight_queue: List[Req] = []
......
...@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks ...@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.conn import KVBootstrapServer from sglang.srt.disaggregation.utils import (
from sglang.srt.disaggregation.utils import DisaggregationMode DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
...@@ -329,10 +333,16 @@ class TokenizerManager: ...@@ -329,10 +333,16 @@ class TokenizerManager:
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode self.server_args.disaggregation_mode
) )
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# for disaggregtion, start kv boostrap server on prefill # for disaggregtion, start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm # only start bootstrap server on prefill tm
self.bootstrap_server = KVBootstrapServer( kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port self.server_args.disaggregation_bootstrap_port
) )
......
...@@ -195,6 +195,7 @@ class ServerArgs: ...@@ -195,6 +195,7 @@ class ServerArgs:
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: str = "null" disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake"
# multimodal # multimodal
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
...@@ -1173,6 +1174,12 @@ class ServerArgs: ...@@ -1173,6 +1174,12 @@ class ServerArgs:
default=ServerArgs.disaggregation_bootstrap_port, default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.", help="Bootstrap server port on the prefill server. Default is 8998.",
) )
parser.add_argument(
"--disaggregation-transfer-backend",
type=str,
default=ServerArgs.disaggregation_transfer_backend,
help="The backend for disaggregation transfer. Default is mooncake.",
)
# Multimodal # Multimodal
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment