Commit 923ca4fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-wm' into 'v0.11.0-dev'

[feat]支持prefill和decoding分开调度

See merge request dcutoolkit/deeplearing/vllm!260
parents 1eff9d04 ff3e7f0a
......@@ -166,6 +166,10 @@ class Scheduler(SchedulerInterface):
self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens
self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
......@@ -629,27 +633,28 @@ class Scheduler(SchedulerInterface):
return scheduler_output
def schedule_split_pd(self) -> SchedulerOutput:
# Give priority to scheduling waiting requests
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to the running request index.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
encoder_compute_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
......@@ -718,6 +723,14 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if num_external_computed_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
......@@ -730,7 +743,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
......@@ -761,22 +774,44 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens
== 0 else
self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens =\
self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
# The request cannot be scheduled.
break
......@@ -802,9 +837,6 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
if request.use_structured_output:
structured_output_request_ids[request.request_id] = (
req_index)
req_index += 1
self.running.append(request)
if self.log_stats:
......@@ -820,8 +852,8 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
req_to_new_blocks[request.request_id] = (
self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
......@@ -836,7 +868,7 @@ class Scheduler(SchedulerInterface):
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
encoder_compute_budget = new_encoder_compute_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
......@@ -848,7 +880,8 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec -
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
......@@ -864,19 +897,22 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 1. No new tokens to schedule. This may happen when
# (1) PP>1 and we have already scheduled all prompt tokens
# but they are not finished yet.
# (2) Async scheduling and the request has reached to either
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
......@@ -885,15 +921,10 @@ class Scheduler(SchedulerInterface):
req_index += 1
continue
num_draft_tokens = max(
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
......@@ -904,10 +935,13 @@ class Scheduler(SchedulerInterface):
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
......@@ -930,14 +964,7 @@ class Scheduler(SchedulerInterface):
# Schedule the request.
scheduled_running_reqs.append(request)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
......@@ -960,7 +987,7 @@ class Scheduler(SchedulerInterface):
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
encoder_compute_budget = new_encoder_compute_budget
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
......@@ -970,7 +997,6 @@ class Scheduler(SchedulerInterface):
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
......@@ -992,15 +1018,10 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id])
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
......@@ -1008,8 +1029,13 @@ class Scheduler(SchedulerInterface):
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
req_to_new_blocks,
)
scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs +
scheduled_resumed_reqs)
structured_output_request_ids, grammar_bitmask = (
self.get_grammar_bitmask(scheduled_requests,
scheduled_spec_decode_tokens))
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
......@@ -1037,7 +1063,19 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
......@@ -1046,7 +1084,7 @@ class Scheduler(SchedulerInterface):
return scheduler_output
def schedule(self) -> SchedulerOutput:
if self.num_spec_tokens > 0 or envs.VLLM_USE_PD_SPLIT:
if (self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0) or envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd()
else:
return self.schedule_default()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment