Commit ea78a424 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

chore: regenerate patch (#29)

parent 46ed649c
diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..a2f88854 100644
index 9ba49757..cbfeb715 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel):
......@@ -261,7 +261,7 @@ index c5b3b04f..c72001f7 100644
self.block_tables: Dict[SeqId, BlockTable] = {}
diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py
new file mode 100644
index 00000000..350453cd
index 00000000..8699ca06
--- /dev/null
+++ b/vllm/core/event_manager.py
@@ -0,0 +1,102 @@
......@@ -368,10 +368,15 @@ index 00000000..350453cd
+
+ self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..ee20d50c 100644
index f507847a..abe574d1 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -8,18 +8,17 @@ from collections import deque
@@ -4,22 +4,22 @@ import enum
import os
import random
import time
+import copy
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
......@@ -393,7 +398,7 @@ index f507847a..ee20d50c 100644
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
@@ -325,12 +324,14 @@ class Scheduler:
@@ -325,12 +325,14 @@ class Scheduler:
def __init__(
self,
......@@ -408,7 +413,7 @@ index f507847a..ee20d50c 100644
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +357,7 @@ class Scheduler:
@@ -356,6 +358,7 @@ class Scheduler:
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
......@@ -416,7 +421,7 @@ index f507847a..ee20d50c 100644
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +373,14 @@ class Scheduler:
@@ -371,6 +374,16 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
......@@ -424,6 +429,8 @@ index f507847a..ee20d50c 100644
+ # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque()
+ # Contain requests that are being prefilled by a local worker.
+ self.prefill_sending: Deque[SequenceGroup] = deque()
+
+ self._remote_prefill_outputs: Dict[str, int] = {}
+
......@@ -431,24 +438,25 @@ index f507847a..ee20d50c 100644
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
@@ -501,7 +511,7 @@ class Scheduler:
@@ -501,7 +514,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +533,7 @@ class Scheduler:
@@ -523,6 +536,8 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@@ -537,6 +548,8 @@ class Scheduler:
@@ -537,6 +552,8 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
......@@ -457,7 +465,7 @@ index f507847a..ee20d50c 100644
Returns:
SchedulerRunningOutputs.
@@ -566,6 +579,24 @@ class Scheduler:
@@ -566,6 +583,38 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
......@@ -468,6 +476,7 @@ index f507847a..ee20d50c 100644
+ if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group)
+ continue
+
+ else:
+ finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
......@@ -478,39 +487,63 @@ index f507847a..ee20d50c 100644
+ seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
+ remote_transfers_queue = self.prefill_sending
+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
+ while remote_transfers_queue:
+ seq_group = remote_transfers_queue.popleft()
+ if seq_group.request_id not in finished_transfers:
+ leftover_remote_transfers_sequences.append(seq_group)
+ else:
+ finished_transfers.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ self.free_seq(seq)
+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
+
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@@ -1008,7 +1039,7 @@ class Scheduler:
@@ -1008,7 +1057,17 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
- self._allocate_and_set_running(seq_group)
+
+ seq_group_copy = copy.deepcopy(seq_group)
+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1079,7 @@ class Scheduler:
@@ -1048,7 +1107,7 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1121,8 @@ class Scheduler:
@@ -1090,7 +1149,9 @@ class Scheduler:
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
- enable_chunking=False)
+ enable_chunking=False,
+ finished_prefills=finished_prefills)
+ finished_prefills=finished_prefills,
+ finished_transfers=finished_transfers)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
@@ -1106,7 +1138,12 @@ class Scheduler:
@@ -1106,7 +1167,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
......@@ -524,30 +557,31 @@ index f507847a..ee20d50c 100644
self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1285,14 @@ class Scheduler:
@@ -1248,12 +1314,14 @@ class Scheduler:
len(running_scheduled.swapped_out)),
)
- def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills:
+ if finished_prefills or finished_transfers:
+ raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill()
else:
- return self._schedule_default()
+ return self._schedule_default(finished_prefills)
+ return self._schedule_default(finished_prefills, finished_transfers)
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
@@ -1287,14 +1326,15 @@ class Scheduler:
@@ -1287,14 +1355,16 @@ class Scheduler:
return no_single_seq
def schedule(
- self
+ self,
+ finished_prefills: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
......@@ -556,11 +590,11 @@ index f507847a..ee20d50c 100644
- scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills)
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers)
now = time.time()
if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1373,8 @@ class Scheduler:
@@ -1333,7 +1403,8 @@ class Scheduler:
encoder_seq_data = None
cross_block_table = None
......@@ -570,18 +604,24 @@ index f507847a..ee20d50c 100644
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,6 +1405,10 @@ class Scheduler:
@@ -1364,9 +1435,16 @@ class Scheduler:
< seqs[0].data.get_len()):
do_sample = False
+ is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
@@ -1392,6 +1437,7 @@ class Scheduler:
+ logger.debug("Assinged blocks: %s", block_tables)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
......@@ -589,7 +629,7 @@ index f507847a..ee20d50c 100644
)
else:
# When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +1536,13 @@ class Scheduler:
@@ -1490,10 +1569,13 @@ class Scheduler:
self._async_stopped.clear()
......@@ -605,12 +645,80 @@ index f507847a..ee20d50c 100644
def _append_slots(self,
seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
--- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def rearrange_kernel(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int):
+ N, B, H, C = t1.shape
+
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
+ assert H % d == 0, "H must be divisible by d"
+
+ block_size = B * H * C
+ token_size = H * C
+ tensor_size = N * block_size
+ tensor_subset_size = tensor_size // d
+
+ BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+
+ rearrange_kernel[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
\ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..bc962726
index 00000000..86248e7b
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,249 @@
@@ -0,0 +1,318 @@
+import torch
+from typing import List, Tuple
+from vllm.config import VllmConfig
......@@ -618,39 +726,18 @@ index 00000000..bc962726
+import msgspec
+import time
+import uuid
+from nixl_wrapper import nixl_wrapper as NixlWrapper
+from collections import defaultdict
+from .kv_rearrange import rearrange_tensors
+
+logger = init_logger(__name__)
+
+
+def nixl_wrapper_init_patch(self, agent_name, nixl_config):
+ logger.info("Initializing patched NixlWrapper")
+ import nixl_bindings as nixl
+ # Read available backends and device info from nixl_config
+ # For now setting the multithreading to enabled.
+ devices = nixl.nixlAgentConfig(False)
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ print("Initializied NIXL agent:", agent_name)
+
+NixlWrapper.__init__ = nixl_wrapper_init_patch
+
+
+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+try:
+ from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ logger.info("NIXL is available")
+except ImportError:
+ logger.warning("NIXL is not available")
+ NixlWrapper = None
+
+class NixlMetadata(
+ msgspec.Struct,
......@@ -665,11 +752,17 @@ index 00000000..bc962726
+class DynemoNixlConnector:
+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
+ self.vllm_config = vllm_config
+ if NixlWrapper is None:
+ logger.error("NIXL is not available")
+ raise RuntimeError("NIXL is not available")
+ logger.info("Initializing NIXL wrapper")
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+
+ self.num_layers = None
+ self.num_blocks = None
+ self.num_heads = None
+ self.block_len = None
+ self.kv_caches = None
+ self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {}
+
......@@ -678,33 +771,51 @@ index 00000000..bc962726
+ self.engine_id = engine_id
+ self.rank = rank
+ self.notifs = {}
+ self._tp_size = {}
+ self._block_descs = {}
+ self._xfer_side_handles = {}
+
+
+ self._transfers = defaultdict(list)
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+
+
+ @property
+ def agent_name(self):
+ return self.nixl_wrapper.name
+
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ caches_data = []
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+
+ self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+ kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr()
+ region_len = cache.numel() * cache.element_size()
+ gpu_id = cache.get_device()
+ assert gpu_id > -1, "Tensor is not on GPU"
+ caches_data.append((base_addr, region_len, gpu_id))
+ region_len = num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank))
+ for block_id in range(self.num_blocks):
+ blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+
......@@ -714,10 +825,14 @@ index 00000000..bc962726
+ for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+
+ def add_remote_agent(self, engine_id, agent_metadata):
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata)
+ self._remote_agents[engine_id] = agent_name
+ return agent_name
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
+ for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
......@@ -732,17 +847,29 @@ index 00000000..bc962726
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids
+
+ def _get_range_descs(self, engine_id, ranges, layer_ids):
+ def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0):
+ if rank is None:
+ rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ blocks_data = []
+ for layer_id in layer_ids:
+ for range_start, range_end in ranges:
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id]
+ start_offset = range_start * self.block_len
+ blocks_len = (range_end - range_start + 1) * self.block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank))
+ range_len = range_end - range_start + 1
+ key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ start_offset = range_start * offset_block_len + tp_offset * range_len
+ blocks_len = range_len * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+
+ def _get_ranges(self, block_ids):
......@@ -755,9 +882,9 @@ index 00000000..bc962726
+ ranges = []
+ for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([sorted_block_ids[i]])
+ ranges.append([sorted_block_ids[i], sorted_block_ids[i]])
+ else:
+ ranges[-1].append(sorted_block_ids[i])
+ ranges[-1][1] = sorted_block_ids[i]
+ return ranges
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges):
......@@ -797,11 +924,24 @@ index 00000000..bc962726
+ src_idx += 1
+
+ return src_overlapping_ranges, dst_overlapping_ranges
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ return descs_ids
+
+
+
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False):
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
......@@ -810,44 +950,62 @@ index 00000000..bc962726
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids))
+ assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+
+ logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time)
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all")
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all")
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE")
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+ for i in range(tp_multiplier):
+
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
......@@ -860,153 +1018,138 @@ index 00000000..bc962726
+
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -27,13 +27,13 @@ class KVConnectorFactory:
@classmethod
def create_connector(cls, rank: int, local_rank: int,
- config: "VllmConfig") -> KVConnectorBase:
+ config: "VllmConfig", world_group) -> KVConnectorBase:
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
- return connector_cls(rank, local_rank, config)
+ return connector_cls(rank, local_rank, config, world_group)
# Register various connectors here.
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
+
+KVConnectorFactory.register_connector(
+ "DynemoNcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.dynemo_connector",
+ "DynemoConnector")
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e33919c1 100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -8,13 +8,15 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
+ def get_done_tranfers(self) -> List[str]:
+ done_req_ids = []
+ for req_id, handles in self._transfers.items():
+ running_reqs = []
+ for handle in handles:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ continue
+ if xfer_state == "PROC":
+ running_reqs.append(handle)
+ else:
+ raise RuntimeError("Transfer failed with state %s", xfer_state)
+ if len(running_reqs) == 0:
+ done_req_ids.append(req_id)
+ else:
+ self._transfers[req_id] = running_reqs
+ return done_req_ids
diff --git a/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
new file mode 100644
index 00000000..2319867a
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Simple KV Cache Connector for Distributed Machine Learning Inference
+
+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
+MooncakePipe.
+
+But the logic can be extended to support other pipe and lookup buffer.
+"""
+import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
-from vllm.config import VllmConfig
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import torch
+
+from vllm import _custom_ops as ops
+from vllm.config import VllmConfig, KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
@@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
rank: int,
local_rank: int,
config: VllmConfig,
+from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
+ SimpleBuffer)
+from vllm.logger import init_logger
+from vllm.sequence import IntermediateTensors
+
+if TYPE_CHECKING:
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
+
+logger = init_logger(__name__)
+
+
+class DynemoConnector(KVConnectorBase):
+
+ def __init__(
+ self,
+ rank: int,
+ local_rank: int,
+ config: VllmConfig,
+ world_group,
):
self.config = config.kv_transfer_config
@@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase):
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
+ ):
+
+ self.config = config.kv_transfer_config
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.rank = rank
+
+ if self.config.kv_connector != "DynemoNcclConnector":
+ raise NotImplementedError("Only DynemoNcclConnector is supported by the DynemoConnector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.dynemo_nccl_pipe import (
+ DynemoNcclDataPlane)
+
+ logger.info(
+ "Initializing DynemoNcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
+
+ self.producer_data_pipe: PyNcclPipe
+ self.consumer_data_pipe: PyNcclPipe
+ self.producer_signal_pipe: PyNcclPipe
+ self.consumer_signal_pipe: PyNcclPipe
+
+ self._broadcast_and_enhance_kv_config(rank, config, world_group)
+
+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+ self.tp_size = config.parallel_config.tensor_parallel_size
+
# 2 pipes for every rank in the world
- port_offset_base = 2 * rank
+ # 2 pipes for every rank in the world
+ if self.config.is_kv_producer:
+ port_offset_base = 2 * rank + 1
+ port_offset_base = rank + 1
+ else:
+ port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1
+ port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
+
+
+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase):
# its recv pipe to the send pipe of KV producder
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase):
self.config.kv_buffer_size,
)
- def select(self, input_tokens: Optional[torch.Tensor],
+ def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
+ logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
+ self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
+
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
- return self.consumer_buffer.drop_select(input_tokens, roi)
+ return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi)
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
+ logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
+ self.data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
+ local_rank=local_rank,
+ config=self.config,
+ port_offset=port_offset_base,
+ )
+
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
- self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
+ self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
@@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
+ self.data_plane = DynemoNcclDataPlane(
+ data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
+ )
+
+ def send_kv_caches_and_hidden_states(
+ self,
+ model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor],
+ hidden_or_intermediate_states: Union[torch.Tensor,
+ IntermediateTensors],
+ ) -> None:
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
+ start_layer = model_executable.model.start_layer
+ end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
model_config = model_executable.model.config
- num_heads = int(model_config.num_key_value_heads / self.tp_size)
- hidden_size = model_config.hidden_size
- num_attention_heads = model_config.num_attention_heads
- head_size = int(hidden_size / num_attention_heads)
+
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ if not is_deepseek:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
......@@ -1018,38 +1161,31 @@ index 2033e976..e33919c1 100644
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(4.5 * hidden_size / num_attention_heads)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
@@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
+
+ # query_lens contains new KV caches that are added to vLLM.
+ # so we will send them to decode instance
+ # FIXME(Kuntai): This assume that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ _, decode_kv_rank = self.parse_request_id(current_request_id)
+ starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
+ decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
+ decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
+
+ for target_rank in range(self.config.tensor_parallel_multiplier):
- keys, values = [], []
+
+ keys, values = [], []
- for layer_id in range(start_layer, end_layer):
- kv_cache = kv_caches[layer_id - start_layer]
+
+ for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer]
- key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
- value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
- current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+
+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+ head_start = target_rank * num_heads_per_rank
+ head_end = head_start + num_heads_per_rank
- keys.append(key_cache[current_slot_mapping].unsqueeze(0))
- values.append(value_cache[current_slot_mapping].unsqueeze(0))
+
+ if not is_deepseek:
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
......@@ -1059,75 +1195,70 @@ index 2033e976..e33919c1 100644
+ key_cache = kv_cache
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+ values.append(torch.empty(0))
- keys = torch.cat(keys, dim=0)
- values = torch.cat(values, dim=0)
+
+ keys = torch.cat(keys, dim=0)
+ values = torch.cat(values, dim=0)
- self.insert(current_tokens,
- torch.ones_like(current_tokens,
- dtype=bool), keys, values,
- hidden_or_intermediate_states[start_pos:end_pos])
+ self.insert(starting_kv_group_rank, target_rank, current_tokens,
+ torch.ones_like(current_tokens,
+ dtype=bool), keys, values,
+ hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
@@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+
+ decode_global_rank = decode_first_global_rank + target_rank
+ decode_port = self._get_data_plane_port(decode_global_rank)
+ partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
+ self._send(decode_hostname, decode_port, current_request_id, keys, values,
+ partial_hidden_or_intermediate_states)
+
+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
+
+ def recv_kv_caches_and_hidden_states(
+ self, model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor]
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
+ "ModelInputForGPUWithSamplingMetadata"]:
+
+ # When bypass_model_exec is set to False, it means that at least for one
+ # request its corresponding KV cache or hidden state is missing.
+ # In this case we need to do prefilling to recompute missing KV cache
+ # and hidden states.
+ bypass_model_exec = True
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
hidden_or_intermediate_states_for_one_req = []
@@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase):
num_computed_tokens_list = []
start_pos_list = []
+
+ hidden_or_intermediate_states_for_one_req = []
+
+ input_tokens_list = []
+ start_pos_list = []
+
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
@@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
+ # enumerate different requests
+ # FIXME(Kuntai): This impl assumes that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ prefill_rank, _ = self.parse_request_id(current_request_id)
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
- ret = self.select(current_tokens,
+ ret = self.select(prefill_rank, current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
@@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
- key_cache, value_cache = kv_cache[0], kv_cache[1]
- ops.reshape_and_cache_flash(
- keys[i - model_executable.model.start_layer].to(
- key_cache.device),
- values[i - model_executable.model.start_layer].to(
- value_cache.device),
- key_cache,
- value_cache,
- slot_mapping[start_pos:end_pos],
- layer.self_attn.attn.kv_cache_dtype,
- layer.self_attn.attn._k_scale,
- layer.self_attn.attn._v_scale,
- )
+ num_tokens = slen
+
+ # collecting data for rebuilding the input
+ input_tokens_list.append(current_tokens)
+ start_pos_list.append(start_pos)
+
+ ret = self._recv(current_request_id)
+ keys: torch.Tensor = ret[0]
+ values: torch.Tensor = ret[1]
+ hidden: torch.Tensor = ret[2]
+
+ # put received KV caches into paged memory
+ for i in range(model_executable.model.start_layer,
+ model_executable.model.end_layer):
+
+ kv_cache = kv_caches[i - model_executable.model.start_layer]
+ layer = model_executable.model.layers[i]
+
+ if not is_deepseek:
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
+ ops.reshape_and_cache_flash(
......@@ -1147,32 +1278,58 @@ index 2033e976..e33919c1 100644
+ copy_from =keys[i - model_executable.model.start_layer].to(
+ key_cache.device)
+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
hidden_or_intermediate_states_for_one_req.append(hidden)
@@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase):
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
+
+ hidden_or_intermediate_states_for_one_req.append(hidden)
+
+ if not bypass_model_exec:
+ # Some of the KV cache is not retrieved
+ # Here we will fall back to normal model forwarding
+ # But optionally you can adjust model_input so that you only do
+ # prefilling on those tokens that are missing KV caches.
+ logger.debug(
+ "[rank%d]: Failed to receive all KVs and hidden "
+ "states, redo model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = None
+
+ else:
+ logger.debug(
+ "[rank%d]: Successfully received all KVs and hidden "
+ "states, skip model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = torch.cat(
+ hidden_or_intermediate_states_for_one_req, dim=0)
+
+ return hidden_or_intermediate_states, bypass_model_exec, model_input
+
+ def close(self):
+ self.data_pipe.close()
+ # self.data_plane.close()
+
+ @staticmethod
+ def parse_request_id(request_id):
+ # Regular expression to match the ranks
+ pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)"
+ def parse_request_id(request_id: str) -> Tuple[str, int]:
+ # Regular expression to match the string hostname and integer decode_kv_rank
+ pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
+
+ # Use re.search to find the pattern in the request_id
+ match = re.search(pattern, request_id)
+
+ if match:
+ # Extract the ranks
+ prefill_rank = int(match.group(1))
+ decode_hostname = match.group(1)
+ decode_rank = int(match.group(2))
+
+ return prefill_rank, decode_rank
+ else:
+ return None, None
+ return decode_hostname, decode_rank
+ raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
+
+
+ def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
+ remote_address = f"{hostname}:{port}"
+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
+
+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ keys = self.data_plane.recv_tensor(f"{request_id}_keys")
+ values = self.data_plane.recv_tensor(f"{request_id}_values")
+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
+ return keys, values, hidden
+
+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank < config.kv_producers_parallel_size:
......@@ -1180,45 +1337,53 @@ index 2033e976..e33919c1 100644
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0:
+ if self.config.kv_connector == "PyNcclConnector":
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+ def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank <= config.kv_producers_parallel_size:
+ return kv_rank * config.kv_producers_tensor_parallel_size + rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
+
+ else:
+ raise NotImplementedError("MooncakeConnector is not supported in Dynemo Distributed vllm patch")
+
+ def _get_data_plane_port(self, global_kv_rank: int) -> int:
+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
+
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0:
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
......@@ -1228,120 +1393,152 @@ index 2033e976..e33919c1 100644
+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
new file mode 100644
index 00000000..cb3b3660
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Simple KV Cache Connector for Distributed Machine Learning Inference
+
+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
+MooncakePipe.
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..f4775663 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -27,13 +27,13 @@ class KVConnectorFactory:
@classmethod
def create_connector(cls, rank: int, local_rank: int,
- config: "VllmConfig") -> KVConnectorBase:
+ config: "VllmConfig", world_group) -> KVConnectorBase:
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
- return connector_cls(rank, local_rank, config)
+ return connector_cls(rank, local_rank, config, world_group)
# Register various connectors here.
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
+
+But the logic can be extended to support other pipe and lookup buffer.
+"""
+KVConnectorFactory.register_connector(
+ "DynemoNcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.dynemo_connector",
+ "DynemoConnector")
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e0537903 100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -8,13 +8,15 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
+import re
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import torch
+
+from vllm import _custom_ops as ops
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
-from vllm.config import VllmConfig
+from vllm.config import VllmConfig, KVTransferConfig
+from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup
+from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
+ SimpleBuffer)
+from vllm.logger import init_logger
+from vllm.sequence import IntermediateTensors
+
+if TYPE_CHECKING:
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
+
+logger = init_logger(__name__)
+
+
+class DynemoConnector(KVConnectorBase):
+
+ def __init__(
+ self,
+ rank: int,
+ local_rank: int,
+ config: VllmConfig,
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
@@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
rank: int,
local_rank: int,
config: VllmConfig,
+ world_group,
+ ):
+
+ self.config = config.kv_transfer_config
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.rank = rank
+
+ if self.config.kv_connector != "DynemoNcclConnector":
+ raise NotImplementedError("Only DynemoNcclConnector is supported by the DynemoConnector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.dynemo_nccl_pipe import (
+ DynemoNcclDataPlane)
+
+ logger.info(
+ "Initializing DynemoNcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
+
+ self.producer_data_pipe: PyNcclPipe
+ self.consumer_data_pipe: PyNcclPipe
+ self.producer_signal_pipe: PyNcclPipe
+ self.consumer_signal_pipe: PyNcclPipe
+
):
self.config = config.kv_transfer_config
@@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase):
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
+ self._broadcast_and_enhance_kv_config(rank, config, world_group)
+
+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+ self.tp_size = config.parallel_config.tensor_parallel_size
+
+ # 2 pipes for every rank in the world
# 2 pipes for every rank in the world
- port_offset_base = 2 * rank
+ if self.config.is_kv_producer:
+ port_offset_base = rank + 1
+ port_offset_base = 2 * rank + 1
+ else:
+ port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
+
+
+ port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1
+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
+ self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
+
+ self.data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
+ local_rank=local_rank,
+ config=self.config,
+ port_offset=port_offset_base,
+ )
+
+ self.data_plane = DynemoNcclDataPlane(
+ data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
+ )
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase):
# its recv pipe to the send pipe of KV producder
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase):
self.config.kv_buffer_size,
)
- def select(self, input_tokens: Optional[torch.Tensor],
+ def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
+ logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
+
+ def send_kv_caches_and_hidden_states(
+ self,
+ model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor],
+ hidden_or_intermediate_states: Union[torch.Tensor,
+ IntermediateTensors],
+ ) -> None:
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
- return self.consumer_buffer.drop_select(input_tokens, roi)
+ return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi)
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
+ logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
+ start_layer = model_executable.model.start_layer
+ end_layer = model_executable.model.end_layer
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
- self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
+ self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
@@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ model_config = model_executable.model.config
model_config = model_executable.model.config
- num_heads = int(model_config.num_key_value_heads / self.tp_size)
- hidden_size = model_config.hidden_size
- num_attention_heads = model_config.num_attention_heads
- head_size = int(hidden_size / num_attention_heads)
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ if not is_deepseek:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
......@@ -1353,31 +1550,38 @@ index 00000000..cb3b3660
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(4.5 * hidden_size / num_attention_heads)
+
+ # query_lens contains new KV caches that are added to vLLM.
+ # so we will send them to decode instance
+ # FIXME(Kuntai): This assume that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
@@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
+ decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
+ _, decode_kv_rank = self.parse_request_id(current_request_id)
+ starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
+
+ for target_rank in range(self.config.tensor_parallel_multiplier):
+
- keys, values = [], []
+ keys, values = [], []
+
- for layer_id in range(start_layer, end_layer):
- kv_cache = kv_caches[layer_id - start_layer]
+ for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer]
+
- key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
- value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+
- current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+ head_start = target_rank * num_heads_per_rank
+ head_end = head_start + num_heads_per_rank
+
- keys.append(key_cache[current_slot_mapping].unsqueeze(0))
- values.append(value_cache[current_slot_mapping].unsqueeze(0))
+ if not is_deepseek:
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
......@@ -1387,70 +1591,75 @@ index 00000000..cb3b3660
+ key_cache = kv_cache
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+ values.append(torch.empty(0))
+
- keys = torch.cat(keys, dim=0)
- values = torch.cat(values, dim=0)
+ keys = torch.cat(keys, dim=0)
+ values = torch.cat(values, dim=0)
+
+ decode_global_rank = decode_first_global_rank + target_rank
+ decode_port = self._get_data_plane_port(decode_global_rank)
+ partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
+ self._send(decode_hostname, decode_port, current_request_id, keys, values,
+ partial_hidden_or_intermediate_states)
+
+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
+
+ def recv_kv_caches_and_hidden_states(
+ self, model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor]
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
+ "ModelInputForGPUWithSamplingMetadata"]:
+
+ # When bypass_model_exec is set to False, it means that at least for one
+ # request its corresponding KV cache or hidden state is missing.
+ # In this case we need to do prefilling to recompute missing KV cache
+ # and hidden states.
+ bypass_model_exec = True
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
- self.insert(current_tokens,
- torch.ones_like(current_tokens,
- dtype=bool), keys, values,
- hidden_or_intermediate_states[start_pos:end_pos])
+ self.insert(starting_kv_group_rank, target_rank, current_tokens,
+ torch.ones_like(current_tokens,
+ dtype=bool), keys, values,
+ hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
@@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ hidden_or_intermediate_states_for_one_req = []
+
+ input_tokens_list = []
+ start_pos_list = []
+
hidden_or_intermediate_states_for_one_req = []
@@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase):
num_computed_tokens_list = []
start_pos_list = []
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+
+ # enumerate different requests
+ # FIXME(Kuntai): This impl assumes that all requests are prefill.
+ for idx, slen in enumerate(seq_lens):
+
+ start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
@@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ num_tokens = slen
+
+ # collecting data for rebuilding the input
+ input_tokens_list.append(current_tokens)
+ start_pos_list.append(start_pos)
+
+ ret = self._recv(current_request_id)
+ keys: torch.Tensor = ret[0]
+ values: torch.Tensor = ret[1]
+ hidden: torch.Tensor = ret[2]
+
+ # put received KV caches into paged memory
+ for i in range(model_executable.model.start_layer,
+ model_executable.model.end_layer):
+
+ kv_cache = kv_caches[i - model_executable.model.start_layer]
+ layer = model_executable.model.layers[i]
+
+ prefill_rank, _ = self.parse_request_id(current_request_id)
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
- ret = self.select(current_tokens,
+ ret = self.select(prefill_rank, current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
@@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
- key_cache, value_cache = kv_cache[0], kv_cache[1]
- ops.reshape_and_cache_flash(
- keys[i - model_executable.model.start_layer].to(
- key_cache.device),
- values[i - model_executable.model.start_layer].to(
- value_cache.device),
- key_cache,
- value_cache,
- slot_mapping[start_pos:end_pos],
- layer.self_attn.attn.kv_cache_dtype,
- layer.self_attn.attn._k_scale,
- layer.self_attn.attn._v_scale,
- )
+ if not is_deepseek:
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
+ ops.reshape_and_cache_flash(
......@@ -1470,58 +1679,32 @@ index 00000000..cb3b3660
+ copy_from =keys[i - model_executable.model.start_layer].to(
+ key_cache.device)
+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
+
+ hidden_or_intermediate_states_for_one_req.append(hidden)
+
+ if not bypass_model_exec:
+ # Some of the KV cache is not retrieved
+ # Here we will fall back to normal model forwarding
+ # But optionally you can adjust model_input so that you only do
+ # prefilling on those tokens that are missing KV caches.
+ logger.debug(
+ "[rank%d]: Failed to receive all KVs and hidden "
+ "states, redo model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = None
+
+ else:
+ logger.debug(
+ "[rank%d]: Successfully received all KVs and hidden "
+ "states, skip model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = torch.cat(
+ hidden_or_intermediate_states_for_one_req, dim=0)
+
+ return hidden_or_intermediate_states, bypass_model_exec, model_input
+
+ def close(self):
+ self.data_pipe.close()
+ # self.data_plane.close()
hidden_or_intermediate_states_for_one_req.append(hidden)
@@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase):
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
+
+ @staticmethod
+ def parse_request_id(request_id: str) -> Tuple[str, int]:
+ # Regular expression to match the string hostname and integer decode_kv_rank
+ pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
+ def parse_request_id(request_id):
+ # Regular expression to match the ranks
+ pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)"
+
+ # Use re.search to find the pattern in the request_id
+ match = re.search(pattern, request_id)
+
+ if match:
+ # Extract the ranks
+ decode_hostname = match.group(1)
+ prefill_rank = int(match.group(1))
+ decode_rank = int(match.group(2))
+
+ return decode_hostname, decode_rank
+ raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
+
+ def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
+ remote_address = f"{hostname}:{port}"
+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
+
+ return prefill_rank, decode_rank
+ else:
+ return None, None
+
+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ keys = self.data_plane.recv_tensor(f"{request_id}_keys")
+ values = self.data_plane.recv_tensor(f"{request_id}_values")
+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
+ return keys, values, hidden
+
+
+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank < config.kv_producers_parallel_size:
......@@ -1529,53 +1712,45 @@ index 00000000..cb3b3660
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+
+ def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank <= config.kv_producers_parallel_size:
+ return kv_rank * config.kv_producers_tensor_parallel_size + rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
+
+
+ def _get_data_plane_port(self, global_kv_rank: int) -> int:
+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
+
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0:
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+ if self.config.kv_connector == "PyNcclConnector":
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+
+ else:
+ raise NotImplementedError("MooncakeConnector is not supported in Dynemo patch")
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
......@@ -1585,7 +1760,6 @@ index 00000000..cb3b3660
+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
index 5e1b6235..b4506877 100644
--- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
......@@ -1787,84 +1961,214 @@ index 5e1b6235..b4506877 100644
- self.data_pipe.send_tensor(roi)
+ self.signal_pipe.send_tensor(self.normal_signal, rank)
+
+ self.data_pipe.send_tensor(torch.tensor(kv_rank), rank)
+ self.data_pipe.send_tensor(input_tokens, rank)
+ self.data_pipe.send_tensor(roi, rank)
- input_tokens = self.data_pipe.recv_tensor()
- roi = self.data_pipe.recv_tensor()
+ input_tokens = self.data_pipe.recv_tensor(rank)
+ roi = self.data_pipe.recv_tensor(rank)
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
- key = self.data_pipe.recv_tensor()
- value = self.data_pipe.recv_tensor()
- hidden = self.data_pipe.recv_tensor()
+ key = self.data_pipe.recv_tensor(rank)
+ value = self.data_pipe.recv_tensor(rank)
+ hidden = self.data_pipe.recv_tensor(rank)
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
@@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase):
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
- self._add_to_buffer(input_tokens, roi, key, value, hidden)
+ self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
+ target_rank_global = target_rank + kv_group_rank
if self.request_handling_thread is None:
- self.request_handling_thread = threading.Thread(
- target=self.drop_select_handler)
- self.request_handling_thread.start()
+ self.request_handling_thread = ThreadPoolExecutor(max_workers=1)
+ self.request_handling_thread.submit(self.drop_select_handler, target_rank_global)
def close(self):
- if hasattr(self, "request_handling_thread"
- ) and self.request_handling_thread is not None:
- self.request_handling_thread.join()
+ if hasattr(self, "request_handling_thread") and self.request_handling_thread:
+ self.request_handling_thread.shutdown()
else:
# TODO: have a explicit close signal and have a explicit way to
diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py
index 40589fb3..da2829cf 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/base.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/base.py
@@ -23,7 +23,7 @@ class KVPipeBase(ABC):
"""
@abstractmethod
- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
@@ -41,7 +41,7 @@ class KVPipeBase(ABC):
raise NotImplementedError
@abstractmethod
- def recv_tensor(self) -> Optional[torch.Tensor]:
+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:
+ self.data_pipe.send_tensor(torch.tensor(kv_rank), rank)
+ self.data_pipe.send_tensor(input_tokens, rank)
+ self.data_pipe.send_tensor(roi, rank)
- input_tokens = self.data_pipe.recv_tensor()
- roi = self.data_pipe.recv_tensor()
+ input_tokens = self.data_pipe.recv_tensor(rank)
+ roi = self.data_pipe.recv_tensor(rank)
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
- key = self.data_pipe.recv_tensor()
- value = self.data_pipe.recv_tensor()
- hidden = self.data_pipe.recv_tensor()
+ key = self.data_pipe.recv_tensor(rank)
+ value = self.data_pipe.recv_tensor(rank)
+ hidden = self.data_pipe.recv_tensor(rank)
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
@@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase):
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
- self._add_to_buffer(input_tokens, roi, key, value, hidden)
+ self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
+ target_rank_global = target_rank + kv_group_rank
if self.request_handling_thread is None:
- self.request_handling_thread = threading.Thread(
- target=self.drop_select_handler)
- self.request_handling_thread.start()
+ self.request_handling_thread = ThreadPoolExecutor(max_workers=1)
+ self.request_handling_thread.submit(self.drop_select_handler, target_rank_global)
def close(self):
- if hasattr(self, "request_handling_thread"
- ) and self.request_handling_thread is not None:
- self.request_handling_thread.join()
+ if hasattr(self, "request_handling_thread") and self.request_handling_thread:
+ self.request_handling_thread.shutdown()
else:
# TODO: have a explicit close signal and have a explicit way to
diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py
index 40589fb3..da2829cf 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/base.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/base.py
@@ -23,7 +23,7 @@ class KVPipeBase(ABC):
"""
@abstractmethod
- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
@@ -41,7 +41,7 @@ class KVPipeBase(ABC):
raise NotImplementedError
@abstractmethod
- def recv_tensor(self) -> Optional[torch.Tensor]:
+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:
diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
new file mode 100644
index 00000000..58d0d28c
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
+import typing
+import zmq
+import socket
+import time
+import torch
+
+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
+
+
+logger = logging.getLogger(__name__)
+
+
+class DynemoNcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
+ hostname: str = "",
+ port: int = 0,
+ ) -> None:
+
+ self.data_pipe = data_pipe
+ if not hostname:
+ hostname = socket.gethostname()
+ if port == 0:
+ raise ValueError("Port cannot be 0")
+ self._hostname = hostname
+ self._port = port
+ self.store = {}
+ self.context = zmq.Context()
+ self.rep_socket = self.context.socket(zmq.REP)
+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
+ self._listener_thread.start()
+ self.req_sockets = {}
+ logger.info(f"Rank {self.rank} connected to the server")
+
+ @property
+ def rank(self):
+ return self.data_pipe.kv_group_rank
+
+ def send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
+ return self._send_tensor(tensor, tensor_id, remote_address)
+
+ def recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ ret = self._recv_tensor(tensor_id, remote_address)
+ return ret
+
+ def _send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
+ if remote_address is None:
+ self.store[tensor_id] = tensor
+ else:
+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape)
+ # tensor_dtype = str(tensor.dtype)
+ if remote_address not in self.req_sockets:
+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
+
+ req_socket = self.req_sockets[remote_address]
+ # req_socket.connect(f"tcp://{remote_address}")
+ req_socket.send_string(f"PUT {self.rank} {tensor_id}")
+ dst_rank = req_socket.recv_string()
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
+ self.data_pipe.send_tensor(tensor, int(dst_rank))
+
+ def _recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ logger.debug(f"Rank {self.rank} receiving tensor")
+ if remote_address is not None:
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ if tensor_id in self.store:
+ logger.debug(f"Popping tensor {tensor_id} from store")
+ future = self.store.pop(tensor_id)
+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
+ logger.debug(f"Rank {self.rank} received tensor")
+ return tensor
+
+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
+ time.sleep(0.001)
+ return self._recv_tensor(tensor_id, remote_address)
+ # raise NotImplementedError("Tensor not found in store")
+
+ def _receive_tensor(
+ self,
+ tensor_id: str,
+ rank: int,
+ ):
+ future = self.data_pipe.recv_tensor(rank)
+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
+ self.store[tensor_id] = future
+
+ def listen_for_requests(self):
+ while True:
+ cmd, rank, tensor_id = self.rep_socket.recv_string().split()
+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
+ self.rep_socket.send_string(f"{self.rank}")
+ if cmd == "GET":
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ elif cmd == "PUT":
+ rank = int(rank)
+ # shape = [int(dim) for dim in shape.split("_")]
+ # dtype = getattr(torch, dtype)
+ self._receive_tensor(tensor_id, rank)
diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
index 7aa53d07..f5dd50b7 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
......@@ -2055,136 +2359,6 @@ index 7aa53d07..f5dd50b7 100644
def close(self):
"""
diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
new file mode 100644
index 00000000..8a356504
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
+import typing
+import zmq
+import socket
+import time
+import torch
+
+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
+
+
+logger = logging.getLogger(__name__)
+
+
+class DynemoNcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
+ hostname: str = "",
+ port: int = 0,
+ ) -> None:
+
+ self.data_pipe = data_pipe
+ if not hostname:
+ hostname = socket.gethostname()
+ if port == 0:
+ raise ValueError("Port cannot be 0")
+ self._hostname = hostname
+ self._port = port
+ self.store = {}
+ self.context = zmq.Context()
+ self.rep_socket = self.context.socket(zmq.REP)
+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
+ self._listener_thread.start()
+ self.req_sockets = {}
+ logger.info(f"Rank {self.rank} connected to the server")
+
+ @property
+ def rank(self):
+ return self.data_pipe.kv_group_rank
+
+ def send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
+ return self._send_tensor(tensor, tensor_id, remote_address)
+
+ def recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ ret = self._recv_tensor(tensor_id, remote_address)
+ return ret
+
+ def _send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
+ if remote_address is None:
+ self.store[tensor_id] = tensor
+ else:
+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape)
+ # tensor_dtype = str(tensor.dtype)
+ if remote_address not in self.req_sockets:
+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
+
+ req_socket = self.req_sockets[remote_address]
+ # req_socket.connect(f"tcp://{remote_address}")
+ req_socket.send_string(f"PUT {self.rank} {tensor_id}")
+ dst_rank = req_socket.recv_string()
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
+ self.data_pipe.send_tensor(tensor, int(dst_rank))
+
+ def _recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ logger.debug(f"Rank {self.rank} receiving tensor")
+ if remote_address is not None:
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ if tensor_id in self.store:
+ logger.debug(f"Popping tensor {tensor_id} from store")
+ future = self.store.pop(tensor_id)
+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
+ logger.debug(f"Rank {self.rank} received tensor")
+ return tensor
+
+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
+ time.sleep(0.001)
+ return self._recv_tensor(tensor_id, remote_address)
+ # raise NotImplementedError("Tensor not found in store")
+
+ def _receive_tensor(
+ self,
+ tensor_id: str,
+ rank: int,
+ ):
+ future = self.data_pipe.recv_tensor(rank)
+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
+ self.store[tensor_id] = future
+
+ def listen_for_requests(self):
+ while True:
+ cmd, rank, tensor_id = self.rep_socket.recv_string().split()
+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
+ self.rep_socket.send_string(f"{self.rank}")
+ if cmd == "GET":
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ elif cmd == "PUT":
+ rank = int(rank)
+ # shape = [int(dim) for dim in shape.split("_")]
+ # dtype = getattr(torch, dtype)
+ self._receive_tensor(tensor_id, rank)
diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py
index 1e80e0bd..cd90206f 100644
--- a/vllm/distributed/kv_transfer/kv_transfer_agent.py
......@@ -2221,7 +2395,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..9ba1a326 100644
index d82d9ad9..62dbbd6e 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@
......@@ -2308,7 +2482,9 @@ index d82d9ad9..9ba1a326 100644
+ self._nixl_agents_names = self._initialize_nixl()
+
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set()
+ self._finished_transfers = set()
+
+ @property
+ def is_nixl_initialized(self) -> bool:
......@@ -2327,8 +2503,6 @@ index d82d9ad9..9ba1a326 100644
+ engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+
+ def _initialize_nixl(self) -> List[bytes]:
......@@ -2338,16 +2512,7 @@ index d82d9ad9..9ba1a326 100644
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@@ -500,6 +545,8 @@ class LLMEngine:
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
+ if self._nixl_agents_names:
+ model_executor.collective_rpc("shutdown_nixl")
model_executor.shutdown()
def get_tokenizer_group(
@@ -552,11 +599,14 @@ class LLMEngine:
@@ -552,11 +597,14 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
......@@ -2362,6 +2527,15 @@ index d82d9ad9..9ba1a326 100644
ParallelSampleSequenceGroup.add_request(
request_id,
self,
@@ -574,6 +622,8 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +634,7 @@ class LLMEngine:
encoder_inputs = None
......@@ -2454,7 +2628,7 @@ index d82d9ad9..9ba1a326 100644
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills)
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers)
+
+
+ # Separate remote prefill and running seq groups
......@@ -2486,7 +2660,7 @@ index d82d9ad9..9ba1a326 100644
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1476,29 @@ class LLMEngine:
@@ -1383,9 +1476,31 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
......@@ -2500,9 +2674,11 @@ index d82d9ad9..9ba1a326 100644
+ req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id,
+ src_block_ids=block_table,
+ staging_block_ids=staging_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id,
......@@ -2512,13 +2688,13 @@ index d82d9ad9..9ba1a326 100644
+
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs
+
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
-
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
@@ -1396,7 +1509,20 @@ class LLMEngine:
@@ -1396,7 +1511,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
......@@ -2529,7 +2705,7 @@ index d82d9ad9..9ba1a326 100644
+ blocks_to_swap_out=[],
+ blocks_to_copy=[])
+
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req)
+
+ for req_id, notif_count in request_notif_counter.items():
......@@ -2537,10 +2713,16 @@ index d82d9ad9..9ba1a326 100644
+ if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id]
+
+ for req_id, done_count in request_done_counter.items():
+ self._request_done_counter[req_id] += done_count
+ if self._request_done_counter[req_id] > -1:
+ self._finished_transfers.add(req_id)
+ del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@@ -1456,7 +1582,7 @@ class LLMEngine:
@@ -1456,7 +1590,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
......@@ -2618,7 +2800,7 @@ index 3cf1850e..6b90ece7 100644
+ kv_active_blocks: int
+ kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..d33d546a 100644
index 85b5f31e..3f8b8fad 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
......@@ -2746,7 +2928,7 @@ index 85b5f31e..d33d546a 100644
+ kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks)
+
+ logger.debug("Metircs successful.")
+ logger.debug("Metircs successful.")
+
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
......@@ -3151,10 +3333,10 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644
index 00000000..03f02006
index 00000000..957f55de
--- /dev/null
+++ b/vllm/remote_prefill.py
@@ -0,0 +1,53 @@
@@ -0,0 +1,54 @@
+from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine
+
......@@ -3192,6 +3374,7 @@ index 00000000..03f02006
+ """
+ request_id: str
+ src_block_ids: List[int]
+ staging_block_ids: List[int]
+ dst_block_ids: List[int]
+ dst_engine_id: str
+ notify_msg: str
......@@ -3327,7 +3510,7 @@ index 534b9e60..18675d2f 100644
@property
def is_first_multi_step(self) -> bool:
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 12baecde..cbada27f 100644
index 12baecde..489d3b77 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -1824,6 +1824,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
......@@ -3351,7 +3534,7 @@ index 12baecde..cbada27f 100644
prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..ffb7b403 100644
index 582aa460..c01cfe00 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@
......@@ -3392,13 +3575,13 @@ index 582aa460..ffb7b403 100644
+
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata[self.local_rank]) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr[self.local_rank])
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata)) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ return agent_name
+
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name[self.local_rank], notify_msg) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
......@@ -3406,8 +3589,8 @@ index 582aa460..ffb7b403 100644
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None:
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+ def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
......@@ -3425,11 +3608,12 @@ index 582aa460..ffb7b403 100644
return WorkerInput(
num_seq_groups=num_seq_groups,
@@ -375,6 +416,10 @@ class Worker(LocalOrDistributedWorkerBase):
@@ -375,6 +416,11 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs],
......@@ -3437,7 +3621,7 @@ index 582aa460..ffb7b403 100644
@torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..d9c039eb 100644
index 819b81fb..8dfdadde 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
......@@ -3465,11 +3649,12 @@ index 819b81fb..d9c039eb 100644
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,11 @@ class WorkerInput:
@@ -216,6 +220,12 @@ class WorkerInput:
virtual_engine: int = 0
num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None
......@@ -3477,29 +3662,31 @@ index 819b81fb..d9c039eb 100644
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
@@ -232,6 +241,10 @@ class WorkerInput:
@@ -232,6 +242,11 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"),
)
def as_broadcastable_tensor_dict(
@@ -246,6 +259,10 @@ class WorkerInput:
@@ -246,6 +261,11 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids,
+ "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg,
}
return tensor_dict
@@ -316,13 +333,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@@ -316,13 +336,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
......@@ -3521,7 +3708,7 @@ index 819b81fb..d9c039eb 100644
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
@@ -396,49 +416,79 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@@ -396,49 +419,87 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
......@@ -3618,7 +3805,7 @@ index 819b81fb..d9c039eb 100644
+ else:
+ for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs:
+ for req_ids in notifs.values():
......@@ -3627,12 +3814,20 @@ index 819b81fb..d9c039eb 100644
+
+ if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter)
+
+ request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else:
+ request_notif_counter = {}
+ request_done_counter = {}
# output is List[SamplerOutput]
- return output
+ return output, request_notif_counter
+ return output, request_notif_counter, request_done_counter
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment