Commit 17b624a0 authored by yangshj1's avatar yangshj1
Browse files

support pp+mtp

parent 7b8f9aa2
......@@ -429,6 +429,12 @@ class Scheduler(SchedulerInterface):
request, num_new_tokens
)
if self.use_pp and self.num_spec_tokens > 0:
# For PP with spec decoding, we only schedule a request when it has new tokens to compute.
if request.num_output_tokens == 0 and request.num_computed_tokens == request.num_prompt_tokens:
req_index += 1
continue
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
......@@ -1642,18 +1648,15 @@ class Scheduler(SchedulerInterface):
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = req.num_generated_token_ids
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
num_tokens = num_scheduled_tokens[req_id] - len(
spec_decode_tokens.get(req_id, ())
)
token_ids = req.all_token_ids[
req.num_computed_tokens : req.num_computed_tokens + num_tokens
]
num_tokens = req.num_generated_token_ids
token_ids = req.all_token_ids[-num_tokens:] if num_tokens > 0 else []
new_token_ids.append(token_ids)
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
if idx >= num_running_reqs:
......@@ -1925,6 +1928,8 @@ class Scheduler(SchedulerInterface):
sampled_token_ids[req_index] if sampled_token_ids else []
)
request.num_generated_token_ids = len(generated_token_ids)
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)
)
......
......@@ -131,6 +131,7 @@ class Request:
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: str | None = cache_salt
# Multi-modal related
......
......@@ -120,6 +120,7 @@ class InputBatch:
self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {}
self.invalid_req_indices: list[int] = []
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
ysj
import functools
import gc
import itertools
......@@ -972,6 +971,7 @@ class GPUModelRunner(
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
self.input_batch.remove_request(req_id)
reqs_to_add: list[CachedRequestState] = []
......@@ -1088,9 +1088,7 @@ class GPUModelRunner(
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
)
num_new_tokens = len(new_token_ids)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
......@@ -1141,13 +1139,13 @@ class GPUModelRunner(
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
if not is_last_rank:
if len(new_token_ids) > 0:
end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[
req_index, start_token_index:end_token_index
] = new_token_ids
req_index,
start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
......@@ -3316,6 +3314,7 @@ class GPUModelRunner(
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
self.input_batch.invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
......@@ -4211,6 +4210,9 @@ class GPUModelRunner(
if not self.num_spec_tokens or not self._draft_token_req_ids:
return None
draft_token_ids, req_ids = self._get_draft_token_ids_cpu()
if draft_token_ids is not None:
for i in self.input_batch.invalid_req_indices:
draft_token_ids[i].clear()
return DraftTokenIds(req_ids, draft_token_ids)
def _copy_draft_token_ids_to_cpu(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment