Commit 17827e1d authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: Decode -> Prefill cached kv transfer (#340)

parent 405222ce
diff --git a/vllm/config.py b/vllm/config.py diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..5e1cf249 100644 index 9ba49757..a4df0019 100644
--- a/vllm/config.py --- a/vllm/config.py
+++ b/vllm/config.py +++ b/vllm/config.py
@@ -2620,6 +2620,9 @@ class KVTransferConfig(BaseModel): @@ -2620,6 +2620,9 @@ class KVTransferConfig(BaseModel):
...@@ -7,7 +7,7 @@ index 9ba49757..5e1cf249 100644 ...@@ -7,7 +7,7 @@ index 9ba49757..5e1cf249 100644
kv_connector: Optional[str] = None kv_connector: Optional[str] = None
+ # Whether to use NIXL prepped xfer for KV cache transfer. + # Whether to use NIXL prepped xfer for KV cache transfer.
+ use_prepped_xfer: bool = False + use_prepped_xfer: bool = True
+ +
# The device used by kv connector to buffer the KV cache. # The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'. # Currently only support 'cuda'.
...@@ -36,7 +36,7 @@ index 9ba49757..5e1cf249 100644 ...@@ -36,7 +36,7 @@ index 9ba49757..5e1cf249 100644
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
@@ -2680,11 +2691,12 @@ class KVTransferConfig(BaseModel): @@ -2680,11 +2691,16 @@ class KVTransferConfig(BaseModel):
f"Supported roles are `kv_producer`, `kv_consumer`, " f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`") f"and `kv_both`")
...@@ -46,11 +46,15 @@ index 9ba49757..5e1cf249 100644 ...@@ -46,11 +46,15 @@ index 9ba49757..5e1cf249 100644
"is set, supported roles are `kv_producer`, " "is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`") "`kv_consumer`, and `kv_both`")
+ if self.use_prepped_xfer is False:
+ logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.")
+ self.use_prepped_xfer = True
+
+ +
@property @property
def is_kv_transfer_instance(self) -> bool: def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \ return self.kv_connector is not None and \
@@ -2694,6 +2706,8 @@ class KVTransferConfig(BaseModel): @@ -2694,6 +2710,8 @@ class KVTransferConfig(BaseModel):
def need_kv_parallel_group(self) -> bool: def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create # for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1. # parallel group, and in that case the kv parallel size will be 1.
...@@ -59,7 +63,7 @@ index 9ba49757..5e1cf249 100644 ...@@ -59,7 +63,7 @@ index 9ba49757..5e1cf249 100644
return self.kv_connector is not None and self.kv_parallel_size > 1 return self.kv_connector is not None and self.kv_parallel_size > 1
@property @property
@@ -2706,6 +2720,18 @@ class KVTransferConfig(BaseModel): @@ -2706,6 +2724,18 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \ return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"] self.kv_role in ["kv_consumer", "kv_both"]
...@@ -185,7 +189,7 @@ index c388366b..31ed7aa4 100644 ...@@ -185,7 +189,7 @@ index c388366b..31ed7aa4 100644
def free(self, block: Block, keep_block_object: bool = False) -> None: def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id # Release the physical block id
diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py
index 1ca9e49d..b1591c0c 100644 index 1ca9e49d..cd780f69 100644
--- a/vllm/core/block/prefix_caching_block.py --- a/vllm/core/block/prefix_caching_block.py
+++ b/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py
@@ -4,7 +4,7 @@ import sys @@ -4,7 +4,7 @@ import sys
...@@ -246,8 +250,23 @@ index 1ca9e49d..b1591c0c 100644 ...@@ -246,8 +250,23 @@ index 1ca9e49d..b1591c0c 100644
return block.block_id return block.block_id
# Reuse the cached content hash # Reuse the cached content hash
@@ -579,9 +593,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
# Mark all touched blocks as computed.
- for block_id in self._touched_blocks:
- self._block_tracker[block_id].computed = True
- self._touched_blocks.clear()
+ for block_id in block_ids:
+ if block_id in self._touched_blocks:
+ logger.debug("Mark block as computed: %s", block_id)
+ self._block_tracker[block_id].computed = True
+ self._touched_blocks.remove(block_id)
def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None:
diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py
index c5b3b04f..12cd4dc9 100644 index c5b3b04f..21fe0fc8 100644
--- a/vllm/core/block_manager.py --- a/vllm/core/block_manager.py
+++ b/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py
@@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block @@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block
...@@ -299,6 +318,27 @@ index c5b3b04f..12cd4dc9 100644 ...@@ -299,6 +318,27 @@ index c5b3b04f..12cd4dc9 100644
) )
self.block_tables: Dict[SeqId, BlockTable] = {} self.block_tables: Dict[SeqId, BlockTable] = {}
@@ -108,7 +130,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def can_allocate(self,
seq_group: SequenceGroup,
- num_lookahead_slots: int = 0) -> AllocStatus:
+ num_lookahead_slots: int = 0,
+ is_remote_decode: bool = False) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
@@ -121,6 +144,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
num_lookahead_slots=num_lookahead_slots,
)
+ # if remote decode, we need to allocate twice as many blocks for staging
+ if is_remote_decode:
+ num_required_blocks *= 2
+
if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py
new file mode 100644 new file mode 100644
index 00000000..a27af580 index 00000000..a27af580
...@@ -414,7 +454,7 @@ index 00000000..a27af580 ...@@ -414,7 +454,7 @@ index 00000000..a27af580
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..9e6443bf 100644 index f507847a..170a359f 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py
@@ -4,22 +4,22 @@ import enum @@ -4,22 +4,22 @@ import enum
...@@ -574,7 +614,19 @@ index f507847a..9e6443bf 100644 ...@@ -574,7 +614,19 @@ index f507847a..9e6443bf 100644
leftover_waiting_sequences: Deque[SequenceGroup] = deque() leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue: while self._passed_delay(time.time()) and waiting_queue:
@@ -1008,7 +1060,18 @@ class Scheduler: @@ -961,8 +1013,10 @@ class Scheduler:
True, enable_chunking)
# If the sequence group cannot be allocated, stop.
+ is_remote_decode = seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode
can_allocate = self.block_manager.can_allocate(
- seq_group, num_lookahead_slots=num_lookahead_slots)
+ seq_group, num_lookahead_slots=num_lookahead_slots,
+ is_remote_decode=is_remote_decode)
if can_allocate == AllocStatus.LATER:
break
elif can_allocate == AllocStatus.NEVER:
@@ -1008,7 +1062,18 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0: if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id) curr_loras.add(lora_int_id)
waiting_queue.popleft() waiting_queue.popleft()
...@@ -587,14 +639,14 @@ index f507847a..9e6443bf 100644 ...@@ -587,14 +639,14 @@ index f507847a..9e6443bf 100644
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id) + logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group) + is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group)
+ num_remote_prefill_groups += is_remote_prefill + num_remote_prefill_groups += is_remote_prefill
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode: + if is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id) + logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy) + self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy) + self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step: if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
@@ -1046,9 +1109,11 @@ class Scheduler: @@ -1046,9 +1111,11 @@ class Scheduler:
seq_groups=seq_groups, seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups, ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
...@@ -608,22 +660,24 @@ index f507847a..9e6443bf 100644 ...@@ -608,22 +660,24 @@ index f507847a..9e6443bf 100644
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
@@ -1066,6 +1131,9 @@ class Scheduler: @@ -1066,9 +1133,13 @@ class Scheduler:
for seq_group in self.running: for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id, budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs()) seq_group.get_max_num_running_seqs())
+ for seq_group in self.prefill_sending: - curr_loras = set(
+ for seq_group in self.remote_prefilling:
+ budget.add_num_seqs(seq_group.request_id, + budget.add_num_seqs(seq_group.request_id,
+ seq_group.get_max_num_running_seqs()) + seq_group.get_max_num_running_seqs())
curr_loras = set( +
+ curr_loras = (set(
seq_group.lora_int_id for seq_group in self.running seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None - if seq_group.lora_int_id > 0) if self.lora_enabled else None
@@ -1087,10 +1155,12 @@ class Scheduler: + if seq_group.lora_int_id > 0) if self.lora_enabled else None)
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running prefills = SchedulerPrefillOutputs.create_empty()
# only contains decode requests, not chunked prefills. running_scheduled = SchedulerRunningOutputs.create_empty()
- if len(prefills.seq_groups) == 0: @@ -1090,7 +1161,9 @@ class Scheduler:
+ if len(prefills.seq_groups) == prefills.num_remote_prefill_groups: if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget, running_scheduled = self._schedule_running(budget,
curr_loras, curr_loras,
- enable_chunking=False) - enable_chunking=False)
...@@ -633,7 +687,7 @@ index f507847a..9e6443bf 100644 ...@@ -633,7 +687,7 @@ index f507847a..9e6443bf 100644
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
@@ -1106,7 +1176,12 @@ class Scheduler: @@ -1106,7 +1179,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
if len(prefills.seq_groups) > 0: if len(prefills.seq_groups) > 0:
...@@ -647,7 +701,7 @@ index f507847a..9e6443bf 100644 ...@@ -647,7 +701,7 @@ index f507847a..9e6443bf 100644
self.running.extend(running_scheduled.decode_seq_groups_list) self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1323,14 @@ class Scheduler: @@ -1248,12 +1326,14 @@ class Scheduler:
len(running_scheduled.swapped_out)), len(running_scheduled.swapped_out)),
) )
...@@ -664,7 +718,7 @@ index f507847a..9e6443bf 100644 ...@@ -664,7 +718,7 @@ index f507847a..9e6443bf 100644
def _can_append_slots(self, seq_group: SequenceGroup, def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool: enable_chunking: bool) -> bool:
@@ -1287,14 +1364,16 @@ class Scheduler: @@ -1287,14 +1367,16 @@ class Scheduler:
return no_single_seq return no_single_seq
def schedule( def schedule(
...@@ -684,7 +738,7 @@ index f507847a..9e6443bf 100644 ...@@ -684,7 +738,7 @@ index f507847a..9e6443bf 100644
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching: if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1412,8 @@ class Scheduler: @@ -1333,7 +1415,8 @@ class Scheduler:
encoder_seq_data = None encoder_seq_data = None
cross_block_table = None cross_block_table = None
...@@ -694,15 +748,40 @@ index f507847a..9e6443bf 100644 ...@@ -694,15 +748,40 @@ index f507847a..9e6443bf 100644
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,9 +1444,16 @@ class Scheduler: @@ -1342,7 +1425,9 @@ class Scheduler:
if self.cache_config.enable_prefix_caching:
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
- seq_group.get_seqs(status=SequenceStatus.RUNNING)))
+ running_or_remote_prefilling_seqs
+ )
+ )
do_sample = True
is_prompt = seq_group.is_prefill()
@@ -1364,9 +1449,30 @@ class Scheduler:
< seqs[0].data.get_len()): < seqs[0].data.get_len()):
do_sample = False do_sample = False
+ is_remote_prefill = False + is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True + is_remote_prefill = True
+ logger.debug("Remote prefill, computed block nums: %s", common_computed_block_nums)
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode: + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids + block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+
+ # Since we know that prefill is scheduled we can
+ # assume that the blocks computed on decode
+ # will be fetched by the time we run prefill
+ logger.debug("Computed decode blocks: %s", seq_group.remote_prefill_params.decode_computed_block_ids)
+ if seq_group.remote_prefill_params.decode_computed_block_ids:
+ computed_block_ids = set(seq_group.remote_prefill_params.decode_computed_block_ids)
+ prefill_block_ids = block_tables[seq_group.seqs[0].seq_id]
+ prefill_fetched_block_ids = [prefill_block_ids[i] for i, block_id in enumerate(seq_group.remote_prefill_params.decode_block_ids) if block_id in computed_block_ids and i < len(prefill_block_ids)]
+
+ assert len(common_computed_block_nums) == 0, "common_computed_block_nums should be empty for remote prefill as it doesn't suport prefix caching"
+ common_computed_block_nums = prefill_fetched_block_ids
+
+ +
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
...@@ -711,7 +790,7 @@ index f507847a..9e6443bf 100644 ...@@ -711,7 +790,7 @@ index f507847a..9e6443bf 100644
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,
@@ -1392,6 +1479,7 @@ class Scheduler: @@ -1392,6 +1498,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
...@@ -719,7 +798,7 @@ index f507847a..9e6443bf 100644 ...@@ -719,7 +798,7 @@ index f507847a..9e6443bf 100644
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for
@@ -1490,11 +1578,17 @@ class Scheduler: @@ -1490,11 +1597,17 @@ class Scheduler:
self._async_stopped.clear() self._async_stopped.clear()
...@@ -742,16 +821,50 @@ index f507847a..9e6443bf 100644 ...@@ -742,16 +821,50 @@ index f507847a..9e6443bf 100644
blocks_to_copy: List[Tuple[int, int]], blocks_to_copy: List[Tuple[int, int]],
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644 new file mode 100644
index 00000000..9b938039 index 00000000..b9485bd5
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py +++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@ @@ -0,0 +1,110 @@
+import torch +import torch
+import triton +import triton
+import triton.language as tl +import triton.language as tl
+ +
+@triton.jit +@triton.jit
+def rearrange_kernel( +def rearrange_kernel_read(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos))
+
+@triton.jit
+def rearrange_kernel_write(
+ t1_ptr, + t1_ptr,
+ t2_ptr, + t2_ptr,
+ N, + N,
...@@ -783,8 +896,10 @@ index 00000000..9b938039 ...@@ -783,8 +896,10 @@ index 00000000..9b938039
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset + dst_pos = tensor_subset_size * tp_group + tp_group_offset
+ +
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos)) + tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+
+ +
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int): +def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str):
+ N, B, H, C = t1.shape + N, B, H, C = t1.shape
+ +
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source" + assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
...@@ -798,22 +913,35 @@ index 00000000..9b938039 ...@@ -798,22 +913,35 @@ index 00000000..9b938039
+ BLOCK_SIZE = 1024 + BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,) + grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+ +
+ rearrange_kernel[grid]( + if direction == "read":
+ t1, t2, + rearrange_kernel_read[grid](
+ N, B, H, C, + t1, t2,
+ d, + N, B, H, C,
+ tensor_subset_size, + d,
+ block_size, + tensor_subset_size,
+ token_size, + block_size,
+ BLOCK_SIZE=BLOCK_SIZE + token_size,
+ ) + BLOCK_SIZE=BLOCK_SIZE
+ )
+ elif direction == "write":
+ rearrange_kernel_write[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
+ else:
+ raise ValueError(f"Invalid direction: {direction}")
\ No newline at end of file \ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 00000000..d972252a index 00000000..a8bd202f
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,400 @@ @@ -0,0 +1,379 @@
+import torch +import torch
+from typing import List, Tuple +from typing import List, Tuple
+from vllm.config import VllmConfig +from vllm.config import VllmConfig
...@@ -919,41 +1047,8 @@ index 00000000..d972252a ...@@ -919,41 +1047,8 @@ index 00000000..d972252a
+ self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
+ for dst_xfer_side_handle in dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values():
+ self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for block_id in block_ids:
+ assert block_id < self.num_blocks, f"Block id {block_id} is greater than the number of blocks {self.num_blocks}"
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id))
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids
+ +
+ def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0, staging_ranges=None):
+ if rank is None:
+ rank = self.rank
+ block_len = self.block_len // tp_multiplier
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ blocks_data = []
+ for layer_id in layer_ids:
+ for range_idx, (range_start, range_end) in enumerate(ranges):
+ range_len = range_end - range_start + 1
+ key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ if staging_ranges is not None:
+ start_offset = staging_ranges[range_idx][0] * self.block_len + i * block_len * (staging_ranges[range_idx][1] - staging_ranges[range_idx][0] + 1) + (range_start - staging_ranges[range_idx][0]) * block_len
+ else:
+ start_offset = range_start * block_len
+ blocks_data.append((key_base_addr + start_offset, range_len * block_len, rank))
+ blocks_data.append((value_base_addr + start_offset, range_len * block_len, rank))
+ return self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+
+ def _get_ranges(self, block_ids): + def _get_ranges(self, block_ids):
+ # This function should return a list of ranges of block ids that are contiguous + # This function should return a list of ranges of block ids that are contiguous
+ # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] + # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]]
...@@ -968,6 +1063,35 @@ index 00000000..d972252a ...@@ -968,6 +1063,35 @@ index 00000000..d972252a
+ ranges[-1][1] = block_ids[i] + ranges[-1][1] = block_ids[i]
+ return ranges + return ranges
+ +
+ def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None):
+
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+
+ descs_ids = []
+
+
+ if i is not None:
+ num_blocks = self.num_blocks
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ staging_range_idx = 0
+ for block_id in block_ids:
+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
+ staging_range_idx += 1
+ start_offset = staging_ranges[staging_range_idx][0]
+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ else:
+ num_blocks = self.dst_num_blocks[engine_id]
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
+ return descs_ids
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False): + def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
+ # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length + # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length
+ # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]] + # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]]
...@@ -1016,136 +1140,119 @@ index 00000000..d972252a ...@@ -1016,136 +1140,119 @@ index 00000000..d972252a
+ if return_original_src_ranges: + if return_original_src_ranges:
+ return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges + return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges
+ return src_overlapping_ranges, dst_overlapping_ranges + return src_overlapping_ranges, dst_overlapping_ranges
+
+ +
+ def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id):
+ logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id)
+ +
+ def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None): + assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids)
+ +
+ if layer_ids == "all": + if len(local_block_ids) == 0:
+ layer_ids = list(range(self.num_layers)) + logger.debug("No blocks to read")
+ if block_ids == "all": + return
+ block_ids = list(range(self.num_blocks))
+ +
+ descs_ids = [] + start_time = time.perf_counter()
+
+ local_ranges = self._get_ranges(local_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ +
+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+ +
+ if i is not None: + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+ num_blocks = self.num_blocks + remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
+ for layer_id in layer_ids: + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+ for is_value in [0, 1]: + handles = []
+ staging_range_idx = 0
+ for block_id in block_ids:
+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
+ staging_range_idx += 1
+ start_offset = staging_ranges[staging_range_idx][0]
+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ else:
+ num_blocks = self.dst_num_blocks[engine_id]
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
+ return descs_ids
+
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+ +
+ if self.use_prepped_xfer: + logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
+ self._transfer_mem_prepped_xfer(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg) + create_xfer_start_time = time.perf_counter()
+ else:
+ self._transfer_mem_create_xfer(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ +
+ def _transfer_mem_prepped_xfer(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg): + for i in range(tp_multiplier):
+ start_time = time.perf_counter() + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
+ remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
+ handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids,
+ remote_xfer_side_handle, remote_block_descs_ids,
+ "")
+ handles.append(handle)
+ status = self.nixl_wrapper.transfer(handle)
+ +
+ # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last + logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
+ # isl[-1] token is calculated in the first iteration in decode.
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ +
+ assert len(staging_block_ids) == len(src_block_ids) + transfer_start_time = time.perf_counter()
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ +
+ src_staging_overlapping_ranges, staging_src_overlapping_ranges = self._get_same_length_ranges(src_ranges, staging_ranges) + for handle in handles:
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE":
+ + if status == "PROC":
+ for src_range, staging_range in zip(src_staging_overlapping_ranges, staging_src_overlapping_ranges): + time.sleep(0.001)
+ logger.debug("Rearranging tensors for cache: %s, src_range: %s, staging_range: %s", self.kv_caches[0].shape, src_range, staging_range) + else:
+ for kv_cache in self.kv_caches: + raise RuntimeError("Read transfer failed with state %s", status)
+ for cache in kv_cache: + # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ rearrange_tensors(cache[src_range[0]:src_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier)
+ +
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000)
+ +
+ # getting block descs ids + rearrange_start_time = time.perf_counter()
+ dst_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", dst_block_ids)
+ src_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+
+ for i in range(tp_multiplier):
+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_src_overlapping_ranges)
+ assert len(staging_block_descs_ids) == len(dst_block_descs_ids)
+ dst_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
+ +
+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
+ +
+ logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
+ handle = self.nixl_wrapper.make_prepped_xfer("WRITE", src_xfer_side_handle, staging_block_descs_ids, + logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
+ dst_xfer_side_handle, dst_block_descs_ids, +
+ notify_msg) + def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg):
+ self._transfers[notify_msg].append(handle) + logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def _transfer_mem_create_xfer(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+ +
+ # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last + # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
+ # isl[-1] token is calculated in the first iteration in decode. + # isl[-1] token is calculated in the first iteration in decode.
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ + # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token. + # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)] + remote_block_ids = remote_block_ids[:len(local_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ +
+ staging_src_overlapping_ranges, src_staging_overlapping_ranges = self._get_same_length_ranges(staging_ranges, src_ranges) + assert len(staging_block_ids) == len(local_block_ids)
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ if len(local_block_ids) == 0:
+ logger.debug("No blocks to write")
+ for i in range(tp_multiplier):
+ self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg)
+ return
+
+ start_time = time.perf_counter()
+
+ local_ranges = self._get_ranges(local_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+
+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+ +
+ for src_range, staging_range in zip(src_staging_overlapping_ranges, staging_src_overlapping_ranges): + for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+ logger.debug("Rearranging tensors for cache: %s, src_range: %s, staging_range: %s", self.kv_caches[0].shape, src_range, staging_range) + logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+ for kv_cache in self.kv_caches: + for kv_cache in self.kv_caches:
+ for cache in kv_cache: + for cache in kv_cache:
+ rearrange_tensors(cache[src_range[0]:src_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier) + rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
+ +
+ staging_overlapping_ranges, dst_overlapping_ranges, original_src_ranges = self._get_same_length_ranges(staging_src_overlapping_ranges, dst_ranges, return_original_src_ranges=True) + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+ +
+ logger.debug("Time to get same length ranges: %s ms", (time.perf_counter() - start_time) * 1000) + create_xfer_start_time = time.perf_counter()
+ +
+ # getting block descs ids
+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+
+ for i in range(tp_multiplier): + for i in range(tp_multiplier):
+ + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i, staging_ranges=original_src_ranges) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i) + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000) + handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids,
+ + remote_xfer_side_handle, remote_block_descs_ids,
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer("WRITE", src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg) + notify_msg)
+ self._transfers[notify_msg].append(handle) + self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle) + status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000) +
+ logger.debug("Transfer status: %s", status) + logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
+
+ transfer_start_time = time.perf_counter()
+ logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000)
+ +
+ def get_notifs(self): + def get_notifs(self):
+ return self.nixl_wrapper.update_notifs() + return self.nixl_wrapper.update_notifs()
...@@ -2572,7 +2679,7 @@ index 321902d1..b8937ef8 100644 ...@@ -2572,7 +2679,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..931784f8 100644 index d82d9ad9..03896aa6 100644
--- a/vllm/engine/llm_engine.py --- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@ @@ -2,13 +2,17 @@
...@@ -2598,7 +2705,7 @@ index d82d9ad9..931784f8 100644 ...@@ -2598,7 +2705,7 @@ index d82d9ad9..931784f8 100644
usage_message) usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
+from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest +from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest, MemoryOpType
+from vllm.distributed.device_communicators.nixl import NixlMetadata +from vllm.distributed.device_communicators.nixl import NixlMetadata
+ +
...@@ -2795,7 +2902,7 @@ index d82d9ad9..931784f8 100644 ...@@ -2795,7 +2902,7 @@ index d82d9ad9..931784f8 100644
# Sanity check # Sanity check
assert len(seq_group_metadata_list) == len( assert len(seq_group_metadata_list) == len(
@@ -1325,15 +1387,49 @@ class LLMEngine: @@ -1325,15 +1387,55 @@ class LLMEngine:
# Clear outputs for each new scheduler iteration # Clear outputs for each new scheduler iteration
ctx.request_outputs.clear() ctx.request_outputs.clear()
...@@ -2836,18 +2943,24 @@ index d82d9ad9..931784f8 100644 ...@@ -2836,18 +2943,24 @@ index d82d9ad9..931784f8 100644
+ assert self._nixl_agents_names + assert self._nixl_agents_names
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id] + block_table = seq_group_metadata.block_tables[seq_id]
+ if len(block_table) == len(seq_group_metadata.computed_block_nums):
+ logger.debug("No blocks to prefill")
+ self._finished_prefills.add(seq_group_metadata.request_id)
+ continue
+ remote_prefill_request = RemotePrefillRequest( + remote_prefill_request = RemotePrefillRequest(
+ request_id=seq_group_metadata.request_id, + request_id=seq_group_metadata.request_id,
+ prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway + # prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway
+ prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids, # TODO ptarasiewicz do not send the last token when NIXL fixes send notif (needed for writing 0 blocks)
+ sampling_params=scheduled_seq_group.seq_group.sampling_params, + sampling_params=scheduled_seq_group.seq_group.sampling_params,
+ block_ids=block_table, + block_ids=block_table,
+ engine_id=self.engine_id, + engine_id=self.engine_id,
+ computed_block_ids=seq_group_metadata.computed_block_nums,
+ ) + )
+ scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request) + scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request)
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1479,31 @@ class LLMEngine: @@ -1383,9 +1485,46 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
...@@ -2862,17 +2975,32 @@ index d82d9ad9..931784f8 100644 ...@@ -2862,17 +2975,32 @@ index d82d9ad9..931784f8 100644
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id] + block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1] + staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest( +
+ num_computed_blocks = len(seq_group_metadata.computed_block_nums)
+ computed_decode_block_ids = remote_prefill_params.decode_block_ids[:num_computed_blocks]
+
+ if computed_decode_block_ids:
+ kv_recv_req = MemoryTransferRequest(
+ request_id=req_id,
+ local_block_ids=block_table[:num_computed_blocks],
+ staging_block_ids=staging_block_ids[:num_computed_blocks],
+ remote_block_ids=computed_decode_block_ids,
+ remote_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id,
+ op_type=MemoryOpType.READ
+ )
+ memory_transfer_reqs.append(kv_recv_req)
+
+ kv_send_req = MemoryTransferRequest(
+ request_id=req_id, + request_id=req_id,
+ src_block_ids=block_table, + local_block_ids=block_table[num_computed_blocks:],
+ staging_block_ids=staging_block_ids, + staging_block_ids=staging_block_ids[num_computed_blocks:],
+ dst_block_ids=remote_prefill_params.decode_block_ids, + remote_block_ids=remote_prefill_params.decode_block_ids[num_computed_blocks:],
+ dst_engine_id=remote_prefill_params.decode_engine_id, + remote_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id, + notify_msg=req_id,
+ op_type=MemoryOpType.WRITE
+ ) + )
+ + memory_transfer_reqs.append(kv_send_req)
+ memory_transfer_reqs.append(memory_transfer_req)
+
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs + execute_model_req.memory_transfer_requests = memory_transfer_reqs
+ +
+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
...@@ -2881,7 +3009,7 @@ index d82d9ad9..931784f8 100644 ...@@ -2881,7 +3009,7 @@ index d82d9ad9..931784f8 100644
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1396,7 +1514,26 @@ class LLMEngine: @@ -1396,7 +1535,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# No outputs in this case # No outputs in this case
...@@ -2909,7 +3037,7 @@ index d82d9ad9..931784f8 100644 ...@@ -2909,7 +3037,7 @@ index d82d9ad9..931784f8 100644
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1456,7 +1593,7 @@ class LLMEngine: @@ -1456,7 +1614,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters. # queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.") logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
...@@ -3583,12 +3711,13 @@ index 786380c3..56a7cf89 100644 ...@@ -3583,12 +3711,13 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request. """The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644 new file mode 100644
index 00000000..957f55de index 00000000..3f9711ef
--- /dev/null --- /dev/null
+++ b/vllm/remote_prefill.py +++ b/vllm/remote_prefill.py
@@ -0,0 +1,54 @@ @@ -0,0 +1,67 @@
+from dataclasses import dataclass +from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine +from typing import Callable, Optional, List
+from enum import Enum
+ +
+import msgspec +import msgspec
+ +
...@@ -3603,14 +3732,24 @@ index 00000000..957f55de ...@@ -3603,14 +3732,24 @@ index 00000000..957f55de
+ """The request data of one remote prefill output of a request. + """The request data of one remote prefill output of a request.
+ +
+ Args: + Args:
+ engine_id: The unique ID of the engine.
+ request_id: The unique ID of the request. + request_id: The unique ID of the request.
+ prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt.
+ sampling_params: The sampling parameters.
+ block_ids: The block IDs of the request.
+ computed_block_ids: The computed block IDs of the request.
+ """ + """
+ engine_id: str
+ request_id: str + request_id: str
+ prompt_token_ids: List[int] + prompt_token_ids: List[int]
+ sampling_params: SamplingParams + sampling_params: SamplingParams
+ block_ids: List[int] + block_ids: List[int]
+ engine_id: str + computed_block_ids: List[int]
+
+
+class MemoryOpType(str, Enum):
+ WRITE = "WRITE"
+ READ = "READ"
+ +
+ +
+class MemoryTransferRequest( +class MemoryTransferRequest(
...@@ -3623,11 +3762,12 @@ index 00000000..957f55de ...@@ -3623,11 +3762,12 @@ index 00000000..957f55de
+ request_id: The unique ID of the request. + request_id: The unique ID of the request.
+ """ + """
+ request_id: str + request_id: str
+ src_block_ids: List[int] + local_block_ids: List[int]
+ staging_block_ids: List[int] + staging_block_ids: List[int]
+ dst_block_ids: List[int] + remote_block_ids: List[int]
+ dst_engine_id: str + remote_engine_id: str
+ notify_msg: str + notify_msg: str
+ op_type: MemoryOpType
+ +
+ +
+RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None] +RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None]
...@@ -3639,6 +3779,7 @@ index 00000000..957f55de ...@@ -3639,6 +3779,7 @@ index 00000000..957f55de
+ is_remote_prefill: bool = False + is_remote_prefill: bool = False
+ is_remote_decode: bool = False + is_remote_decode: bool = False
+ decode_block_ids: Optional[List[int]] = None + decode_block_ids: Optional[List[int]] = None
+ decode_computed_block_ids: Optional[List[int]] = None
+ decode_engine_id: Optional[str] = None + decode_engine_id: Optional[str] = None
+ remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None + remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None
\ No newline at end of file \ No newline at end of file
...@@ -3784,7 +3925,7 @@ index 12baecde..a3f2c464 100644 ...@@ -3784,7 +3925,7 @@ index 12baecde..a3f2c464 100644
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..76c2e6ab 100644 index 582aa460..876329d6 100644
--- a/vllm/worker/worker.py --- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py +++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
...@@ -3796,16 +3937,17 @@ index 582aa460..76c2e6ab 100644 ...@@ -3796,16 +3937,17 @@ index 582aa460..76c2e6ab 100644
import torch import torch
import torch.distributed import torch.distributed
@@ -31,6 +31,8 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -31,6 +31,9 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
+from vllm.distributed.device_communicators.nixl import DynamoNixlConnector +from vllm.distributed.device_communicators.nixl import DynamoNixlConnector
+from vllm.remote_prefill import MemoryOpType
+ +
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -306,6 +308,46 @@ class Worker(LocalOrDistributedWorkerBase): @@ -306,6 +309,46 @@ class Worker(LocalOrDistributedWorkerBase):
self._init_cache_engine() self._init_cache_engine()
self._warm_up_model() self._warm_up_model()
...@@ -3828,22 +3970,22 @@ index 582aa460..76c2e6ab 100644 ...@@ -3828,22 +3970,22 @@ index 582aa460..76c2e6ab 100644
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank? + agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank?
+ return agent_name + return agent_name
+ +
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]: + def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id] + return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id]
+ +
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None: + def _read_blocks(self, worker_input: WorkerInput) -> None:
+ + for i, op_type in enumerate(worker_input.op_type):
+ if op_type == MemoryOpType.READ:
+ self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i])
+
+ def _write_blocks(self, worker_input: WorkerInput) -> None:
+ if not self.is_driver_worker: + if not self.is_driver_worker:
+ torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize + torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize
+ +
+ if worker_input.src_block_ids is not None: + for i, op_type in enumerate(worker_input.op_type):
+ for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg): + if op_type == MemoryOpType.WRITE:
+ self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg) + self.nixl_connector.write_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i], worker_input.notify_msg[i])
+ +
+ def shutdown_nixl(self) -> None: + def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
...@@ -3852,7 +3994,7 @@ index 582aa460..76c2e6ab 100644 ...@@ -3852,7 +3994,7 @@ index 582aa460..76c2e6ab 100644
def _init_cache_engine(self): def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [ self.cache_engine = [
@@ -367,6 +409,8 @@ class Worker(LocalOrDistributedWorkerBase): @@ -367,6 +410,8 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device, device=self.device,
dtype=torch.int64).view(-1, 2) dtype=torch.int64).view(-1, 2)
...@@ -3861,20 +4003,21 @@ index 582aa460..76c2e6ab 100644 ...@@ -3861,20 +4003,21 @@ index 582aa460..76c2e6ab 100644
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
@@ -375,6 +419,11 @@ class Worker(LocalOrDistributedWorkerBase): @@ -375,6 +420,12 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_steps=num_steps, num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs], + local_block_ids=[r.local_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs], + staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs], + remote_block_ids=[r.remote_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs], + remote_engine_id=[r.remote_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs], + notify_msg=[r.notify_msg for r in mem_transfer_reqs],
+ op_type=[r.op_type for r in mem_transfer_reqs],
) )
@torch.inference_mode() @torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..ff43dadc 100644 index 819b81fb..2891854b 100644
--- a/vllm/worker/worker_base.py --- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
...@@ -3885,15 +4028,17 @@ index 819b81fb..ff43dadc 100644 ...@@ -3885,15 +4028,17 @@ index 819b81fb..ff43dadc 100644
from vllm.config import (ObservabilityConfig, VllmConfig, from vllm.config import (ObservabilityConfig, VllmConfig,
set_current_vllm_config) set_current_vllm_config)
@@ -23,6 +24,7 @@ from vllm.utils import (enable_trace_function_call_for_thread, @@ -23,6 +24,9 @@ from vllm.utils import (enable_trace_function_call_for_thread,
from vllm.worker.model_runner_base import (BroadcastableModelInput, from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
+from vllm.distributed.device_communicators.nixl import DynamoNixlConnector +from vllm.distributed.device_communicators.nixl import DynamoNixlConnector
+from vllm.remote_prefill import MemoryOpType
+
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -53,6 +55,8 @@ class WorkerBase(ABC): @@ -53,6 +57,8 @@ class WorkerBase(ABC):
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.current_platform = current_platform self.current_platform = current_platform
...@@ -3902,44 +4047,47 @@ index 819b81fb..ff43dadc 100644 ...@@ -3902,44 +4047,47 @@ index 819b81fb..ff43dadc 100644
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,12 @@ class WorkerInput: @@ -216,6 +222,13 @@ class WorkerInput:
virtual_engine: int = 0 virtual_engine: int = 0
num_steps: int = 1 num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None + local_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None + staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None + remote_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None + remote_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None + notify_msg: Optional[List[str]] = None
+ op_type: Optional[List[MemoryOpType]] = None
+ +
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"], cls: Type["WorkerInput"],
@@ -232,6 +242,11 @@ class WorkerInput: @@ -232,6 +245,12 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"], virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"), num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"), + local_block_ids=tensor_dict.pop("local_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"), + staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"), + remote_block_ids=tensor_dict.pop("remote_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"), + remote_engine_id=tensor_dict.pop("remote_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"), + notify_msg=tensor_dict.pop("notify_msg"),
+ op_type=tensor_dict.pop("op_type"),
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
@@ -246,6 +261,11 @@ class WorkerInput: @@ -246,6 +265,12 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"num_steps": self.num_steps, "num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids, + "local_block_ids": self.local_block_ids,
+ "staging_block_ids": self.staging_block_ids, + "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids, + "remote_block_ids": self.remote_block_ids,
+ "dst_engine_id": self.dst_engine_id, + "remote_engine_id": self.remote_engine_id,
+ "notify_msg": self.notify_msg, + "notify_msg": self.notify_msg,
+ "op_type": self.op_type,
} }
return tensor_dict return tensor_dict
@@ -316,13 +336,16 @@ class LocalOrDistributedWorkerBase(WorkerBase): @@ -316,13 +341,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
...@@ -3961,7 +4109,7 @@ index 819b81fb..ff43dadc 100644 ...@@ -3961,7 +4109,7 @@ index 819b81fb..ff43dadc 100644
def _get_driver_input_and_broadcast( def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
@@ -396,49 +419,87 @@ class LocalOrDistributedWorkerBase(WorkerBase): @@ -396,49 +424,88 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input) self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
...@@ -3976,6 +4124,8 @@ index 819b81fb..ff43dadc 100644 ...@@ -3976,6 +4124,8 @@ index 819b81fb..ff43dadc 100644
- all_gather_group=get_tp_group())) - all_gather_group=get_tp_group()))
+ if worker_input.num_seq_groups > 0: + if worker_input.num_seq_groups > 0:
+ +
+ self._read_blocks(worker_input)
+
+ intermediate_tensors = None + intermediate_tensors = None
+ orig_model_execute_time = 0.0 + orig_model_execute_time = 0.0
+ if not get_pp_group().is_first_rank: + if not get_pp_group().is_first_rank:
...@@ -4011,12 +4161,7 @@ index 819b81fb..ff43dadc 100644 ...@@ -4011,12 +4161,7 @@ index 819b81fb..ff43dadc 100644
- and self.observability_config.collect_model_execute_time): - and self.observability_config.collect_model_execute_time):
- orig_model_execute_time = intermediate_tensors.tensors.get( - orig_model_execute_time = intermediate_tensors.tensors.get(
- "model_execute_time", torch.tensor(0)).item() - "model_execute_time", torch.tensor(0)).item()
+ and self.observability_config.collect_model_execute_time -
+ and output is not None):
+ for o in output:
+ o.model_execute_time = (orig_model_execute_time +
+ model_execute_time)
- output = self.model_runner.execute_model( - output = self.model_runner.execute_model(
- model_input=model_input, - model_input=model_input,
- kv_caches=self.kv_cache[worker_input.virtual_engine] - kv_caches=self.kv_cache[worker_input.virtual_engine]
...@@ -4025,7 +4170,11 @@ index 819b81fb..ff43dadc 100644 ...@@ -4025,7 +4170,11 @@ index 819b81fb..ff43dadc 100644
- num_steps=num_steps, - num_steps=num_steps,
- **kwargs, - **kwargs,
- ) - )
+ self._transfer_blocks(worker_input) + and self.observability_config.collect_model_execute_time
+ and output is not None):
+ for o in output:
+ o.model_execute_time = (orig_model_execute_time +
+ model_execute_time)
- model_execute_time = time.perf_counter() - start_time - model_execute_time = time.perf_counter() - start_time
- if not get_pp_group().is_last_rank: - if not get_pp_group().is_last_rank:
...@@ -4044,6 +4193,8 @@ index 819b81fb..ff43dadc 100644 ...@@ -4044,6 +4193,8 @@ index 819b81fb..ff43dadc 100644
- for o in output: - for o in output:
- o.model_execute_time = (orig_model_execute_time + - o.model_execute_time = (orig_model_execute_time +
- model_execute_time) - model_execute_time)
+ self._write_blocks(worker_input)
+ else: + else:
+ output = [] + output = []
+ +
...@@ -4058,7 +4209,7 @@ index 819b81fb..ff43dadc 100644 ...@@ -4058,7 +4209,7 @@ index 819b81fb..ff43dadc 100644
+ else: + else:
+ for i in range(1, get_tp_group().world_size): + for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i)) + all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int) + request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs: + for notifs in all_new_notifs:
+ for req_ids in notifs.values(): + for req_ids in notifs.values():
...@@ -4071,18 +4222,17 @@ index 819b81fb..ff43dadc 100644 ...@@ -4071,18 +4222,17 @@ index 819b81fb..ff43dadc 100644
+ request_done_counter = defaultdict(int) + request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers(): + for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1 + request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else: + else:
+ request_notif_counter = {} + request_notif_counter = {}
+ request_done_counter = {} + request_done_counter = {}
# output is List[SamplerOutput] # output is List[SamplerOutput]
- return output - return output
+ return output, request_notif_counter, request_done_counter + return output, request_notif_counter, request_done_counter
+ +
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None: + def _read_blocks(self, worker_input: WorkerInput) -> None:
+ pass
+
+ def _write_blocks(self, worker_input: WorkerInput) -> None:
+ pass + pass
def _execute_model_spmd( def _execute_model_spmd(
......
...@@ -64,6 +64,12 @@ class PrefillWorker: ...@@ -64,6 +64,12 @@ class PrefillWorker:
print("Prefill must be done eagerly, setting to True") print("Prefill must be done eagerly, setting to True")
self.engine_args.enforce_eager = True self.engine_args.enforce_eager = True
if self.engine_args.enable_prefix_caching is not False:
print(
"Prefix caching is not supported yet in prefill worker, setting to False"
)
self.engine_args.enable_prefix_caching = False
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args( self._engine_context = build_async_engine_client_from_engine_args(
...@@ -115,6 +121,7 @@ class PrefillWorker: ...@@ -115,6 +121,7 @@ class PrefillWorker:
is_remote_decode=True, is_remote_decode=True,
decode_block_ids=request.block_ids, decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id, decode_engine_id=request.engine_id,
decode_computed_block_ids=request.computed_block_ids,
) )
# TODO check if metadata has changed # TODO check if metadata has changed
......
...@@ -30,22 +30,25 @@ Router: ...@@ -30,22 +30,25 @@ Router:
VllmWorker: VllmWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}' kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64 block-size: 64
max-model-len: 16384 max-model-len: 16384
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
remote-prefill: true
conditional-disagg: true conditional-disagg: true
max-local-prefill-length: 10 max-local-prefill-length: 10
max-prefill-queue-size: 2 max-prefill-queue-size: 2
tensor-parallel-size: 1 tensor-parallel-size: 1
router: kv router: kv
enable-prefix-caching: true enable-prefix-caching: true
ServiceArgs:
workers: 1
resources:
gpu: 1
# TODO - set all of these but model as default # TODO - set all of these but model as default
PrefillWorker: PrefillWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}' kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64 block-size: 64
max-model-len: 16384 max-model-len: 16384
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment