Commit bc42616e authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: Enable make_xfer NIXL kv transfer (#39)

Co-authored-by: ptarasiewicz@nvidia.com <Piotr Tarasiewicz>
parent 09656f6c
diff --git a/vllm/config.py b/vllm/config.py diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..cbfeb715 100644 index 9ba49757..3ec4bbab 100644
--- a/vllm/config.py --- a/vllm/config.py
+++ b/vllm/config.py +++ b/vllm/config.py
@@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel): @@ -2620,6 +2620,9 @@ class KVTransferConfig(BaseModel):
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None
+ # Whether to use NIXL prepped xfer for KV cache transfer.
+ use_prepped_xfer: bool = False
+
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
@@ -2629,7 +2632,7 @@ class KVTransferConfig(BaseModel):
kv_buffer_size: float = 1e9 kv_buffer_size: float = 1e9
# Whether this vLLM instance produces, consumes KV cache, or both. Choices # Whether this vLLM instance produces, consumes KV cache, or both. Choices
...@@ -11,7 +21,7 @@ index 9ba49757..cbfeb715 100644 ...@@ -11,7 +21,7 @@ index 9ba49757..cbfeb715 100644
kv_role: Optional[str] = None kv_role: Optional[str] = None
# The rank of this vLLM instance in the KV cache transfer. Typical value: # The rank of this vLLM instance in the KV cache transfer. Typical value:
@@ -2647,6 +2647,14 @@ class KVTransferConfig(BaseModel): @@ -2647,6 +2650,14 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection # The KV connector port, used to build distributed connection
kv_port: int = 14579 kv_port: int = 14579
...@@ -26,7 +36,7 @@ index 9ba49757..cbfeb715 100644 ...@@ -26,7 +36,7 @@ index 9ba49757..cbfeb715 100644
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
@@ -2680,11 +2688,12 @@ class KVTransferConfig(BaseModel): @@ -2680,11 +2691,12 @@ class KVTransferConfig(BaseModel):
f"Supported roles are `kv_producer`, `kv_consumer`, " f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`") f"and `kv_both`")
...@@ -40,7 +50,7 @@ index 9ba49757..cbfeb715 100644 ...@@ -40,7 +50,7 @@ index 9ba49757..cbfeb715 100644
@property @property
def is_kv_transfer_instance(self) -> bool: def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \ return self.kv_connector is not None and \
@@ -2694,6 +2703,8 @@ class KVTransferConfig(BaseModel): @@ -2694,6 +2706,8 @@ class KVTransferConfig(BaseModel):
def need_kv_parallel_group(self) -> bool: def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create # for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1. # parallel group, and in that case the kv parallel size will be 1.
...@@ -49,7 +59,7 @@ index 9ba49757..cbfeb715 100644 ...@@ -49,7 +59,7 @@ index 9ba49757..cbfeb715 100644
return self.kv_connector is not None and self.kv_parallel_size > 1 return self.kv_connector is not None and self.kv_parallel_size > 1
@property @property
@@ -2706,6 +2717,18 @@ class KVTransferConfig(BaseModel): @@ -2706,6 +2720,18 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \ return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"] self.kv_role in ["kv_consumer", "kv_both"]
...@@ -368,7 +378,7 @@ index 00000000..8699ca06 ...@@ -368,7 +378,7 @@ index 00000000..8699ca06
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..abe574d1 100644 index f507847a..4d299d7f 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py
@@ -4,22 +4,22 @@ import enum @@ -4,22 +4,22 @@ import enum
...@@ -398,7 +408,23 @@ index f507847a..abe574d1 100644 ...@@ -398,7 +408,23 @@ index f507847a..abe574d1 100644
logger = init_logger(__name__) logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with # Test-only. If configured, decode is preempted with
@@ -325,12 +325,14 @@ class Scheduler: @@ -285,6 +285,7 @@ class SchedulerPrefillOutputs:
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
+ num_remote_prefill_groups: int
@classmethod
def create_empty(cls) -> "SchedulerPrefillOutputs":
@@ -292,6 +293,7 @@ class SchedulerPrefillOutputs:
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=0,
+ num_remote_prefill_groups=0,
)
@@ -325,12 +327,14 @@ class Scheduler:
def __init__( def __init__(
self, self,
...@@ -413,7 +439,7 @@ index f507847a..abe574d1 100644 ...@@ -413,7 +439,7 @@ index f507847a..abe574d1 100644
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely # Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +358,7 @@ class Scheduler: @@ -356,6 +360,7 @@ class Scheduler:
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManagerImpl( self.block_manager = BlockSpaceManagerImpl(
...@@ -421,7 +447,7 @@ index f507847a..abe574d1 100644 ...@@ -421,7 +447,7 @@ index f507847a..abe574d1 100644
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +374,16 @@ class Scheduler: @@ -371,6 +376,16 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out. # Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
...@@ -438,7 +464,7 @@ index f507847a..abe574d1 100644 ...@@ -438,7 +464,7 @@ index f507847a..abe574d1 100644
# Sequence groups finished requests ids since last step iteration. # Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests # It lets the model know that any state associated with these requests
# can and must be released after the current step. # can and must be released after the current step.
@@ -501,7 +514,7 @@ class Scheduler: @@ -501,7 +516,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len( return len(self.waiting) != 0 or len(self.running) != 0 or len(
...@@ -447,7 +473,7 @@ index f507847a..abe574d1 100644 ...@@ -447,7 +473,7 @@ index f507847a..abe574d1 100644
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device) return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +536,8 @@ class Scheduler: @@ -523,6 +538,8 @@ class Scheduler:
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
enable_chunking: bool = False, enable_chunking: bool = False,
...@@ -456,7 +482,7 @@ index f507847a..abe574d1 100644 ...@@ -456,7 +482,7 @@ index f507847a..abe574d1 100644
) -> SchedulerRunningOutputs: ) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running. """Schedule sequence groups that are running.
@@ -537,6 +552,8 @@ class Scheduler: @@ -537,6 +554,8 @@ class Scheduler:
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
...@@ -465,7 +491,7 @@ index f507847a..abe574d1 100644 ...@@ -465,7 +491,7 @@ index f507847a..abe574d1 100644
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.
@@ -566,6 +583,38 @@ class Scheduler: @@ -566,6 +585,38 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out swapped_out: List[SequenceGroup] = ret.swapped_out
...@@ -504,7 +530,15 @@ index f507847a..abe574d1 100644 ...@@ -504,7 +530,15 @@ index f507847a..abe574d1 100644
running_queue = self.running running_queue = self.running
assert len(self._async_stopped) == 0 assert len(self._async_stopped) == 0
while running_queue: while running_queue:
@@ -1008,7 +1057,17 @@ class Scheduler: @@ -925,6 +976,7 @@ class Scheduler:
seq_groups: List[ScheduledSequenceGroup] = []
waiting_queue = self.waiting
+ num_remote_prefill_groups = 0
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
@@ -1008,7 +1060,18 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0: if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id) curr_loras.add(lora_int_id)
waiting_queue.popleft() waiting_queue.popleft()
...@@ -515,7 +549,8 @@ index f507847a..abe574d1 100644 ...@@ -515,7 +549,8 @@ index f507847a..abe574d1 100644
+ +
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id) + logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id) + logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group) + is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group)
+ num_remote_prefill_groups += is_remote_prefill
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode: + if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id) + logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy) + self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
...@@ -523,17 +558,26 @@ index f507847a..abe574d1 100644 ...@@ -523,17 +558,26 @@ index f507847a..abe574d1 100644
if enable_chunking and self.scheduler_config.is_multi_step: if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1107,7 @@ class Scheduler: @@ -1046,9 +1109,11 @@ class Scheduler:
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking)) - is_prefill=True, enable_chunking=enable_chunking))
+ is_prefill=True, enable_chunking=enable_chunking),
+ num_remote_prefill_groups=num_remote_prefill_groups
+ )
- def _schedule_default(self) -> SchedulerOutputs: - def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: + def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1149,9 @@ class Scheduler: @@ -1087,10 +1152,12 @@ class Scheduler:
if len(prefills.seq_groups) == 0: # Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
- if len(prefills.seq_groups) == 0:
+ if len(prefills.seq_groups) == prefills.num_remote_prefill_groups:
running_scheduled = self._schedule_running(budget, running_scheduled = self._schedule_running(budget,
curr_loras, curr_loras,
- enable_chunking=False) - enable_chunking=False)
...@@ -543,7 +587,7 @@ index f507847a..abe574d1 100644 ...@@ -543,7 +587,7 @@ index f507847a..abe574d1 100644
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
@@ -1106,7 +1167,12 @@ class Scheduler: @@ -1106,7 +1173,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
if len(prefills.seq_groups) > 0: if len(prefills.seq_groups) > 0:
...@@ -557,7 +601,7 @@ index f507847a..abe574d1 100644 ...@@ -557,7 +601,7 @@ index f507847a..abe574d1 100644
self.running.extend(running_scheduled.decode_seq_groups_list) self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1314,14 @@ class Scheduler: @@ -1248,12 +1320,14 @@ class Scheduler:
len(running_scheduled.swapped_out)), len(running_scheduled.swapped_out)),
) )
...@@ -574,7 +618,7 @@ index f507847a..abe574d1 100644 ...@@ -574,7 +618,7 @@ index f507847a..abe574d1 100644
def _can_append_slots(self, seq_group: SequenceGroup, def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool: enable_chunking: bool) -> bool:
@@ -1287,14 +1355,16 @@ class Scheduler: @@ -1287,14 +1361,16 @@ class Scheduler:
return no_single_seq return no_single_seq
def schedule( def schedule(
...@@ -594,7 +638,7 @@ index f507847a..abe574d1 100644 ...@@ -594,7 +638,7 @@ index f507847a..abe574d1 100644
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching: if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1403,8 @@ class Scheduler: @@ -1333,7 +1409,8 @@ class Scheduler:
encoder_seq_data = None encoder_seq_data = None
cross_block_table = None cross_block_table = None
...@@ -604,7 +648,7 @@ index f507847a..abe574d1 100644 ...@@ -604,7 +648,7 @@ index f507847a..abe574d1 100644
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,9 +1435,16 @@ class Scheduler: @@ -1364,9 +1441,16 @@ class Scheduler:
< seqs[0].data.get_len()): < seqs[0].data.get_len()):
do_sample = False do_sample = False
...@@ -621,7 +665,7 @@ index f507847a..abe574d1 100644 ...@@ -621,7 +665,7 @@ index f507847a..abe574d1 100644
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@ class Scheduler: @@ -1392,6 +1476,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
...@@ -629,22 +673,27 @@ index f507847a..abe574d1 100644 ...@@ -629,22 +673,27 @@ index f507847a..abe574d1 100644
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +1569,13 @@ class Scheduler: @@ -1490,11 +1575,17 @@ class Scheduler:
self._async_stopped.clear() self._async_stopped.clear()
- def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
+ def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> None: + def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> bool:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)
+ is_remote_prefill = False
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
- seq.status = SequenceStatus.RUNNING - seq.status = SequenceStatus.RUNNING
-
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ seq.status = SequenceStatus.REMOTE_PREFILLING + seq.status = SequenceStatus.REMOTE_PREFILLING
+ is_remote_prefill = True
+ else: + else:
+ seq.status = SequenceStatus.RUNNING + seq.status = SequenceStatus.RUNNING
+ return is_remote_prefill
+
def _append_slots(self, def _append_slots(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644 new file mode 100644
index 00000000..9b938039 index 00000000..9b938039
...@@ -715,10 +764,10 @@ index 00000000..9b938039 ...@@ -715,10 +764,10 @@ index 00000000..9b938039
\ No newline at end of file \ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 00000000..86248e7b index 00000000..523d58d4
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,318 @@ @@ -0,0 +1,405 @@
+import torch +import torch
+from typing import List, Tuple +from typing import List, Tuple
+from vllm.config import VllmConfig +from vllm.config import VllmConfig
...@@ -747,6 +796,7 @@ index 00000000..86248e7b ...@@ -747,6 +796,7 @@ index 00000000..86248e7b
+ engine_id: str + engine_id: str
+ agent_metadata: List[bytes] + agent_metadata: List[bytes]
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values + kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
+ num_blocks: int
+ +
+ +
+class DynemoNixlConnector: +class DynemoNixlConnector:
...@@ -758,6 +808,8 @@ index 00000000..86248e7b ...@@ -758,6 +808,8 @@ index 00000000..86248e7b
+ logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL wrapper")
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+ +
+ self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer
+
+ self.num_layers = None + self.num_layers = None
+ self.num_blocks = None + self.num_blocks = None
+ self.num_heads = None + self.num_heads = None
...@@ -770,11 +822,10 @@ index 00000000..86248e7b ...@@ -770,11 +822,10 @@ index 00000000..86248e7b
+ self._remote_agents = {} + self._remote_agents = {}
+ self.engine_id = engine_id + self.engine_id = engine_id
+ self.rank = rank + self.rank = rank
+ self.notifs = {}
+ self._tp_size = {} + self._tp_size = {}
+ self._block_descs = {} + self.src_xfer_side_handles = {}
+ self._xfer_side_handles = {} + self.dst_xfer_side_handles = defaultdict(dict)
+ + self.dst_num_blocks = {}
+ +
+ self._transfers = defaultdict(list) + self._transfers = defaultdict(list)
+ +
...@@ -796,16 +847,12 @@ index 00000000..86248e7b ...@@ -796,16 +847,12 @@ index 00000000..86248e7b
+ self.kv_caches = kv_caches + self.kv_caches = kv_caches
+ kv_caches_base_addr = [] + kv_caches_base_addr = []
+ caches_data = [] + caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches: + for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]: + base_addr = key_cache.data_ptr()
+ base_addr = cache.data_ptr() + region_len = 2 * num_blocks * self.block_len
+ region_len = num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank))
+ caches_data.append((base_addr, region_len, self.rank))
+ for block_id in range(self.num_blocks):
+ blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) + kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+ +
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data)) + descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
...@@ -813,26 +860,20 @@ index 00000000..86248e7b ...@@ -813,26 +860,20 @@ index 00000000..86248e7b
+ self.nixl_wrapper.register_memory(descs) + self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs) + self._registered_descs.append(descs)
+ +
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self): + def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata() + return self.nixl_wrapper.get_agent_metadata()
+ +
+ def shutdown(self): + def shutdown(self):
+ for descs_list in self._registered_descs: + for descs_list in self._registered_descs:
+ self.nixl_wrapper.deregister_memory(descs_list) + self.nixl_wrapper.deregister_memory(descs_list)
+ for agent_name in self._remote_agents.values(): + for agent_names in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name) + for agent_name in agent_names:
+ + self.nixl_wrapper.remove_remote_agent(agent_name)
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp): + for src_xfer_side_handle in self.src_xfer_side_handles.values():
+ self._tp_size[engine_id] = agent_tp + self.nixl_wrapper.delete_xfer_side(src_xfer_side_handle)
+ agent_names = [] + for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
+ for agent_meta in agent_metadata: + for dst_xfer_side_handle in dst_xfer_side_handles.values():
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta) + self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+ +
+ def get_descs_ids(self, layer_ids, block_ids): + def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all": + if layer_ids == "all":
...@@ -869,7 +910,6 @@ index 00000000..86248e7b ...@@ -869,7 +910,6 @@ index 00000000..86248e7b
+ blocks_len = range_len * block_len + blocks_len = range_len * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank)) + blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank)) + blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data)) + return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ +
+ def _get_ranges(self, block_ids): + def _get_ranges(self, block_ids):
...@@ -927,21 +967,40 @@ index 00000000..86248e7b ...@@ -927,21 +967,40 @@ index 00000000..86248e7b
+ +
+ +
+ +
+ def _get_block_descs_ids(self, layer_ids, block_ids): + def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_range=None):
+
+ if layer_ids == "all": + if layer_ids == "all":
+ layer_ids = list(range(self.num_layers)) + layer_ids = list(range(self.num_layers))
+ if block_ids == "all": + if block_ids == "all":
+ block_ids = list(range(self.num_blocks)) + block_ids = list(range(self.num_blocks))
+
+ descs_ids = [] + descs_ids = []
+ for layer_id in layer_ids: +
+ for is_value in [0, 1]: + if i is not None:
+ for block_id in block_ids: + num_blocks = self.num_blocks
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id) + start_offset = staging_range[0]
+ i_offset = i * (staging_range[-1] - start_offset + 1)
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ else:
+ num_blocks = self.dst_num_blocks[engine_id]
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
+ return descs_ids + return descs_ids
+ +
+
+ +
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False): + def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+
+ if self.use_prepped_xfer:
+ self._transfer_mem_prepped_xfer(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ else:
+ self._transfer_mem_create_xfer(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+ def _transfer_mem_prepped_xfer(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+ start_time = time.perf_counter() + start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg) + logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+ +
...@@ -951,22 +1010,57 @@ index 00000000..86248e7b ...@@ -951,22 +1010,57 @@ index 00000000..86248e7b
+ # one less block due to the missing last token. + # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)] + dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids) + assert len(staging_block_ids) == len(src_block_ids)
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+
+ assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+ +
+ if use_prepped_xfer: + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+ +
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id] + # getting block descs ids
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id] + dst_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", dst_block_ids)
+ src_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+
+ for i in range(tp_multiplier):
+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_range=staging_ranges[0])
+ assert len(staging_block_descs_ids) == len(dst_block_descs_ids)
+ dst_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
+ +
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000) +
+ + logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids, + handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, staging_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids, + dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True) + notify_msg, "WRITE")
+ # else: + self._transfers[notify_msg].append(handle)
+ # Legacy path using range-based transfers + logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def _transfer_mem_create_xfer(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
+ # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
+ # isl[-1] token is calculated in the first iteration in decode.
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+ src_ranges = self._get_ranges(src_block_ids) + src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids) + staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids) + dst_ranges = self._get_ranges(dst_block_ids)
...@@ -989,6 +1083,8 @@ index 00000000..86248e7b ...@@ -989,6 +1083,8 @@ index 00000000..86248e7b
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges) + staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges) + assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+ +
+ logger.debug("Time to get same length ranges: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ for i in range(tp_multiplier): + for i in range(tp_multiplier):
+ +
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i) + src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
...@@ -1010,15 +1106,55 @@ index 00000000..86248e7b ...@@ -1010,15 +1106,55 @@ index 00000000..86248e7b
+ return self.nixl_wrapper.deserialize_descs(serialized_descs) + return self.nixl_wrapper.deserialize_descs(serialized_descs)
+ +
+ def get_notifs(self): + def get_notifs(self):
+ self.notifs = self.nixl_wrapper.agent.getNotifs(self.notifs) + return self.nixl_wrapper.update_notifs()
+ return self.notifs
+ +
+ def get_new_notifs(self): + def get_new_notifs(self):
+ return self.nixl_wrapper.agent.getNotifs({}) + return self.nixl_wrapper.get_new_notifs()
+ +
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr): +
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
+ for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr + self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+ +
+ tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id]
+ assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
+
+ logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
+ dst_block_len = self.block_len // tp_multiplier
+ if tp_multiplier not in self.src_xfer_side_handles:
+ # create descs and xfer side handles
+ blocks_data = []
+ for layer_id in range(self.num_layers):
+ for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
+ for block_id in range(self.num_blocks):
+ block_offset = block_id * self.block_len
+ for i in range(tp_multiplier):
+ tp_multiplier_offset = i * dst_block_len
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side(descs)
+
+ # create dst xfer side handles
+ self.dst_num_blocks[engine_id] = num_blocks
+ for i in range(tp_multiplier):
+ blocks_data = []
+ for layer_id in range(self.num_layers):
+ for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]:
+ for block_id in range(num_blocks):
+ block_offset = block_id * dst_block_len
+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(descs, remote_agent=self._remote_agents[engine_id][self.rank * tp_multiplier + i])
+
+ return agent_names
+
+ def get_done_tranfers(self) -> List[str]: + def get_done_tranfers(self) -> List[str]:
+ done_req_ids = [] + done_req_ids = []
+ for req_id, handles in self._transfers.items(): + for req_id, handles in self._transfers.items():
...@@ -2395,7 +2531,7 @@ index 321902d1..b8937ef8 100644 ...@@ -2395,7 +2531,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..62dbbd6e 100644 index d82d9ad9..cc02b029 100644
--- a/vllm/engine/llm_engine.py --- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@ @@ -2,13 +2,17 @@
...@@ -2472,7 +2608,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2472,7 +2608,7 @@ index d82d9ad9..62dbbd6e 100644
self.parallel_config.pipeline_parallel_size, self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id] self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None) if self.model_config.use_async_output_proc else None)
@@ -405,6 +417,39 @@ class LLMEngine: @@ -405,6 +417,40 @@ class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
...@@ -2495,7 +2631,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2495,7 +2631,7 @@ index d82d9ad9..62dbbd6e 100644
+ raise RuntimeError("Nixl is not initialized") + raise RuntimeError("Nixl is not initialized")
+ agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata") + agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata")
+ kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr") + kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr")
+ return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr) + return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.cache_config.num_gpu_blocks)
+ +
+ def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]: + def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]:
+ if not self.is_nixl_initialized: + if not self.is_nixl_initialized:
...@@ -2503,7 +2639,8 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2503,7 +2639,8 @@ index d82d9ad9..62dbbd6e 100644
+ engine_id = nixl_metadata.engine_id + engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata + agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr + kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr)) + num_blocks = nixl_metadata.num_blocks
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks))
+ +
+ def _initialize_nixl(self) -> List[bytes]: + def _initialize_nixl(self) -> List[bytes]:
+ agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,)) + agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,))
...@@ -2512,7 +2649,16 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2512,7 +2649,16 @@ index d82d9ad9..62dbbd6e 100644
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
@@ -552,11 +597,14 @@ class LLMEngine: @@ -500,6 +546,8 @@ class LLMEngine:
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
+ if self.is_nixl_initialized:
+ model_executor.collective_rpc("shutdown_nixl")
model_executor.shutdown()
def get_tokenizer_group(
@@ -552,11 +600,14 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
...@@ -2527,7 +2673,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2527,7 +2673,7 @@ index d82d9ad9..62dbbd6e 100644
ParallelSampleSequenceGroup.add_request( ParallelSampleSequenceGroup.add_request(
request_id, request_id,
self, self,
@@ -574,6 +622,8 @@ class LLMEngine: @@ -574,6 +625,8 @@ class LLMEngine:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
...@@ -2536,7 +2682,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2536,7 +2682,7 @@ index d82d9ad9..62dbbd6e 100644
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs): if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +634,7 @@ class LLMEngine: @@ -584,7 +637,7 @@ class LLMEngine:
encoder_inputs = None encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
...@@ -2545,7 +2691,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2545,7 +2691,7 @@ index d82d9ad9..62dbbd6e 100644
encoder_seq = (None if encoder_inputs is None else Sequence( encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request, seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +651,12 @@ class LLMEngine: @@ -601,8 +654,12 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
...@@ -2559,7 +2705,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2559,7 +2705,7 @@ index d82d9ad9..62dbbd6e 100644
seq_group = self._create_sequence_group_with_pooling( seq_group = self._create_sequence_group_with_pooling(
request_id, request_id,
seq, seq,
@@ -673,6 +727,7 @@ class LLMEngine: @@ -673,6 +730,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -2567,7 +2713,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2567,7 +2713,7 @@ index d82d9ad9..62dbbd6e 100644
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
@@ -765,6 +820,7 @@ class LLMEngine: @@ -765,6 +823,7 @@ class LLMEngine:
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
...@@ -2575,7 +2721,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2575,7 +2721,7 @@ index d82d9ad9..62dbbd6e 100644
) )
def _validate_token_prompt(self, prompt: PromptType, def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +855,7 @@ class LLMEngine: @@ -799,6 +858,7 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0, priority: int = 0,
...@@ -2583,7 +2729,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2583,7 +2729,7 @@ index d82d9ad9..62dbbd6e 100644
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +886,9 @@ class LLMEngine: @@ -829,7 +889,9 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
...@@ -2594,7 +2740,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2594,7 +2740,7 @@ index d82d9ad9..62dbbd6e 100644
return seq_group return seq_group
@@ -995,11 +1054,11 @@ class LLMEngine: @@ -995,11 +1057,11 @@ class LLMEngine:
# When we process only one request, no pop is required # When we process only one request, no pop is required
# (since later we will process all of the rest) # (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async, (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
...@@ -2608,7 +2754,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2608,7 +2754,7 @@ index d82d9ad9..62dbbd6e 100644
# Sanity check # Sanity check
assert len(seq_group_metadata_list) == len( assert len(seq_group_metadata_list) == len(
@@ -1325,15 +1384,49 @@ class LLMEngine: @@ -1325,15 +1387,49 @@ class LLMEngine:
# Clear outputs for each new scheduler iteration # Clear outputs for each new scheduler iteration
ctx.request_outputs.clear() ctx.request_outputs.clear()
...@@ -2660,7 +2806,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2660,7 +2806,7 @@ index d82d9ad9..62dbbd6e 100644
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1476,31 @@ class LLMEngine: @@ -1383,9 +1479,31 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
...@@ -2694,7 +2840,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2694,7 +2840,7 @@ index d82d9ad9..62dbbd6e 100644
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1396,7 +1511,26 @@ class LLMEngine: @@ -1396,7 +1514,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# No outputs in this case # No outputs in this case
...@@ -2722,7 +2868,7 @@ index d82d9ad9..62dbbd6e 100644 ...@@ -2722,7 +2868,7 @@ index d82d9ad9..62dbbd6e 100644
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1456,7 +1590,7 @@ class LLMEngine: @@ -1456,7 +1593,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters. # queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.") logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
...@@ -3534,7 +3680,7 @@ index 12baecde..489d3b77 100644 ...@@ -3534,7 +3680,7 @@ index 12baecde..489d3b77 100644
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..c01cfe00 100644 index 582aa460..e4ed902e 100644
--- a/vllm/worker/worker.py --- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py +++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
...@@ -3555,7 +3701,7 @@ index 582aa460..c01cfe00 100644 ...@@ -3555,7 +3701,7 @@ index 582aa460..c01cfe00 100644
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -306,6 +308,43 @@ class Worker(LocalOrDistributedWorkerBase): @@ -306,6 +308,42 @@ class Worker(LocalOrDistributedWorkerBase):
self._init_cache_engine() self._init_cache_engine()
self._warm_up_model() self._warm_up_model()
...@@ -3573,10 +3719,9 @@ index 582aa460..c01cfe00 100644 ...@@ -3573,10 +3719,9 @@ index 582aa460..c01cfe00 100644
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ return self.nixl_connector.get_agent_metadata() + return self.nixl_connector.get_agent_metadata()
+ +
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str: + def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]], num_blocks: int) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata)) # TODO ptarasiewicz: rank or local_rank? + agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ return agent_name + return agent_name
+ +
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None: + def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
...@@ -3599,7 +3744,7 @@ index 582aa460..c01cfe00 100644 ...@@ -3599,7 +3744,7 @@ index 582aa460..c01cfe00 100644
def _init_cache_engine(self): def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [ self.cache_engine = [
@@ -367,6 +406,8 @@ class Worker(LocalOrDistributedWorkerBase): @@ -367,6 +405,8 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device, device=self.device,
dtype=torch.int64).view(-1, 2) dtype=torch.int64).view(-1, 2)
...@@ -3608,7 +3753,7 @@ index 582aa460..c01cfe00 100644 ...@@ -3608,7 +3753,7 @@ index 582aa460..c01cfe00 100644
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
@@ -375,6 +416,11 @@ class Worker(LocalOrDistributedWorkerBase): @@ -375,6 +415,11 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_steps=num_steps, num_steps=num_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment