Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
bc42616e
Commit
bc42616e
authored
Mar 07, 2025
by
ptarasiewiczNV
Committed by
GitHub
Mar 06, 2025
Browse files
feat: Enable make_xfer NIXL kv transfer (#39)
Co-authored-by: ptarasiewicz@nvidia.com <Piotr Tarasiewicz>
parent
09656f6c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
252 additions
and
107 deletions
+252
-107
container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch
container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch
+252
-107
No files found.
container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch
View file @
bc42616e
diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..
cbfeb715
100644
index 9ba49757..
3ec4bbab
100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2629,7 +2629,7 @@
class KVTransferConfig(BaseModel):
@@ -2620,6 +2620,9 @@
class KVTransferConfig(BaseModel):
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None
+ # Whether to use NIXL prepped xfer for KV cache transfer.
+ use_prepped_xfer: bool = False
+
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
@@ -2629,7 +2632,7 @@
class KVTransferConfig(BaseModel):
kv_buffer_size: float = 1e9
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
...
...
@@ -11,7 +21,7 @@ index 9ba49757..cbfeb715 100644
kv_role: Optional[str] = None
# The rank of this vLLM instance in the KV cache transfer. Typical value:
@@ -2647,6 +26
47
,14 @@
class KVTransferConfig(BaseModel):
@@ -2647,6 +26
50
,14 @@
class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection
kv_port: int = 14579
...
...
@@ -26,7 +36,7 @@ index 9ba49757..cbfeb715 100644
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@@ -2680,11 +26
88
,12 @@
class KVTransferConfig(BaseModel):
@@ -2680,11 +26
91
,12 @@
class KVTransferConfig(BaseModel):
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")
...
...
@@ -40,7 +50,7 @@ index 9ba49757..cbfeb715 100644
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
@@ -2694,6 +270
3
,8 @@
class KVTransferConfig(BaseModel):
@@ -2694,6 +270
6
,8 @@
class KVTransferConfig(BaseModel):
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
...
...
@@ -49,7 +59,7 @@ index 9ba49757..cbfeb715 100644
return self.kv_connector is not None and self.kv_parallel_size > 1
@property
@@ -2706,6 +27
17
,18 @@
class KVTransferConfig(BaseModel):
@@ -2706,6 +27
20
,18 @@
class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
...
...
@@ -368,7 +378,7 @@ index 00000000..8699ca06
+
+ self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..
abe574d1
100644
index f507847a..
4d299d7f
100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -4,22 +4,22 @@
import enum
...
...
@@ -398,7 +408,23 @@ index f507847a..abe574d1 100644
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
@@ -325,12 +325,14 @@
class Scheduler:
@@ -285,6 +285,7 @@
class SchedulerPrefillOutputs:
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
+ num_remote_prefill_groups: int
@classmethod
def create_empty(cls) -> "SchedulerPrefillOutputs":
@@ -292,6 +293,7 @@
class SchedulerPrefillOutputs:
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=0,
+ num_remote_prefill_groups=0,
)
@@ -325,12 +327,14 @@
class Scheduler:
def __init__(
self,
...
...
@@ -413,7 +439,7 @@ index f507847a..abe574d1 100644
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +3
58
,7 @@
class Scheduler:
@@ -356,6 +3
60
,7 @@
class Scheduler:
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
...
...
@@ -421,7 +447,7 @@ index f507847a..abe574d1 100644
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +37
4
,16 @@
class Scheduler:
@@ -371,6 +37
6
,16 @@
class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
...
...
@@ -438,7 +464,7 @@ index f507847a..abe574d1 100644
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
@@ -501,7 +51
4
,7 @@
class Scheduler:
@@ -501,7 +51
6
,7 @@
class Scheduler:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
...
...
@@ -447,7 +473,7 @@ index f507847a..abe574d1 100644
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +53
6
,8 @@
class Scheduler:
@@ -523,6 +53
8
,8 @@
class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
...
...
@@ -456,7 +482,7 @@ index f507847a..abe574d1 100644
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@@ -537,6 +55
2
,8 @@
class Scheduler:
@@ -537,6 +55
4
,8 @@
class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
...
...
@@ -465,7 +491,7 @@ index f507847a..abe574d1 100644
Returns:
SchedulerRunningOutputs.
@@ -566,6 +58
3
,38 @@
class Scheduler:
@@ -566,6 +58
5
,38 @@
class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
...
...
@@ -504,7 +530,15 @@ index f507847a..abe574d1 100644
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@@ -1008,7 +1057,17 @@
class Scheduler:
@@ -925,6 +976,7 @@
class Scheduler:
seq_groups: List[ScheduledSequenceGroup] = []
waiting_queue = self.waiting
+ num_remote_prefill_groups = 0
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
@@ -1008,7 +1060,18 @@
class Scheduler:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
...
...
@@ -515,7 +549,8 @@ index f507847a..abe574d1 100644
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
+ is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group)
+ num_remote_prefill_groups += is_remote_prefill
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
...
...
@@ -523,17 +558,26 @@ index f507847a..abe574d1 100644
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1107,7 @@
class Scheduler:
@@ -1046,9 +1109,11 @@
class Scheduler:
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
- is_prefill=True, enable_chunking=enable_chunking))
+ is_prefill=True, enable_chunking=enable_chunking),
+ num_remote_prefill_groups=num_remote_prefill_groups
+ )
- def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1149,9 @@
class Scheduler:
if len(prefills.seq_groups) == 0:
@@ -1087,10 +1152,12 @@
class Scheduler:
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
- if len(prefills.seq_groups) == 0:
+ if len(prefills.seq_groups) == prefills.num_remote_prefill_groups:
running_scheduled = self._schedule_running(budget,
curr_loras,
- enable_chunking=False)
...
...
@@ -543,7 +587,7 @@ index f507847a..abe574d1 100644
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
@@ -1106,7 +11
6
7,12 @@
class Scheduler:
@@ -1106,7 +117
3
,12 @@
class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
...
...
@@ -557,7 +601,7 @@ index f507847a..abe574d1 100644
self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +13
14
,14 @@
class Scheduler:
@@ -1248,12 +13
20
,14 @@
class Scheduler:
len(running_scheduled.swapped_out)),
)
...
...
@@ -574,7 +618,7 @@ index f507847a..abe574d1 100644
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
@@ -1287,14 +13
55
,16 @@
class Scheduler:
@@ -1287,14 +13
61
,16 @@
class Scheduler:
return no_single_seq
def schedule(
...
...
@@ -594,7 +638,7 @@ index f507847a..abe574d1 100644
now = time.time()
if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +140
3
,8 @@
class Scheduler:
@@ -1333,7 +140
9
,8 @@
class Scheduler:
encoder_seq_data = None
cross_block_table = None
...
...
@@ -604,7 +648,7 @@ index f507847a..abe574d1 100644
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,9 +14
35
,16 @@
class Scheduler:
@@ -1364,9 +14
41
,16 @@
class Scheduler:
< seqs[0].data.get_len()):
do_sample = False
...
...
@@ -621,7 +665,7 @@ index f507847a..abe574d1 100644
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +147
0
,7 @@
class Scheduler:
@@ -1392,6 +147
6
,7 @@
class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
...
...
@@ -629,22 +673,27 @@ index f507847a..abe574d1 100644
)
else:
# When SPMD mode is enabled, we only send delta data except for
@@ -1490,1
0
+15
69
,1
3
@@
class Scheduler:
@@ -1490,1
1
+15
75
,1
7
@@
class Scheduler:
self._async_stopped.clear()
- def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
+ def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) ->
None
:
+ def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) ->
bool
:
self.block_manager.allocate(seq_group)
+ is_remote_prefill = False
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
- seq.status = SequenceStatus.RUNNING
-
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ seq.status = SequenceStatus.REMOTE_PREFILLING
+ is_remote_prefill = True
+ else:
+ seq.status = SequenceStatus.RUNNING
+ return is_remote_prefill
+
def _append_slots(self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
...
...
@@ -715,10 +764,10 @@ index 00000000..9b938039
\
No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..
86248e7b
index 00000000..
523d58d4
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,
318
@@
@@ -0,0 +1,
405
@@
+import torch
+from typing import List, Tuple
+from vllm.config import VllmConfig
...
...
@@ -747,6 +796,7 @@ index 00000000..86248e7b
+ engine_id: str
+ agent_metadata: List[bytes]
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
+ num_blocks: int
+
+
+class DynemoNixlConnector:
...
...
@@ -758,6 +808,8 @@ index 00000000..86248e7b
+ logger.info("Initializing NIXL wrapper")
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+
+ self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer
+
+ self.num_layers = None
+ self.num_blocks = None
+ self.num_heads = None
...
...
@@ -770,11 +822,10 @@ index 00000000..86248e7b
+ self._remote_agents = {}
+ self.engine_id = engine_id
+ self.rank = rank
+ self.notifs = {}
+ self._tp_size = {}
+ self.
_block_desc
s = {}
+ self._xfer_side_handles =
{}
+
+ self.
src_xfer_side_handle
s = {}
+ self.
dst
_xfer_side_handles =
defaultdict(dict)
+
self.dst_num_blocks = {}
+
+ self._transfers = defaultdict(list)
+
...
...
@@ -796,16 +847,12 @@ index 00000000..86248e7b
+ self.kv_caches = kv_caches
+ kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr()
+ region_len = num_blocks * self.block_len
+ base_addr = key_cache.data_ptr()
+ region_len = 2 * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank))
+ for block_id in range(self.num_blocks):
+ blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
...
...
@@ -813,26 +860,20 @@ index 00000000..86248e7b
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+
+ def shutdown(self):
+ for descs_list in self._registered_descs:
+ self.nixl_wrapper.deregister_memory(descs_list)
+ for agent_name in self._remote_agents.values():
+ for agent_names in self._remote_agents.values():
+ for agent_name in agent_names:
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
+ for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+ for src_xfer_side_handle in self.src_xfer_side_handles.values():
+ self.nixl_wrapper.delete_xfer_side(src_xfer_side_handle)
+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
+ for dst_xfer_side_handle in dst_xfer_side_handles.values():
+ self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
...
...
@@ -869,7 +910,6 @@ index 00000000..86248e7b
+ blocks_len = range_len * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+
+ def _get_ranges(self, block_ids):
...
...
@@ -927,21 +967,40 @@ index 00000000..86248e7b
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_range=None):
+
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+
+ descs_ids = []
+
+ if i is not None:
+ num_blocks = self.num_blocks
+ start_offset = staging_range[0]
+ i_offset = i * (staging_range[-1] - start_offset + 1)
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ else:
+ num_blocks = self.dst_num_blocks[engine_id]
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
+ return descs_ids
+
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+
+ if self.use_prepped_xfer:
+ self._transfer_mem_prepped_xfer(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ else:
+ self._transfer_mem_create_xfer(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg
, use_prepped_xfer=False
):
+ def
_
transfer_mem
_prepped_xfer
(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
...
...
@@ -951,22 +1010,57 @@ index 00000000..86248e7b
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+ assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ # getting block descs ids
+ dst_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", dst_block_ids)
+ src_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+
+ for i in range(tp_multiplier):
+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_range=staging_ranges[0])
+ assert len(staging_block_descs_ids) == len(dst_block_descs_ids)
+ dst_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
+
+
+ logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
+ handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, staging_block_descs_ids,
+ dst_xfer_side_handle, dst_block_descs_ids,
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def _transfer_mem_create_xfer(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
+ # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
+ # isl[-1] token is calculated in the first iteration in decode.
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
...
...
@@ -989,6 +1083,8 @@ index 00000000..86248e7b
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+
+ logger.debug("Time to get same length ranges: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ for i in range(tp_multiplier):
+
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
...
...
@@ -1010,15 +1106,55 @@ index 00000000..86248e7b
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
+ def get_notifs(self):
+ self.notifs = self.nixl_wrapper.agent.getNotifs(self.notifs)
+ return self.notifs
+ return self.nixl_wrapper.update_notifs()
+
+ def get_new_notifs(self):
+ return self.nixl_wrapper.
a
ge
nt.getN
otifs(
{}
)
+ return self.nixl_wrapper.ge
t_new_n
otifs()
+
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
+ for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+
+ tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id]
+ assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
+
+ logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
+ dst_block_len = self.block_len // tp_multiplier
+ if tp_multiplier not in self.src_xfer_side_handles:
+ # create descs and xfer side handles
+ blocks_data = []
+ for layer_id in range(self.num_layers):
+ for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
+ for block_id in range(self.num_blocks):
+ block_offset = block_id * self.block_len
+ for i in range(tp_multiplier):
+ tp_multiplier_offset = i * dst_block_len
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side(descs)
+
+ # create dst xfer side handles
+ self.dst_num_blocks[engine_id] = num_blocks
+ for i in range(tp_multiplier):
+ blocks_data = []
+ for layer_id in range(self.num_layers):
+ for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]:
+ for block_id in range(num_blocks):
+ block_offset = block_id * dst_block_len
+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(descs, remote_agent=self._remote_agents[engine_id][self.rank * tp_multiplier + i])
+
+ return agent_names
+
+ def get_done_tranfers(self) -> List[str]:
+ done_req_ids = []
+ for req_id, handles in self._transfers.items():
...
...
@@ -2395,7 +2531,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..
62dbbd6e
100644
index d82d9ad9..
cc02b029
100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@
...
...
@@ -2472,7 +2608,7 @@ index d82d9ad9..62dbbd6e 100644
self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None)
@@ -405,6 +417,
39
@@
class LLMEngine:
@@ -405,6 +417,
40
@@
class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
...
...
@@ -2495,7 +2631,7 @@ index d82d9ad9..62dbbd6e 100644
+ raise RuntimeError("Nixl is not initialized")
+ agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata")
+ kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr")
+ return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr)
+ return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr
, num_blocks=self.cache_config.num_gpu_blocks
)
+
+ def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]:
+ if not self.is_nixl_initialized:
...
...
@@ -2503,7 +2639,8 @@ index d82d9ad9..62dbbd6e 100644
+ engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+ num_blocks = nixl_metadata.num_blocks
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks))
+
+ def _initialize_nixl(self) -> List[bytes]:
+ agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,))
...
...
@@ -2512,7 +2649,16 @@ index d82d9ad9..62dbbd6e 100644
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@@ -552,11 +597,14 @@
class LLMEngine:
@@ -500,6 +546,8 @@
class LLMEngine:
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
+ if self.is_nixl_initialized:
+ model_executor.collective_rpc("shutdown_nixl")
model_executor.shutdown()
def get_tokenizer_group(
@@ -552,11 +600,14 @@
class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
...
...
@@ -2527,7 +2673,7 @@ index d82d9ad9..62dbbd6e 100644
ParallelSampleSequenceGroup.add_request(
request_id,
self,
@@ -574,6 +62
2
,8 @@
class LLMEngine:
@@ -574,6 +62
5
,8 @@
class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
...
...
@@ -2536,7 +2682,7 @@ index d82d9ad9..62dbbd6e 100644
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +63
4
,7 @@
class LLMEngine:
@@ -584,7 +63
7
,7 @@
class LLMEngine:
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
...
...
@@ -2545,7 +2691,7 @@ index d82d9ad9..62dbbd6e 100644
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +65
1
,12 @@
class LLMEngine:
@@ -601,8 +65
4
,12 @@
class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
...
...
@@ -2559,7 +2705,7 @@ index d82d9ad9..62dbbd6e 100644
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
@@ -673,6 +7
27
,7 @@
class LLMEngine:
@@ -673,6 +7
30
,7 @@
class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
...
...
@@ -2567,7 +2713,7 @@ index d82d9ad9..62dbbd6e 100644
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@@ -765,6 +82
0
,7 @@
class LLMEngine:
@@ -765,6 +82
3
,7 @@
class LLMEngine:
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
...
...
@@ -2575,7 +2721,7 @@ index d82d9ad9..62dbbd6e 100644
)
def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +85
5
,7 @@
class LLMEngine:
@@ -799,6 +85
8
,7 @@
class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
...
...
@@ -2583,7 +2729,7 @@ index d82d9ad9..62dbbd6e 100644
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +88
6
,9 @@
class LLMEngine:
@@ -829,7 +88
9
,9 @@
class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
...
...
@@ -2594,7 +2740,7 @@ index d82d9ad9..62dbbd6e 100644
return seq_group
@@ -995,11 +105
4
,11 @@
class LLMEngine:
@@ -995,11 +105
7
,11 @@
class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
...
...
@@ -2608,7 +2754,7 @@ index d82d9ad9..62dbbd6e 100644
# Sanity check
assert len(seq_group_metadata_list) == len(
@@ -1325,15 +138
4
,49 @@
class LLMEngine:
@@ -1325,15 +138
7
,49 @@
class LLMEngine:
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
...
...
@@ -2660,7 +2806,7 @@ index d82d9ad9..62dbbd6e 100644
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +147
6
,31 @@
class LLMEngine:
@@ -1383,9 +147
9
,31 @@
class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
...
...
@@ -2694,7 +2840,7 @@ index d82d9ad9..62dbbd6e 100644
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
@@ -1396,7 +151
1
,26 @@
class LLMEngine:
@@ -1396,7 +151
4
,26 @@
class LLMEngine:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
...
...
@@ -2722,7 +2868,7 @@ index d82d9ad9..62dbbd6e 100644
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@@ -1456,7 +159
0
,7 @@
class LLMEngine:
@@ -1456,7 +159
3
,7 @@
class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
...
...
@@ -3534,7 +3680,7 @@ index 12baecde..489d3b77 100644
prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..
c01cfe00
100644
index 582aa460..
e4ed902e
100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@
...
...
@@ -3555,7 +3701,7 @@ index 582aa460..c01cfe00 100644
logger = init_logger(__name__)
@@ -306,6 +308,4
3
@@
class Worker(LocalOrDistributedWorkerBase):
@@ -306,6 +308,4
2
@@
class Worker(LocalOrDistributedWorkerBase):
self._init_cache_engine()
self._warm_up_model()
...
...
@@ -3573,10 +3719,9 @@ index 582aa460..c01cfe00 100644
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ return self.nixl_connector.get_agent_metadata()
+
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]
, num_blocks: int
) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata)) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank?
+ return agent_name
+
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
...
...
@@ -3599,7 +3744,7 @@ index 582aa460..c01cfe00 100644
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
@@ -367,6 +40
6
,8 @@
class Worker(LocalOrDistributedWorkerBase):
@@ -367,6 +40
5
,8 @@
class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
...
...
@@ -3608,7 +3753,7 @@ index 582aa460..c01cfe00 100644
return WorkerInput(
num_seq_groups=num_seq_groups,
@@ -375,6 +41
6
,11 @@
class Worker(LocalOrDistributedWorkerBase):
@@ -375,6 +41
5
,11 @@
class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment