"examples/python_rs/vscode:/vscode.git/clone" did not exist on "d99b188da6684230fe2fe43fbf91d0e3180ac3fa"
Commit 7567620f authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

fix: vLLM disagg fix incorrect block ids order (#63)

Co-authored-by: ptarasiewicz@nvidia.com <Piotr Tarasiewicz>
parent cbd20c30
...@@ -144,17 +144,46 @@ index 359b5b26..d52ee050 100644 ...@@ -144,17 +144,46 @@ index 359b5b26..d52ee050 100644
self._swap_mapping: Dict[int, int] = {} self._swap_mapping: Dict[int, int] = {}
self._null_block: Optional[Block] = None self._null_block: Optional[Block] = None
diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py
index c388366b..c1883736 100644 index c388366b..31ed7aa4 100644
--- a/vllm/core/block/naive_block.py --- a/vllm/core/block/naive_block.py
+++ b/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py
@@ -135,6 +135,7 @@ class NaiveBlockAllocator(BlockAllocator): @@ -2,7 +2,7 @@
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
-
+import heapq
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
@@ -38,7 +38,7 @@ class NaiveBlockAllocator(BlockAllocator):
if block_ids is None:
block_ids = range(num_blocks)
- self._free_block_indices: Deque[BlockId] = deque(block_ids)
+ self._free_block_indices: List[BlockId] = list(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
@@ -134,7 +134,8 @@ class NaiveBlockAllocator(BlockAllocator):
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError() raise BlockAllocator.NoFreeBlocksError()
block_id = self._free_block_indices.popleft() - block_id = self._free_block_indices.popleft()
+ block_id = heapq.heappop(self._free_block_indices)
+ # TODO: figure out why sometime block_id is None + # TODO: figure out why sometime block_id is None
self._refcounter.incr(block_id) self._refcounter.incr(block_id)
return block_id return block_id
@@ -148,7 +149,7 @@ class NaiveBlockAllocator(BlockAllocator):
refcount = self._refcounter.decr(block_id)
if refcount == 0:
- self._free_block_indices.appendleft(block_id)
+ heapq.heappush(self._free_block_indices, block_id)
def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py
index 1ca9e49d..b1591c0c 100644 index 1ca9e49d..b1591c0c 100644
--- a/vllm/core/block/prefix_caching_block.py --- a/vllm/core/block/prefix_caching_block.py
...@@ -378,7 +407,7 @@ index 00000000..d3706700 ...@@ -378,7 +407,7 @@ index 00000000..d3706700
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..4d299d7f 100644 index f507847a..9e6443bf 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py
@@ -4,22 +4,22 @@ import enum @@ -4,22 +4,22 @@ import enum
...@@ -572,7 +601,17 @@ index f507847a..4d299d7f 100644 ...@@ -572,7 +601,17 @@ index f507847a..4d299d7f 100644
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
@@ -1087,10 +1152,12 @@ class Scheduler: @@ -1066,6 +1131,9 @@ class Scheduler:
for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
+ for seq_group in self.prefill_sending:
+ budget.add_num_seqs(seq_group.request_id,
+ seq_group.get_max_num_running_seqs())
curr_loras = set(
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
@@ -1087,10 +1155,12 @@ class Scheduler:
# Don't schedule decodes if prefills are scheduled. # Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills. # only contains decode requests, not chunked prefills.
...@@ -587,7 +626,7 @@ index f507847a..4d299d7f 100644 ...@@ -587,7 +626,7 @@ index f507847a..4d299d7f 100644
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
@@ -1106,7 +1173,12 @@ class Scheduler: @@ -1106,7 +1176,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
if len(prefills.seq_groups) > 0: if len(prefills.seq_groups) > 0:
...@@ -601,7 +640,7 @@ index f507847a..4d299d7f 100644 ...@@ -601,7 +640,7 @@ index f507847a..4d299d7f 100644
self.running.extend(running_scheduled.decode_seq_groups_list) self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1320,14 @@ class Scheduler: @@ -1248,12 +1323,14 @@ class Scheduler:
len(running_scheduled.swapped_out)), len(running_scheduled.swapped_out)),
) )
...@@ -618,7 +657,7 @@ index f507847a..4d299d7f 100644 ...@@ -618,7 +657,7 @@ index f507847a..4d299d7f 100644
def _can_append_slots(self, seq_group: SequenceGroup, def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool: enable_chunking: bool) -> bool:
@@ -1287,14 +1361,16 @@ class Scheduler: @@ -1287,14 +1364,16 @@ class Scheduler:
return no_single_seq return no_single_seq
def schedule( def schedule(
...@@ -638,7 +677,7 @@ index f507847a..4d299d7f 100644 ...@@ -638,7 +677,7 @@ index f507847a..4d299d7f 100644
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching: if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1409,8 @@ class Scheduler: @@ -1333,7 +1412,8 @@ class Scheduler:
encoder_seq_data = None encoder_seq_data = None
cross_block_table = None cross_block_table = None
...@@ -648,7 +687,7 @@ index f507847a..4d299d7f 100644 ...@@ -648,7 +687,7 @@ index f507847a..4d299d7f 100644
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,9 +1441,16 @@ class Scheduler: @@ -1364,9 +1444,16 @@ class Scheduler:
< seqs[0].data.get_len()): < seqs[0].data.get_len()):
do_sample = False do_sample = False
...@@ -665,7 +704,7 @@ index f507847a..4d299d7f 100644 ...@@ -665,7 +704,7 @@ index f507847a..4d299d7f 100644
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,
@@ -1392,6 +1476,7 @@ class Scheduler: @@ -1392,6 +1479,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
...@@ -673,7 +712,7 @@ index f507847a..4d299d7f 100644 ...@@ -673,7 +712,7 @@ index f507847a..4d299d7f 100644
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for
@@ -1490,11 +1575,17 @@ class Scheduler: @@ -1490,11 +1578,17 @@ class Scheduler:
self._async_stopped.clear() self._async_stopped.clear()
...@@ -764,10 +803,10 @@ index 00000000..9b938039 ...@@ -764,10 +803,10 @@ index 00000000..9b938039
\ No newline at end of file \ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 00000000..87020367 index 00000000..f1459cf9
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,405 @@ @@ -0,0 +1,404 @@
+import torch +import torch
+from typing import List, Tuple +from typing import List, Tuple
+from vllm.config import VllmConfig +from vllm.config import VllmConfig
...@@ -888,28 +927,24 @@ index 00000000..87020367 ...@@ -888,28 +927,24 @@ index 00000000..87020367
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1) + descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids + return descs_ids
+ +
+ def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0): + def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0, staging_ranges=None):
+ if rank is None: + if rank is None:
+ rank = self.rank + rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier + block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i) + logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ if layer_ids == "all": + if layer_ids == "all":
+ layer_ids = list(range(self.num_layers)) + layer_ids = list(range(self.num_layers))
+ blocks_data = [] + blocks_data = []
+ for layer_id in layer_ids: + for layer_id in layer_ids:
+ for range_start, range_end in ranges: + for range_idx, (range_start, range_end) in enumerate(ranges):
+ range_len = range_end - range_start + 1 + range_len = range_end - range_start + 1
+ key_base_addr, value_base_addr = kv_caches_base_addr[layer_id] + key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ start_offset = range_start * offset_block_len + tp_offset * range_len + if staging_ranges is not None:
+ blocks_len = range_len * block_len + start_offset = staging_ranges[range_idx][0] * self.block_len + i * block_len * (staging_ranges[range_idx][1] - staging_ranges[range_idx][0] + 1) + (range_start - staging_ranges[range_idx][0]) * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank)) + else:
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank)) + start_offset = range_start * block_len
+ blocks_data.append((key_base_addr + start_offset, range_len * block_len, rank))
+ blocks_data.append((value_base_addr + start_offset, range_len * block_len, rank))
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data)) + return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ +
+ def _get_ranges(self, block_ids): + def _get_ranges(self, block_ids):
...@@ -918,21 +953,23 @@ index 00000000..87020367 ...@@ -918,21 +953,23 @@ index 00000000..87020367
+ # The ranges are sorted by the starting block id + # The ranges are sorted by the starting block id
+ # The function should also make sure that the block ids are contiguous + # The function should also make sure that the block ids are contiguous
+ # If the block ids are not contiguous, the function should raise an error + # If the block ids are not contiguous, the function should raise an error
+ sorted_block_ids = sorted(block_ids)
+ ranges = [] + ranges = []
+ for i in range(len(sorted_block_ids)): + for i in range(len(block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1: + if i == 0 or block_ids[i] != block_ids[i-1] + 1:
+ ranges.append([sorted_block_ids[i], sorted_block_ids[i]]) + ranges.append([block_ids[i], block_ids[i]])
+ else: + else:
+ ranges[-1][1] = sorted_block_ids[i] + ranges[-1][1] = block_ids[i]
+ return ranges + return ranges
+ +
+ def _get_same_length_ranges(self, src_ranges, dst_ranges): + def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
+ # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length + # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length
+ # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]] + # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]]
+ # The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]]) + # The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]])
+ src_overlapping_ranges, dst_overlapping_ranges = [], [] + src_overlapping_ranges, dst_overlapping_ranges = [], []
+ +
+ original_src_ranges = []
+ org_src_range = tuple(src_ranges[0])
+
+ src_idx, dst_idx = 0, 0 + src_idx, dst_idx = 0, 0
+ while src_idx < len(src_ranges) and dst_idx < len(dst_ranges): + while src_idx < len(src_ranges) and dst_idx < len(dst_ranges):
+ src_range = src_ranges[src_idx] + src_range = src_ranges[src_idx]
...@@ -946,12 +983,16 @@ index 00000000..87020367 ...@@ -946,12 +983,16 @@ index 00000000..87020367
+ if src_len == dst_len: + if src_len == dst_len:
+ src_overlapping_ranges.append([src_range[0], src_range[-1]]) + src_overlapping_ranges.append([src_range[0], src_range[-1]])
+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) + dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
+ original_src_ranges.append(org_src_range)
+ src_idx += 1 + src_idx += 1
+ dst_idx += 1 + dst_idx += 1
+ if src_idx < len(src_ranges):
+ org_src_range = tuple(src_ranges[src_idx])
+ # If source range is longer, split it + # If source range is longer, split it
+ elif src_len > dst_len: + elif src_len > dst_len:
+ src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1]) + src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1])
+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) + dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
+ original_src_ranges.append(org_src_range)
+ # Update source range for next iteration + # Update source range for next iteration
+ src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]] + src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]]
+ dst_idx += 1 + dst_idx += 1
...@@ -959,15 +1000,19 @@ index 00000000..87020367 ...@@ -959,15 +1000,19 @@ index 00000000..87020367
+ else: # src_len < dst_len + else: # src_len < dst_len
+ src_overlapping_ranges.append([src_range[0], src_range[-1]]) + src_overlapping_ranges.append([src_range[0], src_range[-1]])
+ dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1]) + dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1])
+ original_src_ranges.append(org_src_range)
+ # Update destination range for next iteration + # Update destination range for next iteration
+ dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]] + dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]]
+ src_idx += 1 + src_idx += 1
+ + if src_idx < len(src_ranges):
+ org_src_range = tuple(src_ranges[src_idx])
+ if return_original_src_ranges:
+ return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges
+ return src_overlapping_ranges, dst_overlapping_ranges + return src_overlapping_ranges, dst_overlapping_ranges
+ +
+ +
+ +
+ def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_range=None): + def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None):
+ +
+ if layer_ids == "all": + if layer_ids == "all":
+ layer_ids = list(range(self.num_layers)) + layer_ids = list(range(self.num_layers))
...@@ -976,13 +1021,17 @@ index 00000000..87020367 ...@@ -976,13 +1021,17 @@ index 00000000..87020367
+ +
+ descs_ids = [] + descs_ids = []
+ +
+
+ if i is not None: + if i is not None:
+ num_blocks = self.num_blocks + num_blocks = self.num_blocks
+ start_offset = staging_range[0]
+ i_offset = i * (staging_range[-1] - start_offset + 1)
+ for layer_id in layer_ids: + for layer_id in layer_ids:
+ for is_value in [0, 1]: + for is_value in [0, 1]:
+ staging_range_idx = 0
+ for block_id in block_ids: + for block_id in block_ids:
+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
+ staging_range_idx += 1
+ start_offset = staging_ranges[staging_range_idx][0]
+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset)) + descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ else: + else:
+ num_blocks = self.dst_num_blocks[engine_id] + num_blocks = self.dst_num_blocks[engine_id]
...@@ -1009,24 +1058,19 @@ index 00000000..87020367 ...@@ -1009,24 +1058,19 @@ index 00000000..87020367
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ + # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token. + # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)] + dst_block_ids = dst_block_ids[:len(src_block_ids)]
+
+ assert len(staging_block_ids) == len(src_block_ids) + assert len(staging_block_ids) == len(src_block_ids)
+ src_ranges = self._get_ranges(src_block_ids) + src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids) + staging_ranges = self._get_ranges(staging_block_ids)
+ +
+ assert len(src_ranges) == 1 + src_staging_overlapping_ranges, staging_src_overlapping_ranges = self._get_same_length_ranges(src_ranges, staging_ranges)
+ assert len(staging_ranges) == 1
+
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+ +
+ src_range_start, src_range_end = src_ranges[0] + for src_range, staging_range in zip(src_staging_overlapping_ranges, staging_src_overlapping_ranges):
+ src_range_len = src_range_end - src_range_start + 1 + logger.debug("Rearranging tensors for cache: %s, src_range: %s, staging_range: %s", self.kv_caches[0].shape, src_range, staging_range)
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches: + for kv_cache in self.kv_caches:
+ for cache in kv_cache: + for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier) + rearrange_tensors(cache[src_range[0]:src_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier)
+ +
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
+ +
...@@ -1035,7 +1079,7 @@ index 00000000..87020367 ...@@ -1035,7 +1079,7 @@ index 00000000..87020367
+ src_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + src_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+ +
+ for i in range(tp_multiplier): + for i in range(tp_multiplier):
+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_range=staging_ranges[0]) + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_src_overlapping_ranges)
+ assert len(staging_block_descs_ids) == len(dst_block_descs_ids) + assert len(staging_block_descs_ids) == len(dst_block_descs_ids)
+ dst_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] + dst_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
+ +
...@@ -1065,29 +1109,23 @@ index 00000000..87020367 ...@@ -1065,29 +1109,23 @@ index 00000000..87020367
+ staging_ranges = self._get_ranges(staging_block_ids) + staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids) + dst_ranges = self._get_ranges(dst_block_ids)
+ +
+ assert len(src_ranges) == 1 + staging_src_overlapping_ranges, src_staging_overlapping_ranges = self._get_same_length_ranges(staging_ranges, src_ranges)
+ assert len(staging_ranges) == 1
+
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+ +
+ src_range_start, src_range_end = src_ranges[0] + for src_range, staging_range in zip(src_staging_overlapping_ranges, staging_src_overlapping_ranges):
+ src_range_len = src_range_end - src_range_start + 1 + logger.debug("Rearranging tensors for cache: %s, src_range: %s, staging_range: %s", self.kv_caches[0].shape, src_range, staging_range)
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches: + for kv_cache in self.kv_caches:
+ for cache in kv_cache: + for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier) + rearrange_tensors(cache[src_range[0]:src_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier)
+ +
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges) + staging_overlapping_ranges, dst_overlapping_ranges, original_src_ranges = self._get_same_length_ranges(staging_src_overlapping_ranges, dst_ranges, return_original_src_ranges=True)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges) + assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+ +
+ logger.debug("Time to get same length ranges: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Time to get same length ranges: %s ms", (time.perf_counter() - start_time) * 1000)
+ +
+ for i in range(tp_multiplier): + for i in range(tp_multiplier):
+ +
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i) + src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i, staging_ranges=original_src_ranges)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i) + dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+ +
...@@ -2531,7 +2569,7 @@ index 321902d1..b8937ef8 100644 ...@@ -2531,7 +2569,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..53cace75 100644 index d82d9ad9..931784f8 100644
--- a/vllm/engine/llm_engine.py --- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@ @@ -2,13 +2,17 @@
...@@ -2624,7 +2662,7 @@ index d82d9ad9..53cace75 100644 ...@@ -2624,7 +2662,7 @@ index d82d9ad9..53cace75 100644
+ +
+ @property + @property
+ def is_nixl_initialized(self) -> bool: + def is_nixl_initialized(self) -> bool:
+ return self._nixl_agents_names is not None + return getattr(self, "_nixl_agents_names", None) is not None
+ +
+ def get_nixl_metadata(self) -> NixlMetadata: + def get_nixl_metadata(self) -> NixlMetadata:
+ if not self.is_nixl_initialized: + if not self.is_nixl_initialized:
...@@ -3680,7 +3718,7 @@ index 12baecde..a3f2c464 100644 ...@@ -3680,7 +3718,7 @@ index 12baecde..a3f2c464 100644
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..36a21d10 100644 index 582aa460..76c2e6ab 100644
--- a/vllm/worker/worker.py --- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py +++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
...@@ -3701,7 +3739,7 @@ index 582aa460..36a21d10 100644 ...@@ -3701,7 +3739,7 @@ index 582aa460..36a21d10 100644
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -306,6 +308,42 @@ class Worker(LocalOrDistributedWorkerBase): @@ -306,6 +308,46 @@ class Worker(LocalOrDistributedWorkerBase):
self._init_cache_engine() self._init_cache_engine()
self._warm_up_model() self._warm_up_model()
...@@ -3733,6 +3771,10 @@ index 582aa460..36a21d10 100644 ...@@ -3733,6 +3771,10 @@ index 582aa460..36a21d10 100644
+ return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id] + return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id]
+ +
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None: + def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+
+ if not self.is_driver_worker:
+ torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize
+
+ if worker_input.src_block_ids is not None: + if worker_input.src_block_ids is not None:
+ for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg): + for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg) + self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
...@@ -3744,7 +3786,7 @@ index 582aa460..36a21d10 100644 ...@@ -3744,7 +3786,7 @@ index 582aa460..36a21d10 100644
def _init_cache_engine(self): def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [ self.cache_engine = [
@@ -367,6 +405,8 @@ class Worker(LocalOrDistributedWorkerBase): @@ -367,6 +409,8 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device, device=self.device,
dtype=torch.int64).view(-1, 2) dtype=torch.int64).view(-1, 2)
...@@ -3753,7 +3795,7 @@ index 582aa460..36a21d10 100644 ...@@ -3753,7 +3795,7 @@ index 582aa460..36a21d10 100644
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
@@ -375,6 +415,11 @@ class Worker(LocalOrDistributedWorkerBase): @@ -375,6 +419,11 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_steps=num_steps, num_steps=num_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment