Unverified Commit af9ee90e authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

fix: fix missing num_remote_prefill_groups in vLLM patch (#981)

parent 8af8c82f
...@@ -533,7 +533,7 @@ index 000000000..79eb8db67 ...@@ -533,7 +533,7 @@ index 000000000..79eb8db67
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index cf85a2135..f9087b5c3 100644 index cf85a2135..f157aa231 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py
@@ -1,16 +1,30 @@ @@ -1,16 +1,30 @@
...@@ -702,7 +702,13 @@ index cf85a2135..f9087b5c3 100644 ...@@ -702,7 +702,13 @@ index cf85a2135..f9087b5c3 100644
running_queue = self.running running_queue = self.running
assert len(self._async_stopped) == 0 assert len(self._async_stopped) == 0
while running_queue: while running_queue:
@@ -1073,6 +1138,7 @@ class Scheduler: @@ -1068,11 +1133,13 @@ class Scheduler:
ignored_seq_groups=[],
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking),
+ num_remote_prefill_groups=0
)
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = []
waiting_queue = self.waiting waiting_queue = self.waiting
...@@ -710,7 +716,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -710,7 +716,7 @@ index cf85a2135..f9087b5c3 100644
leftover_waiting_sequences: Deque[SequenceGroup] = deque() leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue: while self._passed_delay(time.time()) and waiting_queue:
@@ -1121,8 +1187,10 @@ class Scheduler: @@ -1121,8 +1188,10 @@ class Scheduler:
True, enable_chunking) True, enable_chunking)
# If the sequence group cannot be allocated, stop. # If the sequence group cannot be allocated, stop.
...@@ -722,7 +728,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -722,7 +728,7 @@ index cf85a2135..f9087b5c3 100644
if can_allocate == AllocStatus.LATER: if can_allocate == AllocStatus.LATER:
break break
elif can_allocate == AllocStatus.NEVER: elif can_allocate == AllocStatus.NEVER:
@@ -1170,7 +1238,18 @@ class Scheduler: @@ -1170,7 +1239,18 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0: if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id) curr_loras.add(lora_int_id)
waiting_queue.popleft() waiting_queue.popleft()
...@@ -742,7 +748,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -742,7 +748,7 @@ index cf85a2135..f9087b5c3 100644
if partial_prefill_metadata is not None: if partial_prefill_metadata is not None:
partial_prefill_metadata.maybe_increment_partial_prefills( partial_prefill_metadata.maybe_increment_partial_prefills(
@@ -1214,9 +1293,10 @@ class Scheduler: @@ -1214,9 +1294,10 @@ class Scheduler:
ignored_seq_groups=ignored_seq_groups, ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking), is_prefill=True, enable_chunking=enable_chunking),
...@@ -754,7 +760,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -754,7 +760,7 @@ index cf85a2135..f9087b5c3 100644
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
@@ -1234,6 +1314,9 @@ class Scheduler: @@ -1234,6 +1315,9 @@ class Scheduler:
for seq_group in self.running: for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id, budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs()) seq_group.get_max_num_running_seqs())
...@@ -764,7 +770,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -764,7 +770,7 @@ index cf85a2135..f9087b5c3 100644
curr_loras = (set( curr_loras = (set(
seq_group.lora_int_id for seq_group in self.running seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None) if seq_group.lora_int_id > 0) if self.lora_enabled else None)
@@ -1258,7 +1341,9 @@ class Scheduler: @@ -1258,7 +1342,9 @@ class Scheduler:
if len(prefills.seq_groups) == 0: if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget, running_scheduled = self._schedule_running(budget,
curr_loras, curr_loras,
...@@ -775,7 +781,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -775,7 +781,7 @@ index cf85a2135..f9087b5c3 100644
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
@@ -1275,7 +1360,12 @@ class Scheduler: @@ -1275,7 +1361,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
if len(prefills.seq_groups) > 0: if len(prefills.seq_groups) > 0:
...@@ -789,7 +795,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -789,7 +795,7 @@ index cf85a2135..f9087b5c3 100644
self.running.extend(running_scheduled.decode_seq_groups_list) self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1452,12 +1542,14 @@ class Scheduler: @@ -1452,12 +1543,14 @@ class Scheduler:
] ]
return finishing + not_finishing return finishing + not_finishing
...@@ -806,7 +812,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -806,7 +812,7 @@ index cf85a2135..f9087b5c3 100644
def _can_append_slots(self, seq_group: SequenceGroup, def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool: enable_chunking: bool) -> bool:
@@ -1491,14 +1583,16 @@ class Scheduler: @@ -1491,14 +1584,16 @@ class Scheduler:
return no_single_seq return no_single_seq
def schedule( def schedule(
...@@ -826,7 +832,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -826,7 +832,7 @@ index cf85a2135..f9087b5c3 100644
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching: if not self.cache_config.enable_prefix_caching:
@@ -1537,7 +1631,8 @@ class Scheduler: @@ -1537,7 +1632,8 @@ class Scheduler:
encoder_seq_data = None encoder_seq_data = None
cross_block_table = None cross_block_table = None
...@@ -836,7 +842,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -836,7 +842,7 @@ index cf85a2135..f9087b5c3 100644
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1546,7 +1641,9 @@ class Scheduler: @@ -1546,7 +1642,9 @@ class Scheduler:
if self.cache_config.enable_prefix_caching: if self.cache_config.enable_prefix_caching:
common_computed_block_nums = ( common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids( self.block_manager.get_common_computed_block_ids(
...@@ -847,7 +853,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -847,7 +853,7 @@ index cf85a2135..f9087b5c3 100644
do_sample = True do_sample = True
is_prompt = seq_group.is_prefill() is_prompt = seq_group.is_prefill()
@@ -1568,9 +1665,30 @@ class Scheduler: @@ -1568,9 +1666,30 @@ class Scheduler:
< seqs[0].data.get_len()): < seqs[0].data.get_len()):
do_sample = False do_sample = False
...@@ -878,7 +884,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -878,7 +884,7 @@ index cf85a2135..f9087b5c3 100644
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,
@@ -1598,6 +1716,7 @@ class Scheduler: @@ -1598,6 +1717,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None), if scheduler_outputs.num_prefill_groups > 0 else None),
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
...@@ -886,7 +892,7 @@ index cf85a2135..f9087b5c3 100644 ...@@ -886,7 +892,7 @@ index cf85a2135..f9087b5c3 100644
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for
@@ -1696,10 +1815,16 @@ class Scheduler: @@ -1696,10 +1816,16 @@ class Scheduler:
self._async_stopped.clear() self._async_stopped.clear()
...@@ -1039,10 +1045,10 @@ index 000000000..a2f9ce99e ...@@ -1039,10 +1045,10 @@ index 000000000..a2f9ce99e
\ No newline at end of file \ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 000000000..4c5ed707f index 000000000..bd4ac984e
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,447 @@ @@ -0,0 +1,445 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0
+# +#
...@@ -1487,8 +1493,6 @@ index 000000000..4c5ed707f ...@@ -1487,8 +1493,6 @@ index 000000000..4c5ed707f
+ done_req_ids.append(req_id) + done_req_ids.append(req_id)
+ else: + else:
+ self._transfers[req_id] = running_reqs + self._transfers[req_id] = running_reqs
+ for req_id in done_req_ids:
+ del self._transfers[req_id]
+ return done_req_ids + return done_req_ids
diff --git a/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py
new file mode 100644 new file mode 100644
...@@ -2892,7 +2896,7 @@ index 975afe5ad..2208abea0 100644 ...@@ -2892,7 +2896,7 @@ index 975afe5ad..2208abea0 100644
use_v1 = True use_v1 = True
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 54f7b8fb6..0559f9db2 100644 index 54f7b8fb6..9c1c2635f 100644
--- a/vllm/engine/llm_engine.py --- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py
@@ -1,11 +1,28 @@ @@ -1,11 +1,28 @@
...@@ -3135,7 +3139,7 @@ index 54f7b8fb6..0559f9db2 100644 ...@@ -3135,7 +3139,7 @@ index 54f7b8fb6..0559f9db2 100644
# Skip the scheduler if there are any remaining steps in the seq groups. # Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current # This ensures that the scheduler is only called again when the current
@@ -1372,7 +1452,41 @@ class LLMEngine: @@ -1372,7 +1452,43 @@ class LLMEngine:
# Schedule iteration # Schedule iteration
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc allow_async_output_proc
...@@ -3165,6 +3169,7 @@ index 54f7b8fb6..0559f9db2 100644 ...@@ -3165,6 +3169,7 @@ index 54f7b8fb6..0559f9db2 100644
+ logger.debug("No blocks to prefill") + logger.debug("No blocks to prefill")
+ self._finished_prefills.add(seq_group_metadata.request_id) + self._finished_prefills.add(seq_group_metadata.request_id)
+ continue + continue
+
+ remote_prefill_request = RemotePrefillRequest( + remote_prefill_request = RemotePrefillRequest(
+ request_id=seq_group_metadata.request_id, + request_id=seq_group_metadata.request_id,
+ # prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway + # prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway
...@@ -3173,12 +3178,13 @@ index 54f7b8fb6..0559f9db2 100644 ...@@ -3173,12 +3178,13 @@ index 54f7b8fb6..0559f9db2 100644
+ block_ids=block_table, + block_ids=block_table,
+ engine_id=self.engine_id, + engine_id=self.engine_id,
+ computed_block_ids=seq_group_metadata.computed_block_nums, + computed_block_ids=seq_group_metadata.computed_block_nums,
+ multimodal_data_source=scheduled_seq_group.seq_group.remote_prefill_params.multimodal_data_source
+ ) + )
+ scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request) + scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request)
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
@@ -1427,8 +1541,46 @@ class LLMEngine: @@ -1427,8 +1543,46 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
...@@ -3226,7 +3232,7 @@ index 54f7b8fb6..0559f9db2 100644 ...@@ -3226,7 +3232,7 @@ index 54f7b8fb6..0559f9db2 100644
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
self._skip_scheduling_next_step = False self._skip_scheduling_next_step = False
except InputProcessingError as e: except InputProcessingError as e:
@@ -1444,7 +1596,6 @@ class LLMEngine: @@ -1444,7 +1598,6 @@ class LLMEngine:
allow_async_output_proc=allow_async_output_proc) allow_async_output_proc=allow_async_output_proc)
# Raise so the caller is notified that this request failed # Raise so the caller is notified that this request failed
raise raise
...@@ -3234,7 +3240,7 @@ index 54f7b8fb6..0559f9db2 100644 ...@@ -3234,7 +3240,7 @@ index 54f7b8fb6..0559f9db2 100644
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1455,7 +1606,26 @@ class LLMEngine: @@ -1455,7 +1608,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# No outputs in this case # No outputs in this case
...@@ -3262,7 +3268,7 @@ index 54f7b8fb6..0559f9db2 100644 ...@@ -3262,7 +3268,7 @@ index 54f7b8fb6..0559f9db2 100644
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1515,7 +1685,7 @@ class LLMEngine: @@ -1515,7 +1687,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters. # queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.") logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
...@@ -4174,10 +4180,10 @@ index 0ed221043..08dbc0e78 100644 ...@@ -4174,10 +4180,10 @@ index 0ed221043..08dbc0e78 100644
"The vLLM package was not found, so its version could not be " "The vLLM package was not found, so its version could not be "
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644 new file mode 100644
index 000000000..83f6cd575 index 000000000..0a063f1ca
--- /dev/null --- /dev/null
+++ b/vllm/remote_prefill.py +++ b/vllm/remote_prefill.py
@@ -0,0 +1,82 @@ @@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0
+# +#
...@@ -4223,6 +4229,7 @@ index 000000000..83f6cd575 ...@@ -4223,6 +4229,7 @@ index 000000000..83f6cd575
+ sampling_params: SamplingParams + sampling_params: SamplingParams
+ block_ids: List[int] + block_ids: List[int]
+ computed_block_ids: List[int] + computed_block_ids: List[int]
+ multimodal_data_source: Optional[dict[str, str]] = None
+ +
+ +
+class MemoryOpType(str, Enum): +class MemoryOpType(str, Enum):
...@@ -4260,7 +4267,7 @@ index 000000000..83f6cd575 ...@@ -4260,7 +4267,7 @@ index 000000000..83f6cd575
+ decode_computed_block_ids: Optional[List[int]] = None + decode_computed_block_ids: Optional[List[int]] = None
+ decode_engine_id: Optional[str] = None + decode_engine_id: Optional[str] = None
+ remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None + remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None
\ No newline at end of file + multimodal_data_source: Optional[dict[str, str]] = None
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 68ed99664..5b0b7e6dc 100644 index 68ed99664..5b0b7e6dc 100644
--- a/vllm/sampling_params.py --- a/vllm/sampling_params.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment