Unverified Commit b745e8b5 authored by yjz's avatar yjz Committed by GitHub
Browse files

[KVTransfer][Mooncake] Add heterogeneous TP support for disaggregated P/D in...


[KVTransfer][Mooncake] Add heterogeneous TP support for disaggregated P/D in MooncakeConnector (#36869)
Signed-off-by: default avatarJianDan0212 <zhangyj0212@gmail.com>
parent d215d1ef
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import logging
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
...@@ -66,6 +67,179 @@ TransferId = str # KV transfer coordination ID (shared by P/D) ...@@ -66,6 +67,179 @@ TransferId = str # KV transfer coordination ID (shared by P/D)
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass(frozen=True)
class TransferRegion:
base_addr: int
block_len: int
kv_block_len: int
def _get_tp_ratio(local_tp_size: int, remote_tp_size: int) -> int:
"""Return the TP ratio used by heterogeneous TP transfer planning.
Positive values mean one local rank maps into a larger remote KV region.
Negative values mean one local rank must gather from multiple remote KV
regions.
"""
if local_tp_size >= remote_tp_size:
assert local_tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {local_tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return local_tp_size // remote_tp_size
assert remote_tp_size % local_tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {local_tp_size}."
)
return -(remote_tp_size // local_tp_size)
def _expand_transfer_regions(
base_addrs: list[int],
block_lens: list[int],
is_kv_layout_blocks_first: bool,
) -> list[TransferRegion]:
"""Expand registered KV tensors into the regions transferred by Mooncake."""
assert len(base_addrs) == len(block_lens), (
"Mooncake transfer regions require matching numbers of base addresses "
f"and block lengths, got {len(base_addrs)} and {len(block_lens)}."
)
regions: list[TransferRegion] = []
for base_addr, block_len in zip(base_addrs, block_lens):
kv_block_len = block_len // 2 if is_kv_layout_blocks_first else block_len
regions.append(
TransferRegion(
base_addr=base_addr,
block_len=block_len,
kv_block_len=kv_block_len,
)
)
if is_kv_layout_blocks_first:
regions.append(
TransferRegion(
base_addr=base_addr + kv_block_len,
block_len=block_len,
kv_block_len=kv_block_len,
)
)
return regions
def _compute_sender_transfer_plan(
local_tp_rank: int,
local_tp_size: int,
remote_tp_rank: int,
remote_tp_size: int,
local_kv_block_len: int,
remote_kv_block_len: int,
producer_cache_replicated: bool,
) -> tuple[bool, int, int, int]:
"""Plan one producer-rank to one consumer-rank copy for heterogeneous TP."""
tp_ratio = _get_tp_ratio(local_tp_size, remote_tp_size)
if tp_ratio == 1:
return True, 0, 0, local_kv_block_len
if tp_ratio > 0:
if producer_cache_replicated:
return local_tp_rank % tp_ratio == 0, 0, 0, local_kv_block_len
return (
True,
0,
(local_tp_rank % tp_ratio) * local_kv_block_len,
local_kv_block_len,
)
if producer_cache_replicated:
return True, 0, 0, local_kv_block_len
ratio_abs = -tp_ratio
return (
True,
(remote_tp_rank % ratio_abs) * remote_kv_block_len,
0,
remote_kv_block_len,
)
def _can_coalesce_block_transfers(
local_region_block_len: int,
remote_region_block_len: int,
src_region_offset: int,
dst_region_offset: int,
transfer_len: int,
) -> bool:
"""Whether a contiguous block group can be emitted as one larger copy."""
return (
src_region_offset == 0
and dst_region_offset == 0
and transfer_len == local_region_block_len
and transfer_len == remote_region_block_len
)
def _validate_asymmetric_region_lengths(
local_regions: list[TransferRegion],
remote_regions: list[TransferRegion],
local_tp_size: int,
remote_tp_size: int,
producer_cache_replicated: bool,
) -> str | None:
"""Validate transfer-region metadata for a fixed producer/consumer pair.
This checks registered KV regions, not per-request block counts. A region
corresponds to one registered KV tensor, or one K/V half after expansion
for layouts that store K and V together.
"""
if len(local_regions) != len(remote_regions):
return (
"Mooncake asymmetric TP requires matching KV region counts between "
"producer and consumer."
)
if producer_cache_replicated:
return None
tp_ratio = _get_tp_ratio(local_tp_size, remote_tp_size)
for idx, (local_region, remote_region) in enumerate(
zip(local_regions, remote_regions)
):
if tp_ratio == 1:
if local_region.kv_block_len != remote_region.kv_block_len:
return (
"Mooncake KV region length mismatch for homogeneous TP at "
f"region {idx}: local={local_region.kv_block_len}, "
f"remote={remote_region.kv_block_len}."
)
elif tp_ratio > 0:
if remote_region.kv_block_len != local_region.kv_block_len * tp_ratio:
return (
"Mooncake destination KV region length does not match the "
"producer TP ratio at region "
f"{idx}: local={local_region.kv_block_len}, "
f"remote={remote_region.kv_block_len}, tp_ratio={tp_ratio}."
)
else:
ratio_abs = -tp_ratio
if local_region.kv_block_len != remote_region.kv_block_len * ratio_abs:
return (
"Mooncake source KV region length does not match the "
"consumer TP ratio at region "
f"{idx}: local={local_region.kv_block_len}, "
f"remote={remote_region.kv_block_len}, tp_ratio={tp_ratio}."
)
return None
def _get_tensor_dense_flag(tensor: torch.Tensor) -> bool | None:
is_dense = getattr(tensor, "is_non_overlapping_and_dense", None)
if callable(is_dense):
return bool(is_dense())
return None
class MooncakeXferMetadata( class MooncakeXferMetadata(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
...@@ -76,6 +250,7 @@ class MooncakeXferMetadata( ...@@ -76,6 +250,7 @@ class MooncakeXferMetadata(
remote_tp_rank: int remote_tp_rank: int
req_blocks: dict[ReqId, tuple[TransferId, list[int]]] req_blocks: dict[ReqId, tuple[TransferId, list[int]]]
kv_caches_base_addr: list[int] kv_caches_base_addr: list[int]
block_lens: list[int]
class MooncakeXferResponseStatus(IntEnum): class MooncakeXferResponseStatus(IntEnum):
...@@ -173,6 +348,24 @@ class MooncakeConnector(KVConnectorBase_V1): ...@@ -173,6 +348,24 @@ class MooncakeConnector(KVConnectorBase_V1):
self.connector_scheduler = None self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id) self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config is None:
# This fallback mostly exists for unit tests that instantiate the
# connector without a fully populated model config.
logger.warning_once(
"Unable to detect current VLLM config. "
"Fallback to default kv cache layout."
)
return None
if vllm_config.model_config.use_mla:
return None
logger.info_once(
"MooncakeConnector setting KV cache layout to HND for "
"heterogeneous TP-safe KV transfer."
)
return "HND"
############################################################ ############################################################
# Scheduler Side Methods # Scheduler Side Methods
############################################################ ############################################################
...@@ -487,6 +680,8 @@ class MooncakeConnectorWorker: ...@@ -487,6 +680,8 @@ class MooncakeConnectorWorker:
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_blocks = 0 self.num_blocks = 0
self.block_len_per_layer: list[int] = []
self.seen_base_addresses: list[int] = []
assert (parallel_config := vllm_config.parallel_config) assert (parallel_config := vllm_config.parallel_config)
dp_rank = parallel_config.data_parallel_index dp_rank = parallel_config.data_parallel_index
...@@ -685,9 +880,13 @@ class MooncakeConnectorWorker: ...@@ -685,9 +880,13 @@ class MooncakeConnectorWorker:
): ):
pending_reqs: dict[ReqId, SendBlockMeta] = {} pending_reqs: dict[ReqId, SendBlockMeta] = {}
remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size) remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size)
if self.tp_rank not in remote_tp_ranks: if meta.remote_tp_rank not in remote_tp_ranks:
# This D worker does not pair with the P worker. # This D worker does not pair with the P worker.
msg = f"This P tp_rank {self.tp_rank} not in remote D target ranks {remote_tp_ranks}" # noqa: E501 msg = (
"This D tp_rank "
f"{meta.remote_tp_rank} is not paired with P tp_rank "
f"{self.tp_rank}; expected one of {remote_tp_ranks}."
)
logger.error(msg) logger.error(msg)
response = MooncakeXferResponse( response = MooncakeXferResponse(
status=MooncakeXferResponseStatus.ERROR, status=MooncakeXferResponseStatus.ERROR,
...@@ -695,6 +894,26 @@ class MooncakeConnectorWorker: ...@@ -695,6 +894,26 @@ class MooncakeConnectorWorker:
) )
await sock.send_multipart((identity, self._encoder.encode(response))) await sock.send_multipart((identity, self._encoder.encode(response)))
return return
local_regions = self._get_transfer_regions(
self.kv_caches_base_addr, self.block_len_per_layer
)
remote_regions = self._get_transfer_regions(
meta.kv_caches_base_addr, meta.block_lens
)
validation_err = _validate_asymmetric_region_lengths(
local_regions=local_regions,
remote_regions=remote_regions,
local_tp_size=self.tp_size,
remote_tp_size=meta.remote_tp_size,
producer_cache_replicated=self._producer_cache_is_replicated(),
)
if validation_err is not None:
response = MooncakeXferResponse(
status=MooncakeXferResponseStatus.ERROR,
err_msg=validation_err,
)
await sock.send_multipart((identity, self._encoder.encode(response)))
return
for d_req_id, (transfer_id, _) in meta.req_blocks.items(): for d_req_id, (transfer_id, _) in meta.req_blocks.items():
if transfer_id not in self.reqs_need_send: if transfer_id not in self.reqs_need_send:
# This req is not enqueued in P side yet, create it here. # This req is not enqueued in P side yet, create it here.
...@@ -763,17 +982,24 @@ class MooncakeConnectorWorker: ...@@ -763,17 +982,24 @@ class MooncakeConnectorWorker:
"Request %s expired before sending on P side.", d_req_id "Request %s expired before sending on P side.", d_req_id
) )
src_ptrs, dst_ptrs, lengths, err_reqs = await self._build_transfer_params( (
ready_reqs, meta src_ptrs,
dst_ptrs,
lengths,
err_reqs,
err_msg,
) = await self._build_transfer_params(
ready_reqs,
meta,
local_regions,
remote_regions,
) )
err_req_set = set(err_reqs)
if err_reqs: ok_ready_reqs = [
response = MooncakeXferResponse( (d_req_id, send_meta)
status=response_status, for d_req_id, send_meta in ready_reqs
err_reqs=err_reqs, if d_req_id not in err_req_set
err_msg="P num blocks less than D", ]
)
await sock.send_multipart((identity, self._encoder.encode(response)))
if src_ptrs: if src_ptrs:
remote_session = f"{meta.remote_hostname}:{meta.remote_port}" remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
...@@ -787,58 +1013,61 @@ class MooncakeConnectorWorker: ...@@ -787,58 +1013,61 @@ class MooncakeConnectorWorker:
) )
if ret_value != 0: if ret_value != 0:
err_reqs = [] transfer_err_msg = f"Mooncake transfer engine returned {ret_value}"
for d_req_id, send_meta in ready_reqs: err_msg = (
send_meta.sending -= 1 transfer_err_msg
err_reqs.append(d_req_id) if err_msg is None
# Do best effort to transfer the remaining reqs. else f"{err_msg}; {transfer_err_msg}"
response = MooncakeXferResponse(
status=response_status,
err_reqs=err_reqs,
err_msg=f"Mooncake transfer engine returned {ret_value}",
) )
await sock.send_multipart( err_reqs = list(err_reqs)
(identity, self._encoder.encode(response)) for d_req_id, _ in ok_ready_reqs:
) err_reqs.append(d_req_id)
continue err_req_set.add(d_req_id)
ok_ready_reqs = []
for d_req_id, send_meta in ready_reqs: for d_req_id, send_meta in ready_reqs:
# TODO: for heterogeneous TP (one P pairs to multiple D),
# we need to check whether all headers are sent.
# If not, we should set expire_time to normal and skip the below.
send_meta.sending -= 1 send_meta.sending -= 1
if d_req_id in err_req_set:
continue
send_meta.sent += 1 send_meta.sent += 1
if send_meta.sent == send_meta.need_send: if (
del self.reqs_need_send[send_meta.transfer_id] send_meta.sent == send_meta.need_send
and self.reqs_need_send.pop(send_meta.transfer_id, None) is not None
):
self.finished_sending_reqs.add(send_meta.p_req_id) self.finished_sending_reqs.add(send_meta.p_req_id)
response = MooncakeXferResponse( response = MooncakeXferResponse(
status=response_status, status=response_status,
ok_reqs=[d_req_id for d_req_id, _ in ready_reqs], ok_reqs=[d_req_id for d_req_id, _ in ok_ready_reqs] or None,
err_reqs=err_reqs or None,
err_msg=err_msg,
) )
await sock.send_multipart((identity, self._encoder.encode(response))) await sock.send_multipart((identity, self._encoder.encode(response)))
def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int]): def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int]):
# Prepare for heterogeneous TP (one P pairs to multiple D) # Prepare for heterogeneous TP (one P pairs to multiple D)
send_meta.need_send = len(remote_tp_ranks) send_meta.need_send = len(remote_tp_ranks)
if send_meta.need_send != 1: logger.debug(
logger.error("Mooncake: Heterogeneous TP is not supported yet.") "Mooncake request %s will be served by %d consumer TP workers: %s",
raise NotImplementedError( send_meta.transfer_id,
"Mooncake: Heterogeneous TP is not supported yet." send_meta.need_send,
) remote_tp_ranks,
)
async def _build_transfer_params( async def _build_transfer_params(
self, self,
ready_reqs: list[tuple[ReqId, SendBlockMeta]], ready_reqs: list[tuple[ReqId, SendBlockMeta]],
agent_meta: MooncakeXferMetadata, agent_meta: MooncakeXferMetadata,
) -> tuple[list[int], list[int], list[int], list[ReqId]]: local_regions: list[TransferRegion],
remote_regions: list[TransferRegion],
) -> tuple[list[int], list[int], list[int], list[ReqId], str | None]:
src_ptrs = [] src_ptrs = []
dst_ptrs = [] dst_ptrs = []
lengths = [] lengths = []
err_reqs: list[ReqId] = [] err_reqs: list[ReqId] = []
local_base_addr = self.kv_caches_base_addr err_msg: str | None = None
remote_base_addr = agent_meta.kv_caches_base_addr
block_len = self.block_len
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
for d_req_id, send_meta in ready_reqs: for d_req_id, send_meta in ready_reqs:
...@@ -858,6 +1087,8 @@ class MooncakeConnectorWorker: ...@@ -858,6 +1087,8 @@ class MooncakeConnectorWorker:
num_remote_blocks, num_remote_blocks,
) )
err_reqs.append(d_req_id) err_reqs.append(d_req_id)
if err_msg is None:
err_msg = "P num blocks less than D"
continue continue
if num_local_blocks > num_remote_blocks: if num_local_blocks > num_remote_blocks:
local_block_ids = local_block_ids[-num_remote_blocks:] local_block_ids = local_block_ids[-num_remote_blocks:]
...@@ -867,19 +1098,87 @@ class MooncakeConnectorWorker: ...@@ -867,19 +1098,87 @@ class MooncakeConnectorWorker:
local_block_ids, remote_block_ids local_block_ids, remote_block_ids
) )
for local_layer_addr, remote_layer_addr in zip( for local_region, remote_region in zip(local_regions, remote_regions):
local_base_addr, remote_base_addr should_transfer, src_region_offset, dst_region_offset, transfer_len = (
): self._get_sender_transfer_plan(
local_kv_block_len=local_region.kv_block_len,
remote_kv_block_len=remote_region.kv_block_len,
remote_tp_rank=agent_meta.remote_tp_rank,
remote_tp_size=agent_meta.remote_tp_size,
)
)
if not should_transfer:
# Replicated KV cache: only one producer rank in the TP group
# needs to send the actual bytes for this paired decoder rank.
# TODO: Account for replicated producer KV in
# get_target_remote_ranks() so we can avoid sending
# unnecessary ZMQ requests and remove this branch.
continue
assert src_region_offset + transfer_len <= local_region.kv_block_len, (
"Computed source transfer region exceeds local KV block size."
)
assert dst_region_offset + transfer_len <= remote_region.kv_block_len, (
"Computed destination transfer region exceeds remote KV block size."
)
# Collapse one contiguous block group into a single larger
# transfer descriptor when the per-block copy is identical.
can_coalesce = _can_coalesce_block_transfers(
local_region_block_len=local_region.block_len,
remote_region_block_len=remote_region.block_len,
src_region_offset=src_region_offset,
dst_region_offset=dst_region_offset,
transfer_len=transfer_len,
)
for group_local_block_id, group_remote_block_id in zip( for group_local_block_id, group_remote_block_id in zip(
group_local_block_ids, group_remote_block_ids group_local_block_ids, group_remote_block_ids
): ):
src_ptrs.append( if can_coalesce:
local_layer_addr + group_local_block_id[0] * block_len src_ptrs.append(
) local_region.base_addr
dst_ptrs.append( + group_local_block_id[0] * local_region.block_len
remote_layer_addr + group_remote_block_id[0] * block_len + src_region_offset
)
dst_ptrs.append(
remote_region.base_addr
+ group_remote_block_id[0] * remote_region.block_len
+ dst_region_offset
)
lengths.append(transfer_len * len(group_local_block_id))
else:
for local_block_id, remote_block_id in zip(
group_local_block_id, group_remote_block_id
):
src_ptrs.append(
local_region.base_addr
+ local_block_id * local_region.block_len
+ src_region_offset
)
dst_ptrs.append(
remote_region.base_addr
+ remote_block_id * remote_region.block_len
+ dst_region_offset
)
lengths.append(transfer_len)
if local_region is local_regions[0]:
logger.debug(
"Mooncake transfer plan for request %s: local_tp=%d "
"remote_tp=%d remote_tp_rank=%d local_block_len=%d "
"remote_block_len=%d src_offset=%d dst_offset=%d "
"transfer_len=%d coalesce=%s",
d_req_id,
self.tp_size,
agent_meta.remote_tp_size,
agent_meta.remote_tp_rank,
local_region.block_len,
remote_region.block_len,
src_region_offset,
dst_region_offset,
transfer_len,
can_coalesce,
) )
lengths.append(block_len * len(group_local_block_id))
logger.debug( logger.debug(
"Sending kv_caches for request %s (%d blocks) to %s", "Sending kv_caches for request %s (%d blocks) to %s",
...@@ -888,7 +1187,7 @@ class MooncakeConnectorWorker: ...@@ -888,7 +1187,7 @@ class MooncakeConnectorWorker:
remote_session, remote_session,
) )
return src_ptrs, dst_ptrs, lengths, err_reqs return src_ptrs, dst_ptrs, lengths, err_reqs, err_msg
def _send_blocks( def _send_blocks(
self, self,
...@@ -917,16 +1216,20 @@ class MooncakeConnectorWorker: ...@@ -917,16 +1216,20 @@ class MooncakeConnectorWorker:
kv_data_ptrs = [] kv_data_ptrs = []
kv_data_lens = [] kv_data_lens = []
seen_base_addresses = [] seen_base_addresses = []
self.block_len_per_layer = []
split_k_and_v = self.kv_topo.split_k_and_v split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items(): for layer_name, cache_or_caches in kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
logger.debug( logger.debug(
"registering layer %s with shape %s", layer_name, cache_or_caches.shape "registering layer %s with %d cache tensor(s)",
layer_name,
len(cache_list),
) )
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list: for cache in cache_list:
self._log_debug_cache_registration(layer_name, cache)
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
if base_addr in seen_base_addresses: if base_addr in seen_base_addresses:
continue continue
...@@ -937,16 +1240,24 @@ class MooncakeConnectorWorker: ...@@ -937,16 +1240,24 @@ class MooncakeConnectorWorker:
if tensor_size_bytes is None: if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0] self.num_blocks = cache.shape[0]
assert cache.shape[0] == self.num_blocks, (
assert tensor_size_bytes == curr_tensor_size_bytes, ( "All kv cache tensors must have the same number of blocks"
"All kv cache tensors must have the same size" )
assert curr_tensor_size_bytes % self.num_blocks == 0, (
"Mooncake expects each kv cache tensor size to be "
"divisible by the number of blocks."
)
self.block_len_per_layer.append(
curr_tensor_size_bytes // self.num_blocks
) )
kernel_block_size = cache.shape[-2 if self.use_mla else -3] kernel_block_size = cache.shape[-2 if self.use_mla else -3]
assert self.block_size == kernel_block_size assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr) kv_data_ptrs.append(base_addr)
kv_data_lens.append(tensor_size_bytes) kv_data_lens.append(curr_tensor_size_bytes)
self.kv_caches_base_addr = seen_base_addresses self.kv_caches_base_addr = seen_base_addresses
self.seen_base_addresses = seen_base_addresses
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens) ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
if ret_value != 0: if ret_value != 0:
...@@ -954,11 +1265,11 @@ class MooncakeConnectorWorker: ...@@ -954,11 +1265,11 @@ class MooncakeConnectorWorker:
assert tensor_size_bytes is not None assert tensor_size_bytes is not None
assert self.num_blocks != 0 assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.device_kv_caches = kv_caches self.device_kv_caches = kv_caches
logger.debug( logger.debug(
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len "registered num_blocks=%d block_lens=%s",
self.num_blocks,
self.block_len_per_layer,
) )
# No need to launch server for D node. # No need to launch server for D node.
...@@ -1052,6 +1363,7 @@ class MooncakeConnectorWorker: ...@@ -1052,6 +1363,7 @@ class MooncakeConnectorWorker:
for req_id, pull_meta in pull_metas.items() for req_id, pull_meta in pull_metas.items()
}, },
kv_caches_base_addr=self.kv_caches_base_addr, kv_caches_base_addr=self.kv_caches_base_addr,
block_lens=self.block_len_per_layer,
) )
encoded_data = self._encoder.encode(metadata) encoded_data = self._encoder.encode(metadata)
...@@ -1152,11 +1464,11 @@ class MooncakeConnectorWorker: ...@@ -1152,11 +1464,11 @@ class MooncakeConnectorWorker:
remote_engine_id remote_engine_id
) )
count = len(remote_tp_ranks) count = len(remote_tp_ranks)
if count != 1: logger.debug(
logger.error("Mooncake: Heterogeneous TP is not supported yet.") "Receiving Mooncake KV for engine %s from producer TP ranks %s",
raise NotImplementedError( remote_engine_id,
"Mooncake: Heterogeneous TP is not supported yet." remote_tp_ranks,
) )
for pull_meta in pull_metas.values(): for pull_meta in pull_metas.values():
pull_meta.pull_tasks_count = count pull_meta.pull_tasks_count = count
for remote_tp_rank in remote_tp_ranks: for remote_tp_rank in remote_tp_ranks:
...@@ -1239,6 +1551,52 @@ class MooncakeConnectorWorker: ...@@ -1239,6 +1551,52 @@ class MooncakeConnectorWorker:
self.record_send_reqs(metadata), self.sender_loop self.record_send_reqs(metadata), self.sender_loop
) )
def _producer_cache_is_replicated(self) -> bool:
return self.kv_topo.replicates_kv_cache(self.engine_id)
def _get_transfer_regions(
self, base_addrs: list[int], block_lens: list[int]
) -> list[TransferRegion]:
return _expand_transfer_regions(
base_addrs=base_addrs,
block_lens=block_lens,
is_kv_layout_blocks_first=self.kv_topo.is_kv_layout_blocks_first,
)
def _get_sender_transfer_plan(
self,
local_kv_block_len: int,
remote_kv_block_len: int,
remote_tp_rank: int,
remote_tp_size: int,
) -> tuple[bool, int, int, int]:
return _compute_sender_transfer_plan(
local_tp_rank=self.tp_rank,
local_tp_size=self.tp_size,
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
local_kv_block_len=local_kv_block_len,
remote_kv_block_len=remote_kv_block_len,
producer_cache_replicated=self._producer_cache_is_replicated(),
)
def _log_debug_cache_registration(
self, layer_name: str, cache: torch.Tensor
) -> None:
if not logger.isEnabledFor(logging.DEBUG):
return
logger.debug(
"Mooncake register view layer=%s shape=%s stride=%s "
"storage_offset=%d contiguous=%s dense=%s data_ptr=%d",
layer_name,
tuple(cache.shape),
tuple(cache.stride()),
cache.storage_offset(),
cache.is_contiguous(),
_get_tensor_dense_flag(cache),
cache.data_ptr(),
)
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: list[int], dst_indices: list[int] src_indices: list[int], dst_indices: list[int]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment