Unverified Commit 4a16a71c authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[PD] feat: mooncake use batch reg/dereg (#8910)


Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent a16923ef
...@@ -257,15 +257,17 @@ class MooncakeKVManager(BaseKVManager): ...@@ -257,15 +257,17 @@ class MooncakeKVManager(BaseKVManager):
) )
def register_buffer_to_engine(self): def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip( # Batch register KV data buffers
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:
): self.engine.batch_register(
self.engine.register(kv_data_ptr, kv_data_len) self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
)
for aux_data_ptr, aux_data_len in zip( # Batch register auxiliary data buffers
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:
): self.engine.batch_register(
self.engine.register(aux_data_ptr, aux_data_len) self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
@cache @cache
def _connect(self, endpoint: str, is_ipv6: bool = False): def _connect(self, endpoint: str, is_ipv6: bool = False):
......
...@@ -51,6 +51,35 @@ class MooncakeTransferEngine: ...@@ -51,6 +51,35 @@ class MooncakeTransferEngine:
if ret_value != 0: if ret_value != 0:
logger.debug("Mooncake memory deregistration %s failed.", ptr) logger.debug("Mooncake memory deregistration %s failed.", ptr)
def batch_register(self, ptrs: List[int], lengths: List[int]) -> int:
"""Batch register multiple memory regions."""
try:
ret_value = self.engine.batch_register_memory(ptrs, lengths)
except Exception:
# Mark batch register as failed
ret_value = -1
if not hasattr(self.engine, "batch_register_memory"):
raise RuntimeError(
"Mooncake's batch register requires a newer version of mooncake-transfer-engine. "
"Please upgrade Mooncake."
)
if ret_value != 0:
logger.debug("Mooncake batch memory registration failed.")
return ret_value
def batch_deregister(self, ptrs: List[int]) -> int:
"""Batch deregister multiple memory regions."""
try:
ret_value = self.engine.batch_unregister_memory(ptrs)
except Exception:
# Mark batch deregister as failed
ret_value = -1
if ret_value != 0:
logger.debug("Mooncake batch memory deregistration failed.")
return ret_value
def initialize( def initialize(
self, self,
hostname: str, hostname: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment