Commit ea78a424 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

chore: regenerate patch (#29)

parent 46ed649c
diff --git a/vllm/config.py b/vllm/config.py diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..a2f88854 100644 index 9ba49757..cbfeb715 100644
--- a/vllm/config.py --- a/vllm/config.py
+++ b/vllm/config.py +++ b/vllm/config.py
@@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel): @@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel):
...@@ -261,7 +261,7 @@ index c5b3b04f..c72001f7 100644 ...@@ -261,7 +261,7 @@ index c5b3b04f..c72001f7 100644
self.block_tables: Dict[SeqId, BlockTable] = {} self.block_tables: Dict[SeqId, BlockTable] = {}
diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py
new file mode 100644 new file mode 100644
index 00000000..350453cd index 00000000..8699ca06
--- /dev/null --- /dev/null
+++ b/vllm/core/event_manager.py +++ b/vllm/core/event_manager.py
@@ -0,0 +1,102 @@ @@ -0,0 +1,102 @@
...@@ -368,10 +368,15 @@ index 00000000..350453cd ...@@ -368,10 +368,15 @@ index 00000000..350453cd
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..ee20d50c 100644 index f507847a..abe574d1 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py
@@ -8,18 +8,17 @@ from collections import deque @@ -4,22 +4,22 @@ import enum
import os
import random
import time
+import copy
from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
...@@ -393,7 +398,7 @@ index f507847a..ee20d50c 100644 ...@@ -393,7 +398,7 @@ index f507847a..ee20d50c 100644
logger = init_logger(__name__) logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with # Test-only. If configured, decode is preempted with
@@ -325,12 +324,14 @@ class Scheduler: @@ -325,12 +325,14 @@ class Scheduler:
def __init__( def __init__(
self, self,
...@@ -408,7 +413,7 @@ index f507847a..ee20d50c 100644 ...@@ -408,7 +413,7 @@ index f507847a..ee20d50c 100644
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely # Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +357,7 @@ class Scheduler: @@ -356,6 +358,7 @@ class Scheduler:
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManagerImpl( self.block_manager = BlockSpaceManagerImpl(
...@@ -416,7 +421,7 @@ index f507847a..ee20d50c 100644 ...@@ -416,7 +421,7 @@ index f507847a..ee20d50c 100644
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +373,14 @@ class Scheduler: @@ -371,6 +374,16 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out. # Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
...@@ -424,6 +429,8 @@ index f507847a..ee20d50c 100644 ...@@ -424,6 +429,8 @@ index f507847a..ee20d50c 100644
+ # Sequence groups in the REMOTE_PREFILLING state. + # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker. + # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque() + self.remote_prefilling: Deque[SequenceGroup] = deque()
+ # Contain requests that are being prefilled by a local worker.
+ self.prefill_sending: Deque[SequenceGroup] = deque()
+ +
+ self._remote_prefill_outputs: Dict[str, int] = {} + self._remote_prefill_outputs: Dict[str, int] = {}
+ +
...@@ -431,24 +438,25 @@ index f507847a..ee20d50c 100644 ...@@ -431,24 +438,25 @@ index f507847a..ee20d50c 100644
# Sequence groups finished requests ids since last step iteration. # Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests # It lets the model know that any state associated with these requests
# can and must be released after the current step. # can and must be released after the current step.
@@ -501,7 +511,7 @@ class Scheduler: @@ -501,7 +514,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len( return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0 - self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0 + self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device) return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +533,7 @@ class Scheduler: @@ -523,6 +536,8 @@ class Scheduler:
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
enable_chunking: bool = False, enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None + finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> SchedulerRunningOutputs: ) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running. """Schedule sequence groups that are running.
@@ -537,6 +548,8 @@ class Scheduler: @@ -537,6 +552,8 @@ class Scheduler:
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
...@@ -457,7 +465,7 @@ index f507847a..ee20d50c 100644 ...@@ -457,7 +465,7 @@ index f507847a..ee20d50c 100644
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.
@@ -566,6 +579,24 @@ class Scheduler: @@ -566,6 +583,38 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out swapped_out: List[SequenceGroup] = ret.swapped_out
...@@ -468,6 +476,7 @@ index f507847a..ee20d50c 100644 ...@@ -468,6 +476,7 @@ index f507847a..ee20d50c 100644
+ if seq_group.request_id not in finished_prefills: + if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group) + leftover_remote_prefilling_sequences.append(seq_group)
+ continue + continue
+
+ else: + else:
+ finished_prefills.remove(seq_group.request_id) + finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1 + assert len(seq_group.seqs) == 1
...@@ -478,39 +487,63 @@ index f507847a..ee20d50c 100644 ...@@ -478,39 +487,63 @@ index f507847a..ee20d50c 100644
+ seq.data._stage = SequenceStage.DECODE + seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group) + self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences) + remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
+ remote_transfers_queue = self.prefill_sending
+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
+ while remote_transfers_queue:
+ seq_group = remote_transfers_queue.popleft()
+ if seq_group.request_id not in finished_transfers:
+ leftover_remote_transfers_sequences.append(seq_group)
+ else:
+ finished_transfers.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ self.free_seq(seq)
+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
+ +
running_queue = self.running running_queue = self.running
assert len(self._async_stopped) == 0 assert len(self._async_stopped) == 0
while running_queue: while running_queue:
@@ -1008,7 +1039,7 @@ class Scheduler: @@ -1008,7 +1057,17 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0: if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id) curr_loras.add(lora_int_id)
waiting_queue.popleft() waiting_queue.popleft()
- self._allocate_and_set_running(seq_group) - self._allocate_and_set_running(seq_group)
+
+ seq_group_copy = copy.deepcopy(seq_group)
+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group) + self._allocate_and_set_running_or_remote_prefill(seq_group)
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step: if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1079,7 @@ class Scheduler: @@ -1048,7 +1107,7 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking)) is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs: - def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs: + def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1121,8 @@ class Scheduler: @@ -1090,7 +1149,9 @@ class Scheduler:
if len(prefills.seq_groups) == 0: if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget, running_scheduled = self._schedule_running(budget,
curr_loras, curr_loras,
- enable_chunking=False) - enable_chunking=False)
+ enable_chunking=False, + enable_chunking=False,
+ finished_prefills=finished_prefills) + finished_prefills=finished_prefills,
+ finished_transfers=finished_transfers)
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
@@ -1106,7 +1138,12 @@ class Scheduler: @@ -1106,7 +1167,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
if len(prefills.seq_groups) > 0: if len(prefills.seq_groups) > 0:
...@@ -524,30 +557,31 @@ index f507847a..ee20d50c 100644 ...@@ -524,30 +557,31 @@ index f507847a..ee20d50c 100644
self.running.extend(running_scheduled.decode_seq_groups_list) self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1285,14 @@ class Scheduler: @@ -1248,12 +1314,14 @@ class Scheduler:
len(running_scheduled.swapped_out)), len(running_scheduled.swapped_out)),
) )
- def _schedule(self) -> SchedulerOutputs: - def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs: + def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.""" """Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled: if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills: + if finished_prefills or finished_transfers:
+ raise ValueError("Chunked prefill does not support remote prefills") + raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill() return self._schedule_chunked_prefill()
else: else:
- return self._schedule_default() - return self._schedule_default()
+ return self._schedule_default(finished_prefills) + return self._schedule_default(finished_prefills, finished_transfers)
def _can_append_slots(self, seq_group: SequenceGroup, def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool: enable_chunking: bool) -> bool:
@@ -1287,14 +1326,15 @@ class Scheduler: @@ -1287,14 +1355,16 @@ class Scheduler:
return no_single_seq return no_single_seq
def schedule( def schedule(
- self - self
+ self, + self,
+ finished_prefills: Optional[Set[str]] = None + finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
...@@ -556,11 +590,11 @@ index f507847a..ee20d50c 100644 ...@@ -556,11 +590,11 @@ index f507847a..ee20d50c 100644
- scheduler_outputs: SchedulerOutputs = self._schedule() - scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter() + scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills) + scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers)
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching: if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1373,8 @@ class Scheduler: @@ -1333,7 +1403,8 @@ class Scheduler:
encoder_seq_data = None encoder_seq_data = None
cross_block_table = None cross_block_table = None
...@@ -570,18 +604,24 @@ index f507847a..ee20d50c 100644 ...@@ -570,18 +604,24 @@ index f507847a..ee20d50c 100644
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,6 +1405,10 @@ class Scheduler: @@ -1364,9 +1435,16 @@ class Scheduler:
< seqs[0].data.get_len()): < seqs[0].data.get_len()):
do_sample = False do_sample = False
+ is_remote_prefill = False + is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True + is_remote_prefill = True
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+ +
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data: if is_first_prefill or not self.scheduler_config.send_delta_data:
@@ -1392,6 +1437,7 @@ class Scheduler: + logger.debug("Assinged blocks: %s", block_tables)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
...@@ -589,7 +629,7 @@ index f507847a..ee20d50c 100644 ...@@ -589,7 +629,7 @@ index f507847a..ee20d50c 100644
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +1536,13 @@ class Scheduler: @@ -1490,10 +1569,13 @@ class Scheduler:
self._async_stopped.clear() self._async_stopped.clear()
...@@ -605,12 +645,80 @@ index f507847a..ee20d50c 100644 ...@@ -605,12 +645,80 @@ index f507847a..ee20d50c 100644
def _append_slots(self, def _append_slots(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
--- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def rearrange_kernel(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int):
+ N, B, H, C = t1.shape
+
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
+ assert H % d == 0, "H must be divisible by d"
+
+ block_size = B * H * C
+ token_size = H * C
+ tensor_size = N * block_size
+ tensor_subset_size = tensor_size // d
+
+ BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+
+ rearrange_kernel[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
\ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 00000000..bc962726 index 00000000..86248e7b
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,249 @@ @@ -0,0 +1,318 @@
+import torch +import torch
+from typing import List, Tuple +from typing import List, Tuple
+from vllm.config import VllmConfig +from vllm.config import VllmConfig
...@@ -618,39 +726,18 @@ index 00000000..bc962726 ...@@ -618,39 +726,18 @@ index 00000000..bc962726
+import msgspec +import msgspec
+import time +import time
+import uuid +import uuid
+from nixl_wrapper import nixl_wrapper as NixlWrapper +from collections import defaultdict
+from .kv_rearrange import rearrange_tensors
+ +
+logger = init_logger(__name__) +logger = init_logger(__name__)
+ +
+ +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+def nixl_wrapper_init_patch(self, agent_name, nixl_config): +try:
+ logger.info("Initializing patched NixlWrapper") + from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ import nixl_bindings as nixl + logger.info("NIXL is available")
+ # Read available backends and device info from nixl_config +except ImportError:
+ # For now setting the multithreading to enabled. + logger.warning("NIXL is not available")
+ devices = nixl.nixlAgentConfig(False) + NixlWrapper = None
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ print("Initializied NIXL agent:", agent_name)
+
+NixlWrapper.__init__ = nixl_wrapper_init_patch
+
+
+ +
+class NixlMetadata( +class NixlMetadata(
+ msgspec.Struct, + msgspec.Struct,
...@@ -665,11 +752,17 @@ index 00000000..bc962726 ...@@ -665,11 +752,17 @@ index 00000000..bc962726
+class DynemoNixlConnector: +class DynemoNixlConnector:
+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): + def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
+ self.vllm_config = vllm_config + self.vllm_config = vllm_config
+ if NixlWrapper is None:
+ logger.error("NIXL is not available")
+ raise RuntimeError("NIXL is not available")
+ logger.info("Initializing NIXL wrapper")
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+ +
+ self.num_layers = None + self.num_layers = None
+ self.num_blocks = None + self.num_blocks = None
+ self.num_heads = None
+ self.block_len = None + self.block_len = None
+ self.kv_caches = None
+ self.kv_caches_base_addr = {} + self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {} + self.kv_cache_shape = {}
+ +
...@@ -678,33 +771,51 @@ index 00000000..bc962726 ...@@ -678,33 +771,51 @@ index 00000000..bc962726
+ self.engine_id = engine_id + self.engine_id = engine_id
+ self.rank = rank + self.rank = rank
+ self.notifs = {} + self.notifs = {}
+ self._tp_size = {}
+ self._block_descs = {}
+ self._xfer_side_handles = {}
+
+
+ self._transfers = defaultdict(list)
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+
+ +
+ @property + @property
+ def agent_name(self): + def agent_name(self):
+ return self.nixl_wrapper.name + return self.nixl_wrapper.name
+ +
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]): + def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ caches_data = [] + _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() + self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ + self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+ kv_caches_base_addr = [] + kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches: + for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]: + for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr() + base_addr = cache.data_ptr()
+ region_len = cache.numel() * cache.element_size() + region_len = num_blocks * self.block_len
+ gpu_id = cache.get_device() + caches_data.append((base_addr, region_len, self.rank))
+ assert gpu_id > -1, "Tensor is not on GPU" + for block_id in range(self.num_blocks):
+ caches_data.append((base_addr, region_len, gpu_id)) + blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) + kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+ +
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data)) + descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs) + self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs) + self._registered_descs.append(descs)
+ +
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self): + def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata() + return self.nixl_wrapper.get_agent_metadata()
+ +
...@@ -714,10 +825,14 @@ index 00000000..bc962726 ...@@ -714,10 +825,14 @@ index 00000000..bc962726
+ for agent_name in self._remote_agents.values(): + for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name) + self.nixl_wrapper.remove_remote_agent(agent_name)
+ +
+ def add_remote_agent(self, engine_id, agent_metadata): + def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata) + self._tp_size[engine_id] = agent_tp
+ self._remote_agents[engine_id] = agent_name + agent_names = []
+ return agent_name + for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+ +
+ def get_descs_ids(self, layer_ids, block_ids): + def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all": + if layer_ids == "all":
...@@ -732,17 +847,29 @@ index 00000000..bc962726 ...@@ -732,17 +847,29 @@ index 00000000..bc962726
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1) + descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids + return descs_ids
+ +
+ def _get_range_descs(self, engine_id, ranges, layer_ids): + def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0):
+ if rank is None:
+ rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ if layer_ids == "all": + if layer_ids == "all":
+ layer_ids = list(range(self.num_layers)) + layer_ids = list(range(self.num_layers))
+ blocks_data = [] + blocks_data = []
+ for layer_id in layer_ids: + for layer_id in layer_ids:
+ for range_start, range_end in ranges: + for range_start, range_end in ranges:
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id] + range_len = range_end - range_start + 1
+ start_offset = range_start * self.block_len + key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ blocks_len = (range_end - range_start + 1) * self.block_len + start_offset = range_start * offset_block_len + tp_offset * range_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank)) + blocks_len = range_len * block_len
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank)) + blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data)) + return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ +
+ def _get_ranges(self, block_ids): + def _get_ranges(self, block_ids):
...@@ -755,9 +882,9 @@ index 00000000..bc962726 ...@@ -755,9 +882,9 @@ index 00000000..bc962726
+ ranges = [] + ranges = []
+ for i in range(len(sorted_block_ids)): + for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1: + if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([sorted_block_ids[i]]) + ranges.append([sorted_block_ids[i], sorted_block_ids[i]])
+ else: + else:
+ ranges[-1].append(sorted_block_ids[i]) + ranges[-1][1] = sorted_block_ids[i]
+ return ranges + return ranges
+ +
+ def _get_same_length_ranges(self, src_ranges, dst_ranges): + def _get_same_length_ranges(self, src_ranges, dst_ranges):
...@@ -797,11 +924,24 @@ index 00000000..bc962726 ...@@ -797,11 +924,24 @@ index 00000000..bc962726
+ src_idx += 1 + src_idx += 1
+ +
+ return src_overlapping_ranges, dst_overlapping_ranges + return src_overlapping_ranges, dst_overlapping_ranges
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ return descs_ids
+ +
+ +
+ +
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg): + def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False):
+
+ start_time = time.perf_counter() + start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg) + logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+ +
...@@ -810,44 +950,62 @@ index 00000000..bc962726 ...@@ -810,44 +950,62 @@ index 00000000..bc962726
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ + # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token. + # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)] + dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+ +
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ src_ranges = self._get_ranges(src_block_ids) + src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids) + dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+ +
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids)) + assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+ +
+ logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time) + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+ +
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all") + logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all") + for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+ +
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000) + staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+ +
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE") + for i in range(tp_multiplier):
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+ +
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def deserialize_descs(self, serialized_descs): + def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs) + return self.nixl_wrapper.deserialize_descs(serialized_descs)
+ +
...@@ -860,153 +1018,138 @@ index 00000000..bc962726 ...@@ -860,153 +1018,138 @@ index 00000000..bc962726
+ +
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr): + def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr + self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -27,13 +27,13 @@ class KVConnectorFactory:
@classmethod
def create_connector(cls, rank: int, local_rank: int,
- config: "VllmConfig") -> KVConnectorBase:
+ config: "VllmConfig", world_group) -> KVConnectorBase:
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
- return connector_cls(rank, local_rank, config)
+ return connector_cls(rank, local_rank, config, world_group)
# Register various connectors here.
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
+ +
+KVConnectorFactory.register_connector( + def get_done_tranfers(self) -> List[str]:
+ "DynemoNcclConnector", + done_req_ids = []
+ "vllm.distributed.kv_transfer.kv_connector.dynemo_connector", + for req_id, handles in self._transfers.items():
+ "DynemoConnector") + running_reqs = []
\ No newline at end of file + for handle in handles:
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py + xfer_state = self.nixl_wrapper.check_xfer_state(handle)
index 2033e976..e33919c1 100644 + if xfer_state == "DONE":
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py + # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py + continue
@@ -8,13 +8,15 @@ MooncakePipe. + if xfer_state == "PROC":
+ running_reqs.append(handle)
But the logic can be extended to support other pipe and lookup buffer. + else:
""" + raise RuntimeError("Transfer failed with state %s", xfer_state)
+ if len(running_reqs) == 0:
+ done_req_ids.append(req_id)
+ else:
+ self._transfers[req_id] = running_reqs
+ return done_req_ids
diff --git a/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
new file mode 100644
index 00000000..2319867a
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
+Simple KV Cache Connector for Distributed Machine Learning Inference
+
+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
+MooncakePipe.
+
+But the logic can be extended to support other pipe and lookup buffer.
+"""
+import re +import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
import torch +import torch
+
from vllm import _custom_ops as ops +from vllm import _custom_ops as ops
-from vllm.config import VllmConfig
+from vllm.config import VllmConfig, KVTransferConfig +from vllm.config import VllmConfig, KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer) + SimpleBuffer)
from vllm.logger import init_logger +from vllm.logger import init_logger
@@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase): +from vllm.sequence import IntermediateTensors
rank: int, +
local_rank: int, +if TYPE_CHECKING:
config: VllmConfig, + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
+
+logger = init_logger(__name__)
+
+
+class DynemoConnector(KVConnectorBase):
+
+ def __init__(
+ self,
+ rank: int,
+ local_rank: int,
+ config: VllmConfig,
+ world_group, + world_group,
): + ):
+
self.config = config.kv_transfer_config + self.config = config.kv_transfer_config
@@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase): + self.tp_size = config.parallel_config.tensor_parallel_size
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + self.rank = rank
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] +
+ if self.config.kv_connector != "DynemoNcclConnector":
+ raise NotImplementedError("Only DynemoNcclConnector is supported by the DynemoConnector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.dynemo_nccl_pipe import (
+ DynemoNcclDataPlane)
+
+ logger.info(
+ "Initializing DynemoNcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
+
+ self.producer_data_pipe: PyNcclPipe
+ self.consumer_data_pipe: PyNcclPipe
+ self.producer_signal_pipe: PyNcclPipe
+ self.consumer_signal_pipe: PyNcclPipe
+
+ self._broadcast_and_enhance_kv_config(rank, config, world_group) + self._broadcast_and_enhance_kv_config(rank, config, world_group)
+ +
+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+ self.tp_size = config.parallel_config.tensor_parallel_size + self.tp_size = config.parallel_config.tensor_parallel_size
+ +
# 2 pipes for every rank in the world + # 2 pipes for every rank in the world
- port_offset_base = 2 * rank
+ if self.config.is_kv_producer: + if self.config.is_kv_producer:
+ port_offset_base = 2 * rank + 1 + port_offset_base = rank + 1
+ else: + else:
+ port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1 + port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
+
+
+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
# In disaggregated prefill, the prefill vLLM only uses send pipe + self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase):
# its recv pipe to the send pipe of KV producder
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase):
self.config.kv_buffer_size,
)
- def select(self, input_tokens: Optional[torch.Tensor],
+ def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
+ logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
+ +
assert self.consumer_buffer is not None, "Please initialize the "\ + self.data_pipe = PyNcclPipe(
"consumer buffer before calling select." + kv_group_rank=self.kv_group_rank,
- return self.consumer_buffer.drop_select(input_tokens, roi) + local_rank=local_rank,
+ return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi) + config=self.config,
+ port_offset=port_offset_base,
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + )
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
+ logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
+ +
assert self.producer_buffer is not None, "Please initialize the "\ + self.data_plane = DynemoNcclDataPlane(
"producer buffer before calling insert." + data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
- self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + )
+ self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden) +
+ def send_kv_caches_and_hidden_states(
def send_kv_caches_and_hidden_states( + self,
self, + model_executable: torch.nn.Module,
@@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase): + model_input: "ModelInputForGPUWithSamplingMetadata",
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + kv_caches: List[torch.Tensor],
start_layer = model_executable.model.start_layer + hidden_or_intermediate_states: Union[torch.Tensor,
end_layer = model_executable.model.end_layer + IntermediateTensors],
+ ) -> None:
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
+ start_layer = model_executable.model.start_layer
+ end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) + request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
model_config = model_executable.model.config + model_config = model_executable.model.config
- num_heads = int(model_config.num_key_value_heads / self.tp_size)
- hidden_size = model_config.hidden_size
- num_attention_heads = model_config.num_attention_heads
- head_size = int(hidden_size / num_attention_heads)
+ is_deepseek = "deepseek" in model_config.architectures[0].lower() + is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ if not is_deepseek: + if not is_deepseek:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size) + num_heads = int(model_config.num_key_value_heads / self.tp_size)
...@@ -1018,38 +1161,31 @@ index 2033e976..e33919c1 100644 ...@@ -1018,38 +1161,31 @@ index 2033e976..e33919c1 100644
+ hidden_size = model_config.hidden_size + hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads + num_attention_heads = model_config.num_attention_heads
+ head_size = int(4.5 * hidden_size / num_attention_heads) + head_size = int(4.5 * hidden_size / num_attention_heads)
+
# query_lens contains new KV caches that are added to vLLM. + # query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance + # so we will send them to decode instance
@@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase): + # FIXME(Kuntai): This assume that all requests are prefill.
start_pos = sum(seq_lens[:idx]) + for idx, slen in enumerate(seq_lens):
end_pos = start_pos + slen + start_pos = sum(seq_lens[:idx])
current_tokens = input_tokens_tensor[start_pos:end_pos] + end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx] + current_request_id = request_ids[idx]
+ _, decode_kv_rank = self.parse_request_id(current_request_id) + decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
+ starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config) + decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
+ +
+ for target_rank in range(self.config.tensor_parallel_multiplier): + for target_rank in range(self.config.tensor_parallel_multiplier):
+
- keys, values = [], []
+ keys, values = [], [] + keys, values = [], []
+
- for layer_id in range(start_layer, end_layer):
- kv_cache = kv_caches[layer_id - start_layer]
+ for layer_id in range(start_layer, end_layer): + for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer] + kv_cache = kv_caches[layer_id - start_layer]
+
- key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
- value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+
- current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+ head_start = target_rank * num_heads_per_rank + head_start = target_rank * num_heads_per_rank
+ head_end = head_start + num_heads_per_rank + head_end = head_start + num_heads_per_rank
+
- keys.append(key_cache[current_slot_mapping].unsqueeze(0))
- values.append(value_cache[current_slot_mapping].unsqueeze(0))
+ if not is_deepseek: + if not is_deepseek:
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
...@@ -1059,75 +1195,70 @@ index 2033e976..e33919c1 100644 ...@@ -1059,75 +1195,70 @@ index 2033e976..e33919c1 100644
+ key_cache = kv_cache + key_cache = kv_cache
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+ values.append(torch.empty(0)) + values.append(torch.empty(0))
+
- keys = torch.cat(keys, dim=0)
- values = torch.cat(values, dim=0)
+ keys = torch.cat(keys, dim=0) + keys = torch.cat(keys, dim=0)
+ values = torch.cat(values, dim=0) + values = torch.cat(values, dim=0)
+
- self.insert(current_tokens, + decode_global_rank = decode_first_global_rank + target_rank
- torch.ones_like(current_tokens, + decode_port = self._get_data_plane_port(decode_global_rank)
- dtype=bool), keys, values, + partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
- hidden_or_intermediate_states[start_pos:end_pos]) + self._send(decode_hostname, decode_port, current_request_id, keys, values,
+ self.insert(starting_kv_group_rank, target_rank, current_tokens, + partial_hidden_or_intermediate_states)
+ torch.ones_like(current_tokens, +
+ dtype=bool), keys, values, + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
+ hidden_or_intermediate_states[start_pos:end_pos]) +
+ def recv_kv_caches_and_hidden_states(
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + self, model_executable: torch.nn.Module,
+ model_input: "ModelInputForGPUWithSamplingMetadata",
@@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase): + kv_caches: List[torch.Tensor]
input_tokens_tensor = model_input.input_tokens + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
seq_lens = model_input.attn_metadata.seq_lens + "ModelInputForGPUWithSamplingMetadata"]:
slot_mapping = model_input.attn_metadata.slot_mapping.flatten() +
+ # When bypass_model_exec is set to False, it means that at least for one
+ # request its corresponding KV cache or hidden state is missing.
+ # In this case we need to do prefilling to recompute missing KV cache
+ # and hidden states.
+ bypass_model_exec = True
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) + request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
hidden_or_intermediate_states_for_one_req = [] + hidden_or_intermediate_states_for_one_req = []
+
@@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase): + input_tokens_list = []
num_computed_tokens_list = [] + start_pos_list = []
start_pos_list = [] +
+ model_config = model_executable.model.config + model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower() + is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ +
# enumerate different requests + # enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill. + # FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens): + for idx, slen in enumerate(seq_lens):
@@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase): +
start_pos = sum(seq_lens[:idx]) + start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen + end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos] + current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx] + current_request_id = request_ids[idx]
+ prefill_rank, _ = self.parse_request_id(current_request_id) + num_tokens = slen
num_tokens = slen +
+ # collecting data for rebuilding the input
# collecting data for rebuilding the input + input_tokens_list.append(current_tokens)
input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos)
start_pos_list.append(start_pos) +
+ ret = self._recv(current_request_id)
- ret = self.select(current_tokens, + keys: torch.Tensor = ret[0]
+ ret = self.select(prefill_rank, current_tokens, + values: torch.Tensor = ret[1]
torch.ones_like(current_tokens, dtype=bool)) + hidden: torch.Tensor = ret[2]
if ret[0] is None: +
# didn't find any match. + # put received KV caches into paged memory
@@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase): + for i in range(model_executable.model.start_layer,
kv_cache = kv_caches[i - model_executable.model.start_layer] + model_executable.model.end_layer):
layer = model_executable.model.layers[i] +
+ kv_cache = kv_caches[i - model_executable.model.start_layer]
- key_cache, value_cache = kv_cache[0], kv_cache[1] + layer = model_executable.model.layers[i]
- ops.reshape_and_cache_flash( +
- keys[i - model_executable.model.start_layer].to(
- key_cache.device),
- values[i - model_executable.model.start_layer].to(
- value_cache.device),
- key_cache,
- value_cache,
- slot_mapping[start_pos:end_pos],
- layer.self_attn.attn.kv_cache_dtype,
- layer.self_attn.attn._k_scale,
- layer.self_attn.attn._v_scale,
- )
+ if not is_deepseek: + if not is_deepseek:
+ key_cache, value_cache = kv_cache[0], kv_cache[1] + key_cache, value_cache = kv_cache[0], kv_cache[1]
+ ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash(
...@@ -1147,32 +1278,58 @@ index 2033e976..e33919c1 100644 ...@@ -1147,32 +1278,58 @@ index 2033e976..e33919c1 100644
+ copy_from =keys[i - model_executable.model.start_layer].to( + copy_from =keys[i - model_executable.model.start_layer].to(
+ key_cache.device) + key_cache.device)
+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from + kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
+
hidden_or_intermediate_states_for_one_req.append(hidden) + hidden_or_intermediate_states_for_one_req.append(hidden)
+
@@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase): + if not bypass_model_exec:
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to + # Some of the KV cache is not retrieved
# close the data_pipe. + # Here we will fall back to normal model forwarding
pass + # But optionally you can adjust model_input so that you only do
+ # prefilling on those tokens that are missing KV caches.
+ logger.debug(
+ "[rank%d]: Failed to receive all KVs and hidden "
+ "states, redo model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = None
+
+ else:
+ logger.debug(
+ "[rank%d]: Successfully received all KVs and hidden "
+ "states, skip model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = torch.cat(
+ hidden_or_intermediate_states_for_one_req, dim=0)
+
+ return hidden_or_intermediate_states, bypass_model_exec, model_input
+
+ def close(self):
+ self.data_pipe.close()
+ # self.data_plane.close()
+ +
+ @staticmethod + @staticmethod
+ def parse_request_id(request_id): + def parse_request_id(request_id: str) -> Tuple[str, int]:
+ # Regular expression to match the ranks + # Regular expression to match the string hostname and integer decode_kv_rank
+ pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)" + pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
+ +
+ # Use re.search to find the pattern in the request_id + # Use re.search to find the pattern in the request_id
+ match = re.search(pattern, request_id) + match = re.search(pattern, request_id)
+
+ if match: + if match:
+ # Extract the ranks + # Extract the ranks
+ prefill_rank = int(match.group(1)) + decode_hostname = match.group(1)
+ decode_rank = int(match.group(2)) + decode_rank = int(match.group(2))
+ +
+ return prefill_rank, decode_rank + return decode_hostname, decode_rank
+ else: + raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
+ return None, None
+ +
+ + def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
+ remote_address = f"{hostname}:{port}"
+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
+
+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ keys = self.data_plane.recv_tensor(f"{request_id}_keys")
+ values = self.data_plane.recv_tensor(f"{request_id}_values")
+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
+ return keys, values, hidden
+ +
+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank < config.kv_producers_parallel_size: + if kv_rank < config.kv_producers_parallel_size:
...@@ -1180,45 +1337,53 @@ index 2033e976..e33919c1 100644 ...@@ -1180,45 +1337,53 @@ index 2033e976..e33919c1 100644
+ +
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+ +
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if rank == 0: + if kv_rank <= config.kv_producers_parallel_size:
+ if self.config.kv_connector == "PyNcclConnector": + return kv_rank * config.kv_producers_tensor_parallel_size + rank
+ config_group = StatelessProcessGroup.create( +
+ host=self.config.kv_ip, + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ port=self.config.kv_port, + return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+ +
+ else: +
+ raise NotImplementedError("MooncakeConnector is not supported in Dynemo Distributed vllm patch") + def _get_data_plane_port(self, global_kv_rank: int) -> int:
+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
+
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0:
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+ else: + else:
+ kv_config_enhanced = world_group.broadcast_object() + kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced) + logger.info("kv_config_enhanced: %s", kv_config_enhanced)
...@@ -1228,120 +1393,152 @@ index 2033e976..e33919c1 100644 ...@@ -1228,120 +1393,152 @@ index 2033e976..e33919c1 100644
+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
diff --git a/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py index fe480533..f4775663 100644
new file mode 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py
index 00000000..cb3b3660 +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
--- /dev/null @@ -27,13 +27,13 @@ class KVConnectorFactory:
+++ b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
@@ -0,0 +1,350 @@ @classmethod
+# SPDX-License-Identifier: Apache-2.0 def create_connector(cls, rank: int, local_rank: int,
+""" - config: "VllmConfig") -> KVConnectorBase:
+Simple KV Cache Connector for Distributed Machine Learning Inference + config: "VllmConfig", world_group) -> KVConnectorBase:
+ connector_name = config.kv_transfer_config.kv_connector
+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache if connector_name not in cls._registry:
+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or raise ValueError(f"Unsupported connector type: {connector_name}")
+MooncakePipe.
connector_cls = cls._registry[connector_name]()
- return connector_cls(rank, local_rank, config)
+ return connector_cls(rank, local_rank, config, world_group)
# Register various connectors here.
@@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
+ +
+But the logic can be extended to support other pipe and lookup buffer. +KVConnectorFactory.register_connector(
+""" + "DynemoNcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.dynemo_connector",
+ "DynemoConnector")
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e0537903 100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -8,13 +8,15 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
+import re +import re
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import torch import torch
+
+from vllm import _custom_ops as ops from vllm import _custom_ops as ops
-from vllm.config import VllmConfig
+from vllm.config import VllmConfig, KVTransferConfig +from vllm.config import VllmConfig, KVTransferConfig
+from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.utils import StatelessProcessGroup
+from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
+ SimpleBuffer) SimpleBuffer)
+from vllm.logger import init_logger from vllm.logger import init_logger
+from vllm.sequence import IntermediateTensors @@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
+ rank: int,
+if TYPE_CHECKING: local_rank: int,
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata config: VllmConfig,
+
+logger = init_logger(__name__)
+
+
+class DynemoConnector(KVConnectorBase):
+
+ def __init__(
+ self,
+ rank: int,
+ local_rank: int,
+ config: VllmConfig,
+ world_group, + world_group,
+ ): ):
+
+ self.config = config.kv_transfer_config self.config = config.kv_transfer_config
+ self.tp_size = config.parallel_config.tensor_parallel_size @@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase):
+ self.rank = rank self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
+ self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
+ if self.config.kv_connector != "DynemoNcclConnector":
+ raise NotImplementedError("Only DynemoNcclConnector is supported by the DynemoConnector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.dynemo_nccl_pipe import (
+ DynemoNcclDataPlane)
+
+ logger.info(
+ "Initializing DynemoNcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
+
+ self.producer_data_pipe: PyNcclPipe
+ self.consumer_data_pipe: PyNcclPipe
+ self.producer_signal_pipe: PyNcclPipe
+ self.consumer_signal_pipe: PyNcclPipe
+
+ self._broadcast_and_enhance_kv_config(rank, config, world_group) + self._broadcast_and_enhance_kv_config(rank, config, world_group)
+ +
+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+ self.tp_size = config.parallel_config.tensor_parallel_size + self.tp_size = config.parallel_config.tensor_parallel_size
+ +
+ # 2 pipes for every rank in the world # 2 pipes for every rank in the world
- port_offset_base = 2 * rank
+ if self.config.is_kv_producer: + if self.config.is_kv_producer:
+ port_offset_base = rank + 1 + port_offset_base = 2 * rank + 1
+ else: + else:
+ port_offset_base = rank // self.config.tensor_parallel_multiplier + 1 + port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1
+
+
+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
+ self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config) # In disaggregated prefill, the prefill vLLM only uses send pipe
+ # and the decode vLLM only uses recv pipe
+ self.data_pipe = PyNcclPipe( if self.config.is_kv_producer:
+ kv_group_rank=self.kv_group_rank,
+ local_rank=local_rank, if self.config.kv_connector == "PyNcclConnector":
+ config=self.config, self.producer_data_pipe = PyNcclPipe(
+ port_offset=port_offset_base, + kv_group_rank=self.kv_group_rank,
+ ) local_rank=local_rank,
+ config=self.config,
+ self.data_plane = DynemoNcclDataPlane( port_offset=port_offset_base,
+ data_pipe=self.data_pipe, )
+ port=self._get_data_plane_port(self.global_kv_rank), self.producer_signal_pipe = PyNcclPipe(
+ ) + kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase):
# its recv pipe to the send pipe of KV producder
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase):
self.config.kv_buffer_size,
)
- def select(self, input_tokens: Optional[torch.Tensor],
+ def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
+ logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
+ +
+ def send_kv_caches_and_hidden_states( assert self.consumer_buffer is not None, "Please initialize the "\
+ self, "consumer buffer before calling select."
+ model_executable: torch.nn.Module, - return self.consumer_buffer.drop_select(input_tokens, roi)
+ model_input: "ModelInputForGPUWithSamplingMetadata", + return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi)
+ kv_caches: List[torch.Tensor],
+ hidden_or_intermediate_states: Union[torch.Tensor, - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ IntermediateTensors], + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
+ ) -> None: key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
+ logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
+ +
+ input_tokens_tensor = model_input.input_tokens assert self.producer_buffer is not None, "Please initialize the "\
+ seq_lens = model_input.attn_metadata.seq_lens "producer buffer before calling insert."
+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
+ start_layer = model_executable.model.start_layer - self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
+ end_layer = model_executable.model.end_layer + self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
@@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) + request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ model_config = model_executable.model.config model_config = model_executable.model.config
- num_heads = int(model_config.num_key_value_heads / self.tp_size)
- hidden_size = model_config.hidden_size
- num_attention_heads = model_config.num_attention_heads
- head_size = int(hidden_size / num_attention_heads)
+ is_deepseek = "deepseek" in model_config.architectures[0].lower() + is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ if not is_deepseek: + if not is_deepseek:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size) + num_heads = int(model_config.num_key_value_heads / self.tp_size)
...@@ -1353,31 +1550,38 @@ index 00000000..cb3b3660 ...@@ -1353,31 +1550,38 @@ index 00000000..cb3b3660
+ hidden_size = model_config.hidden_size + hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads + num_attention_heads = model_config.num_attention_heads
+ head_size = int(4.5 * hidden_size / num_attention_heads) + head_size = int(4.5 * hidden_size / num_attention_heads)
+
+ # query_lens contains new KV caches that are added to vLLM. # query_lens contains new KV caches that are added to vLLM.
+ # so we will send them to decode instance # so we will send them to decode instance
+ # FIXME(Kuntai): This assume that all requests are prefill. @@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase):
+ for idx, slen in enumerate(seq_lens): start_pos = sum(seq_lens[:idx])
+ start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen
+ end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx] + current_request_id = request_ids[idx]
+ decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id) + _, decode_kv_rank = self.parse_request_id(current_request_id)
+ decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config) + starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
+ +
+ for target_rank in range(self.config.tensor_parallel_multiplier): + for target_rank in range(self.config.tensor_parallel_multiplier):
+
- keys, values = [], []
+ keys, values = [], [] + keys, values = [], []
+
- for layer_id in range(start_layer, end_layer):
- kv_cache = kv_caches[layer_id - start_layer]
+ for layer_id in range(start_layer, end_layer): + for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer] + kv_cache = kv_caches[layer_id - start_layer]
+
- key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
- value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+
- current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+ head_start = target_rank * num_heads_per_rank + head_start = target_rank * num_heads_per_rank
+ head_end = head_start + num_heads_per_rank + head_end = head_start + num_heads_per_rank
+
- keys.append(key_cache[current_slot_mapping].unsqueeze(0))
- values.append(value_cache[current_slot_mapping].unsqueeze(0))
+ if not is_deepseek: + if not is_deepseek:
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
...@@ -1387,70 +1591,75 @@ index 00000000..cb3b3660 ...@@ -1387,70 +1591,75 @@ index 00000000..cb3b3660
+ key_cache = kv_cache + key_cache = kv_cache
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+ values.append(torch.empty(0)) + values.append(torch.empty(0))
+
- keys = torch.cat(keys, dim=0)
- values = torch.cat(values, dim=0)
+ keys = torch.cat(keys, dim=0) + keys = torch.cat(keys, dim=0)
+ values = torch.cat(values, dim=0) + values = torch.cat(values, dim=0)
+
+ decode_global_rank = decode_first_global_rank + target_rank - self.insert(current_tokens,
+ decode_port = self._get_data_plane_port(decode_global_rank) - torch.ones_like(current_tokens,
+ partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos] - dtype=bool), keys, values,
+ self._send(decode_hostname, decode_port, current_request_id, keys, values, - hidden_or_intermediate_states[start_pos:end_pos])
+ partial_hidden_or_intermediate_states) + self.insert(starting_kv_group_rank, target_rank, current_tokens,
+ + torch.ones_like(current_tokens,
+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + dtype=bool), keys, values,
+ + hidden_or_intermediate_states[start_pos:end_pos])
+ def recv_kv_caches_and_hidden_states(
+ self, model_executable: torch.nn.Module, logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
+ model_input: "ModelInputForGPUWithSamplingMetadata",
+ kv_caches: List[torch.Tensor] @@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase):
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, input_tokens_tensor = model_input.input_tokens
+ "ModelInputForGPUWithSamplingMetadata"]: seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ # When bypass_model_exec is set to False, it means that at least for one
+ # request its corresponding KV cache or hidden state is missing.
+ # In this case we need to do prefilling to recompute missing KV cache
+ # and hidden states.
+ bypass_model_exec = True
+
+ input_tokens_tensor = model_input.input_tokens
+ seq_lens = model_input.attn_metadata.seq_lens
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) + request_ids = list(model_input.request_ids_to_seq_ids.keys())
+
+ hidden_or_intermediate_states_for_one_req = [] hidden_or_intermediate_states_for_one_req = []
+
+ input_tokens_list = [] @@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase):
+ start_pos_list = [] num_computed_tokens_list = []
+ start_pos_list = []
+ model_config = model_executable.model.config + model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower() + is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ +
+ # enumerate different requests # enumerate different requests
+ # FIXME(Kuntai): This impl assumes that all requests are prefill. # FIXME(Kuntai): This impl assumes that all requests are prefill.
+ for idx, slen in enumerate(seq_lens): for idx, slen in enumerate(seq_lens):
+ @@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase):
+ start_pos = sum(seq_lens[:idx]) start_pos = sum(seq_lens[:idx])
+ end_pos = start_pos + slen end_pos = start_pos + slen
+ current_tokens = input_tokens_tensor[start_pos:end_pos] current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx] + current_request_id = request_ids[idx]
+ num_tokens = slen + prefill_rank, _ = self.parse_request_id(current_request_id)
+ num_tokens = slen
+ # collecting data for rebuilding the input
+ input_tokens_list.append(current_tokens) # collecting data for rebuilding the input
+ start_pos_list.append(start_pos) input_tokens_list.append(current_tokens)
+ start_pos_list.append(start_pos)
+ ret = self._recv(current_request_id)
+ keys: torch.Tensor = ret[0] - ret = self.select(current_tokens,
+ values: torch.Tensor = ret[1] + ret = self.select(prefill_rank, current_tokens,
+ hidden: torch.Tensor = ret[2] torch.ones_like(current_tokens, dtype=bool))
+ if ret[0] is None:
+ # put received KV caches into paged memory # didn't find any match.
+ for i in range(model_executable.model.start_layer, @@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase):
+ model_executable.model.end_layer): kv_cache = kv_caches[i - model_executable.model.start_layer]
+ layer = model_executable.model.layers[i]
+ kv_cache = kv_caches[i - model_executable.model.start_layer]
+ layer = model_executable.model.layers[i] - key_cache, value_cache = kv_cache[0], kv_cache[1]
+ - ops.reshape_and_cache_flash(
- keys[i - model_executable.model.start_layer].to(
- key_cache.device),
- values[i - model_executable.model.start_layer].to(
- value_cache.device),
- key_cache,
- value_cache,
- slot_mapping[start_pos:end_pos],
- layer.self_attn.attn.kv_cache_dtype,
- layer.self_attn.attn._k_scale,
- layer.self_attn.attn._v_scale,
- )
+ if not is_deepseek: + if not is_deepseek:
+ key_cache, value_cache = kv_cache[0], kv_cache[1] + key_cache, value_cache = kv_cache[0], kv_cache[1]
+ ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash(
...@@ -1470,58 +1679,32 @@ index 00000000..cb3b3660 ...@@ -1470,58 +1679,32 @@ index 00000000..cb3b3660
+ copy_from =keys[i - model_executable.model.start_layer].to( + copy_from =keys[i - model_executable.model.start_layer].to(
+ key_cache.device) + key_cache.device)
+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from + kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
+
+ hidden_or_intermediate_states_for_one_req.append(hidden) hidden_or_intermediate_states_for_one_req.append(hidden)
+
+ if not bypass_model_exec: @@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase):
+ # Some of the KV cache is not retrieved # MooncakePipe reuses data_pipe for signal_pipe, so we only have to
+ # Here we will fall back to normal model forwarding # close the data_pipe.
+ # But optionally you can adjust model_input so that you only do pass
+ # prefilling on those tokens that are missing KV caches.
+ logger.debug(
+ "[rank%d]: Failed to receive all KVs and hidden "
+ "states, redo model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = None
+
+ else:
+ logger.debug(
+ "[rank%d]: Successfully received all KVs and hidden "
+ "states, skip model forwarding.", torch.distributed.get_rank())
+ hidden_or_intermediate_states = torch.cat(
+ hidden_or_intermediate_states_for_one_req, dim=0)
+
+ return hidden_or_intermediate_states, bypass_model_exec, model_input
+
+ def close(self):
+ self.data_pipe.close()
+ # self.data_plane.close()
+ +
+ @staticmethod + @staticmethod
+ def parse_request_id(request_id: str) -> Tuple[str, int]: + def parse_request_id(request_id):
+ # Regular expression to match the string hostname and integer decode_kv_rank + # Regular expression to match the ranks
+ pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)" + pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)"
+ +
+ # Use re.search to find the pattern in the request_id + # Use re.search to find the pattern in the request_id
+ match = re.search(pattern, request_id) + match = re.search(pattern, request_id)
+
+ if match: + if match:
+ # Extract the ranks + # Extract the ranks
+ decode_hostname = match.group(1) + prefill_rank = int(match.group(1))
+ decode_rank = int(match.group(2)) + decode_rank = int(match.group(2))
+ +
+ return decode_hostname, decode_rank + return prefill_rank, decode_rank
+ raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank") + else:
+ + return None, None
+ def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
+ remote_address = f"{hostname}:{port}"
+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
+ +
+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +
+ keys = self.data_plane.recv_tensor(f"{request_id}_keys")
+ values = self.data_plane.recv_tensor(f"{request_id}_values")
+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
+ return keys, values, hidden
+ +
+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank < config.kv_producers_parallel_size: + if kv_rank < config.kv_producers_parallel_size:
...@@ -1529,53 +1712,45 @@ index 00000000..cb3b3660 ...@@ -1529,53 +1712,45 @@ index 00000000..cb3b3660
+ +
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+
+ def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank <= config.kv_producers_parallel_size:
+ return kv_rank * config.kv_producers_tensor_parallel_size + rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
+
+
+ def _get_data_plane_port(self, global_kv_rank: int) -> int:
+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
+ +
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0: + if rank == 0:
+ config_group = StatelessProcessGroup.create( + if self.config.kv_connector == "PyNcclConnector":
+ host=self.config.kv_ip, + config_group = StatelessProcessGroup.create(
+ port=self.config.kv_port, + host=self.config.kv_ip,
+ rank=self.config.kv_rank, + port=self.config.kv_port,
+ world_size=self.config.kv_parallel_size, + rank=self.config.kv_rank,
+ ) + world_size=self.config.kv_parallel_size,
+ parallel_configs = config_group.all_gather_obj({ + )
+ "kv_role": self.config.kv_role, + parallel_configs = config_group.all_gather_obj({
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "kv_role": self.config.kv_role,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ }) + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ logger.debug("parallel_configs: %s", parallel_configs) + })
+ kv_config_enhanced = { + logger.debug("parallel_configs: %s", parallel_configs)
+ "kv_producers_tensor_parallel_size": None, + kv_config_enhanced = {
+ "kv_consumers_tensor_parallel_size": None, + "kv_producers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None, + "kv_consumers_tensor_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None, + "kv_producers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0, + "kv_consumers_pipeline_parallel_size": None,
+ } + "kv_producers_parallel_size": 0,
+ for parallel_config in parallel_configs: + }
+ kv_role = parallel_config["kv_role"] + for parallel_config in parallel_configs:
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + kv_role = parallel_config["kv_role"]
+ + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+ if kv_role == "kv_producer": +
+ kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_role == "kv_producer":
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced["kv_producers_parallel_size"] += 1
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ else: + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + else:
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ world_group.broadcast_object(kv_config_enhanced) + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+
+ else:
+ raise NotImplementedError("MooncakeConnector is not supported in Dynemo patch")
+ else: + else:
+ kv_config_enhanced = world_group.broadcast_object() + kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced) + logger.info("kv_config_enhanced: %s", kv_config_enhanced)
...@@ -1585,7 +1760,6 @@ index 00000000..cb3b3660 ...@@ -1585,7 +1760,6 @@ index 00000000..cb3b3660
+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
index 5e1b6235..b4506877 100644 index 5e1b6235..b4506877 100644
--- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
...@@ -1787,84 +1961,214 @@ index 5e1b6235..b4506877 100644 ...@@ -1787,84 +1961,214 @@ index 5e1b6235..b4506877 100644
- self.data_pipe.send_tensor(roi) - self.data_pipe.send_tensor(roi)
+ self.signal_pipe.send_tensor(self.normal_signal, rank) + self.signal_pipe.send_tensor(self.normal_signal, rank)
+ +
+ self.data_pipe.send_tensor(torch.tensor(kv_rank), rank) + self.data_pipe.send_tensor(torch.tensor(kv_rank), rank)
+ self.data_pipe.send_tensor(input_tokens, rank) + self.data_pipe.send_tensor(input_tokens, rank)
+ self.data_pipe.send_tensor(roi, rank) + self.data_pipe.send_tensor(roi, rank)
- input_tokens = self.data_pipe.recv_tensor() - input_tokens = self.data_pipe.recv_tensor()
- roi = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor()
+ input_tokens = self.data_pipe.recv_tensor(rank) + input_tokens = self.data_pipe.recv_tensor(rank)
+ roi = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank)
if roi is not None: if roi is not None:
# convert from float tensor to bool tensor # convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor # as PyNccl does not support sending bool tensor
roi = (roi > 0.5) roi = (roi > 0.5)
- key = self.data_pipe.recv_tensor() - key = self.data_pipe.recv_tensor()
- value = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor()
- hidden = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor()
+ key = self.data_pipe.recv_tensor(rank) + key = self.data_pipe.recv_tensor(rank)
+ value = self.data_pipe.recv_tensor(rank) + value = self.data_pipe.recv_tensor(rank)
+ hidden = self.data_pipe.recv_tensor(rank) + hidden = self.data_pipe.recv_tensor(rank)
return [input_tokens, roi, key, value, hidden] return [input_tokens, roi, key, value, hidden]
def full_handler(self): def full_handler(self):
time.sleep(0.001) time.sleep(0.001)
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None: hidden: torch.Tensor) -> None:
@@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase): @@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase):
while self.buffer_size > self.buffer_size_threshold: while self.buffer_size > self.buffer_size_threshold:
self.full_handler() self.full_handler()
- self._add_to_buffer(input_tokens, roi, key, value, hidden) - self._add_to_buffer(input_tokens, roi, key, value, hidden)
+ self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden) + self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender # when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request. # need to launch the request handler and start listening to request.
+ target_rank_global = target_rank + kv_group_rank + target_rank_global = target_rank + kv_group_rank
if self.request_handling_thread is None: if self.request_handling_thread is None:
- self.request_handling_thread = threading.Thread( - self.request_handling_thread = threading.Thread(
- target=self.drop_select_handler) - target=self.drop_select_handler)
- self.request_handling_thread.start() - self.request_handling_thread.start()
+ self.request_handling_thread = ThreadPoolExecutor(max_workers=1) + self.request_handling_thread = ThreadPoolExecutor(max_workers=1)
+ self.request_handling_thread.submit(self.drop_select_handler, target_rank_global) + self.request_handling_thread.submit(self.drop_select_handler, target_rank_global)
def close(self): def close(self):
- if hasattr(self, "request_handling_thread" - if hasattr(self, "request_handling_thread"
- ) and self.request_handling_thread is not None: - ) and self.request_handling_thread is not None:
- self.request_handling_thread.join() - self.request_handling_thread.join()
+ if hasattr(self, "request_handling_thread") and self.request_handling_thread: + if hasattr(self, "request_handling_thread") and self.request_handling_thread:
+ self.request_handling_thread.shutdown() + self.request_handling_thread.shutdown()
else: else:
# TODO: have a explicit close signal and have a explicit way to # TODO: have a explicit close signal and have a explicit way to
diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py
index 40589fb3..da2829cf 100644 index 40589fb3..da2829cf 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/base.py --- a/vllm/distributed/kv_transfer/kv_pipe/base.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py
@@ -23,7 +23,7 @@ class KVPipeBase(ABC): @@ -23,7 +23,7 @@ class KVPipeBase(ABC):
""" """
@abstractmethod @abstractmethod
- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None:
"""Send a tensor, or None, via the pipe. """Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling. Need to support sending None -- important for error handling.
@@ -41,7 +41,7 @@ class KVPipeBase(ABC): @@ -41,7 +41,7 @@ class KVPipeBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
- def recv_tensor(self) -> Optional[torch.Tensor]: - def recv_tensor(self) -> Optional[torch.Tensor]:
+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline. """Receive a tensor (can be None) from the pipeline.
Returns: Returns:
diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
new file mode 100644
index 00000000..58d0d28c
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
+import typing
+import zmq
+import socket
+import time
+import torch
+
+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
+
+
+logger = logging.getLogger(__name__)
+
+
+class DynemoNcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
+ hostname: str = "",
+ port: int = 0,
+ ) -> None:
+
+ self.data_pipe = data_pipe
+ if not hostname:
+ hostname = socket.gethostname()
+ if port == 0:
+ raise ValueError("Port cannot be 0")
+ self._hostname = hostname
+ self._port = port
+ self.store = {}
+ self.context = zmq.Context()
+ self.rep_socket = self.context.socket(zmq.REP)
+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
+ self._listener_thread.start()
+ self.req_sockets = {}
+ logger.info(f"Rank {self.rank} connected to the server")
+
+ @property
+ def rank(self):
+ return self.data_pipe.kv_group_rank
+
+ def send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
+ return self._send_tensor(tensor, tensor_id, remote_address)
+
+ def recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ ret = self._recv_tensor(tensor_id, remote_address)
+ return ret
+
+ def _send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
+ if remote_address is None:
+ self.store[tensor_id] = tensor
+ else:
+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape)
+ # tensor_dtype = str(tensor.dtype)
+ if remote_address not in self.req_sockets:
+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
+
+ req_socket = self.req_sockets[remote_address]
+ # req_socket.connect(f"tcp://{remote_address}")
+ req_socket.send_string(f"PUT {self.rank} {tensor_id}")
+ dst_rank = req_socket.recv_string()
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
+ self.data_pipe.send_tensor(tensor, int(dst_rank))
+
+ def _recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ logger.debug(f"Rank {self.rank} receiving tensor")
+ if remote_address is not None:
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ if tensor_id in self.store:
+ logger.debug(f"Popping tensor {tensor_id} from store")
+ future = self.store.pop(tensor_id)
+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
+ logger.debug(f"Rank {self.rank} received tensor")
+ return tensor
+
+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
+ time.sleep(0.001)
+ return self._recv_tensor(tensor_id, remote_address)
+ # raise NotImplementedError("Tensor not found in store")
+
+ def _receive_tensor(
+ self,
+ tensor_id: str,
+ rank: int,
+ ):
+ future = self.data_pipe.recv_tensor(rank)
+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
+ self.store[tensor_id] = future
+
+ def listen_for_requests(self):
+ while True:
+ cmd, rank, tensor_id = self.rep_socket.recv_string().split()
+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
+ self.rep_socket.send_string(f"{self.rank}")
+ if cmd == "GET":
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ elif cmd == "PUT":
+ rank = int(rank)
+ # shape = [int(dim) for dim in shape.split("_")]
+ # dtype = getattr(torch, dtype)
+ self._receive_tensor(tensor_id, rank)
diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
index 7aa53d07..f5dd50b7 100644 index 7aa53d07..f5dd50b7 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
...@@ -2055,136 +2359,6 @@ index 7aa53d07..f5dd50b7 100644 ...@@ -2055,136 +2359,6 @@ index 7aa53d07..f5dd50b7 100644
def close(self): def close(self):
""" """
diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
new file mode 100644
index 00000000..8a356504
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
+import typing
+import zmq
+import socket
+import time
+import torch
+
+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
+
+
+logger = logging.getLogger(__name__)
+
+
+class DynemoNcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
+ hostname: str = "",
+ port: int = 0,
+ ) -> None:
+
+ self.data_pipe = data_pipe
+ if not hostname:
+ hostname = socket.gethostname()
+ if port == 0:
+ raise ValueError("Port cannot be 0")
+ self._hostname = hostname
+ self._port = port
+ self.store = {}
+ self.context = zmq.Context()
+ self.rep_socket = self.context.socket(zmq.REP)
+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
+ self._listener_thread.start()
+ self.req_sockets = {}
+ logger.info(f"Rank {self.rank} connected to the server")
+
+ @property
+ def rank(self):
+ return self.data_pipe.kv_group_rank
+
+ def send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
+ return self._send_tensor(tensor, tensor_id, remote_address)
+
+ def recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ ret = self._recv_tensor(tensor_id, remote_address)
+ return ret
+
+ def _send_tensor(
+ self,
+ tensor: torch.Tensor,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ):
+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
+ if remote_address is None:
+ self.store[tensor_id] = tensor
+ else:
+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape)
+ # tensor_dtype = str(tensor.dtype)
+ if remote_address not in self.req_sockets:
+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
+
+ req_socket = self.req_sockets[remote_address]
+ # req_socket.connect(f"tcp://{remote_address}")
+ req_socket.send_string(f"PUT {self.rank} {tensor_id}")
+ dst_rank = req_socket.recv_string()
+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
+ self.data_pipe.send_tensor(tensor, int(dst_rank))
+
+ def _recv_tensor(
+ self,
+ tensor_id: str,
+ remote_address: typing.Optional[str] = None,
+ ) -> torch.Tensor:
+ logger.debug(f"Rank {self.rank} receiving tensor")
+ if remote_address is not None:
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ if tensor_id in self.store:
+ logger.debug(f"Popping tensor {tensor_id} from store")
+ future = self.store.pop(tensor_id)
+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
+ logger.debug(f"Rank {self.rank} received tensor")
+ return tensor
+
+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
+ time.sleep(0.001)
+ return self._recv_tensor(tensor_id, remote_address)
+ # raise NotImplementedError("Tensor not found in store")
+
+ def _receive_tensor(
+ self,
+ tensor_id: str,
+ rank: int,
+ ):
+ future = self.data_pipe.recv_tensor(rank)
+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
+ self.store[tensor_id] = future
+
+ def listen_for_requests(self):
+ while True:
+ cmd, rank, tensor_id = self.rep_socket.recv_string().split()
+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
+ self.rep_socket.send_string(f"{self.rank}")
+ if cmd == "GET":
+ raise NotImplementedError("Getting tensor from remote rank not implemented")
+ elif cmd == "PUT":
+ rank = int(rank)
+ # shape = [int(dim) for dim in shape.split("_")]
+ # dtype = getattr(torch, dtype)
+ self._receive_tensor(tensor_id, rank)
diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py
index 1e80e0bd..cd90206f 100644 index 1e80e0bd..cd90206f 100644
--- a/vllm/distributed/kv_transfer/kv_transfer_agent.py --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py
...@@ -2221,7 +2395,7 @@ index 321902d1..b8937ef8 100644 ...@@ -2221,7 +2395,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..9ba1a326 100644 index d82d9ad9..62dbbd6e 100644
--- a/vllm/engine/llm_engine.py --- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@ @@ -2,13 +2,17 @@
...@@ -2308,7 +2482,9 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2308,7 +2482,9 @@ index d82d9ad9..9ba1a326 100644
+ self._nixl_agents_names = self._initialize_nixl() + self._nixl_agents_names = self._initialize_nixl()
+ +
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) + self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set() + self._finished_prefills = set()
+ self._finished_transfers = set()
+ +
+ @property + @property
+ def is_nixl_initialized(self) -> bool: + def is_nixl_initialized(self) -> bool:
...@@ -2327,8 +2503,6 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2327,8 +2503,6 @@ index d82d9ad9..9ba1a326 100644
+ engine_id = nixl_metadata.engine_id + engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata + agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr + kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr)) + return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+ +
+ def _initialize_nixl(self) -> List[bytes]: + def _initialize_nixl(self) -> List[bytes]:
...@@ -2338,16 +2512,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2338,16 +2512,7 @@ index d82d9ad9..9ba1a326 100644
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
@@ -500,6 +545,8 @@ class LLMEngine: @@ -552,11 +597,14 @@ class LLMEngine:
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
+ if self._nixl_agents_names:
+ model_executor.collective_rpc("shutdown_nixl")
model_executor.shutdown()
def get_tokenizer_group(
@@ -552,11 +599,14 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
...@@ -2362,6 +2527,15 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2362,6 +2527,15 @@ index d82d9ad9..9ba1a326 100644
ParallelSampleSequenceGroup.add_request( ParallelSampleSequenceGroup.add_request(
request_id, request_id,
self, self,
@@ -574,6 +622,8 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +634,7 @@ class LLMEngine: @@ -584,7 +634,7 @@ class LLMEngine:
encoder_inputs = None encoder_inputs = None
...@@ -2454,7 +2628,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2454,7 +2628,7 @@ index d82d9ad9..9ba1a326 100644
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule() - ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills) + ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers)
+ +
+ +
+ # Separate remote prefill and running seq groups + # Separate remote prefill and running seq groups
...@@ -2486,7 +2660,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2486,7 +2660,7 @@ index d82d9ad9..9ba1a326 100644
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1476,29 @@ class LLMEngine: @@ -1383,9 +1476,31 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
...@@ -2500,9 +2674,11 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2500,9 +2674,11 @@ index d82d9ad9..9ba1a326 100644
+ req_id = scheduled_seq_group.seq_group.request_id + req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id] + block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest( + memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id, + request_id=req_id,
+ src_block_ids=block_table, + src_block_ids=block_table,
+ staging_block_ids=staging_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids, + dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id, + dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id, + notify_msg=req_id,
...@@ -2512,13 +2688,13 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2512,13 +2688,13 @@ index d82d9ad9..9ba1a326 100644
+ +
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs + execute_model_req.memory_transfer_requests = memory_transfer_reqs
+ +
+ outputs, request_notif_counter = self.model_executor.execute_model( + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
- -
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1396,7 +1509,20 @@ class LLMEngine: @@ -1396,7 +1511,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# No outputs in this case # No outputs in this case
...@@ -2529,7 +2705,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2529,7 +2705,7 @@ index d82d9ad9..9ba1a326 100644
+ blocks_to_swap_out=[], + blocks_to_swap_out=[],
+ blocks_to_copy=[]) + blocks_to_copy=[])
+ +
+ outputs, request_notif_counter = self.model_executor.execute_model( + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req) + execute_model_req=execute_model_req)
+ +
+ for req_id, notif_count in request_notif_counter.items(): + for req_id, notif_count in request_notif_counter.items():
...@@ -2537,10 +2713,16 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2537,10 +2713,16 @@ index d82d9ad9..9ba1a326 100644
+ if self._request_notif_counter[req_id] > -1: + if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id) + self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id] + del self._request_notif_counter[req_id]
+
+ for req_id, done_count in request_done_counter.items():
+ self._request_done_counter[req_id] += done_count
+ if self._request_done_counter[req_id] > -1:
+ self._finished_transfers.add(req_id)
+ del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1456,7 +1582,7 @@ class LLMEngine: @@ -1456,7 +1590,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters. # queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.") logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
...@@ -2618,7 +2800,7 @@ index 3cf1850e..6b90ece7 100644 ...@@ -2618,7 +2800,7 @@ index 3cf1850e..6b90ece7 100644
+ kv_active_blocks: int + kv_active_blocks: int
+ kv_total_blocks: int + kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..d33d546a 100644 index 85b5f31e..3f8b8fad 100644
--- a/vllm/engine/multiprocessing/client.py --- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, @@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...@@ -2746,7 +2928,7 @@ index 85b5f31e..d33d546a 100644 ...@@ -2746,7 +2928,7 @@ index 85b5f31e..d33d546a 100644
+ kv_metrics.kv_active_blocks, + kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks) + kv_metrics.kv_total_blocks)
+ +
+ logger.debug("Metircs successful.") + logger.debug("Metircs successful.")
+ +
+ except asyncio.CancelledError: + except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.") + logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
...@@ -3151,10 +3333,10 @@ index 786380c3..56a7cf89 100644 ...@@ -3151,10 +3333,10 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request. """The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644 new file mode 100644
index 00000000..03f02006 index 00000000..957f55de
--- /dev/null --- /dev/null
+++ b/vllm/remote_prefill.py +++ b/vllm/remote_prefill.py
@@ -0,0 +1,53 @@ @@ -0,0 +1,54 @@
+from dataclasses import dataclass +from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine +from typing import Callable, Optional, List, Coroutine
+ +
...@@ -3192,6 +3374,7 @@ index 00000000..03f02006 ...@@ -3192,6 +3374,7 @@ index 00000000..03f02006
+ """ + """
+ request_id: str + request_id: str
+ src_block_ids: List[int] + src_block_ids: List[int]
+ staging_block_ids: List[int]
+ dst_block_ids: List[int] + dst_block_ids: List[int]
+ dst_engine_id: str + dst_engine_id: str
+ notify_msg: str + notify_msg: str
...@@ -3327,7 +3510,7 @@ index 534b9e60..18675d2f 100644 ...@@ -3327,7 +3510,7 @@ index 534b9e60..18675d2f 100644
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 12baecde..cbada27f 100644 index 12baecde..489d3b77 100644
--- a/vllm/worker/model_runner.py --- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py
@@ -1824,6 +1824,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): @@ -1824,6 +1824,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...@@ -3351,7 +3534,7 @@ index 12baecde..cbada27f 100644 ...@@ -3351,7 +3534,7 @@ index 12baecde..cbada27f 100644
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..ffb7b403 100644 index 582aa460..c01cfe00 100644
--- a/vllm/worker/worker.py --- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py +++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
...@@ -3392,13 +3575,13 @@ index 582aa460..ffb7b403 100644 ...@@ -3392,13 +3575,13 @@ index 582aa460..ffb7b403 100644
+ +
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str: + def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata[self.local_rank]) # TODO ptarasiewicz: rank or local_rank? + agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata)) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr[self.local_rank]) + self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ return agent_name + return agent_name
+ +
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None: + def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name[self.local_rank], notify_msg) # TODO ptarasiewicz: rank or local_rank? + self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+ +
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]: + def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
...@@ -3406,8 +3589,8 @@ index 582aa460..ffb7b403 100644 ...@@ -3406,8 +3589,8 @@ index 582aa460..ffb7b403 100644
+ +
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None: + def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None: + if worker_input.src_block_ids is not None:
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg): + for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg) + self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ +
+ def shutdown_nixl(self) -> None: + def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
...@@ -3425,11 +3608,12 @@ index 582aa460..ffb7b403 100644 ...@@ -3425,11 +3608,12 @@ index 582aa460..ffb7b403 100644
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
@@ -375,6 +416,10 @@ class Worker(LocalOrDistributedWorkerBase): @@ -375,6 +416,11 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_steps=num_steps, num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs], + src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs], + dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs], + dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs], + notify_msg=[r.notify_msg for r in mem_transfer_reqs],
...@@ -3437,7 +3621,7 @@ index 582aa460..ffb7b403 100644 ...@@ -3437,7 +3621,7 @@ index 582aa460..ffb7b403 100644
@torch.inference_mode() @torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..d9c039eb 100644 index 819b81fb..8dfdadde 100644
--- a/vllm/worker/worker_base.py --- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
...@@ -3465,11 +3649,12 @@ index 819b81fb..d9c039eb 100644 ...@@ -3465,11 +3649,12 @@ index 819b81fb..d9c039eb 100644
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,11 @@ class WorkerInput: @@ -216,6 +220,12 @@ class WorkerInput:
virtual_engine: int = 0 virtual_engine: int = 0
num_steps: int = 1 num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None + src_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None + dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None + dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None + notify_msg: Optional[List[str]] = None
...@@ -3477,29 +3662,31 @@ index 819b81fb..d9c039eb 100644 ...@@ -3477,29 +3662,31 @@ index 819b81fb..d9c039eb 100644
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"], cls: Type["WorkerInput"],
@@ -232,6 +241,10 @@ class WorkerInput: @@ -232,6 +242,11 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"], virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"), num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"), + src_block_ids=tensor_dict.pop("src_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"), + dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"), + dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"), + notify_msg=tensor_dict.pop("notify_msg"),
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
@@ -246,6 +259,10 @@ class WorkerInput: @@ -246,6 +261,11 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"num_steps": self.num_steps, "num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids, + "src_block_ids": self.src_block_ids,
+ "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids, + "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id, + "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg, + "notify_msg": self.notify_msg,
} }
return tensor_dict return tensor_dict
@@ -316,13 +333,16 @@ class LocalOrDistributedWorkerBase(WorkerBase): @@ -316,13 +336,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
...@@ -3521,7 +3708,7 @@ index 819b81fb..d9c039eb 100644 ...@@ -3521,7 +3708,7 @@ index 819b81fb..d9c039eb 100644
def _get_driver_input_and_broadcast( def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
@@ -396,49 +416,79 @@ class LocalOrDistributedWorkerBase(WorkerBase): @@ -396,49 +419,87 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input) self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
...@@ -3618,7 +3805,7 @@ index 819b81fb..d9c039eb 100644 ...@@ -3618,7 +3805,7 @@ index 819b81fb..d9c039eb 100644
+ else: + else:
+ for i in range(1, get_tp_group().world_size): + for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i)) + all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int) + request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs: + for notifs in all_new_notifs:
+ for req_ids in notifs.values(): + for req_ids in notifs.values():
...@@ -3627,12 +3814,20 @@ index 819b81fb..d9c039eb 100644 ...@@ -3627,12 +3814,20 @@ index 819b81fb..d9c039eb 100644
+ +
+ if request_notif_counter: + if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter) + logger.debug("Request notif counter: %s", request_notif_counter)
+
+ request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else: + else:
+ request_notif_counter = {} + request_notif_counter = {}
+ request_done_counter = {}
# output is List[SamplerOutput] # output is List[SamplerOutput]
- return output - return output
+ return output, request_notif_counter + return output, request_notif_counter, request_done_counter
+ +
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None: + def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass + pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment