Commit 17b624a0 authored by yangshj1's avatar yangshj1
Browse files

support pp+mtp

parent 7b8f9aa2
...@@ -429,6 +429,12 @@ class Scheduler(SchedulerInterface): ...@@ -429,6 +429,12 @@ class Scheduler(SchedulerInterface):
request, num_new_tokens request, num_new_tokens
) )
if self.use_pp and self.num_spec_tokens > 0:
# For PP with spec decoding, we only schedule a request when it has new tokens to compute.
if request.num_output_tokens == 0 and request.num_computed_tokens == request.num_prompt_tokens:
req_index += 1
continue
if num_new_tokens == 0: if num_new_tokens == 0:
# The request cannot be scheduled because one of the following # The request cannot be scheduled because one of the following
# reasons: # reasons:
...@@ -1642,18 +1648,15 @@ class Scheduler(SchedulerInterface): ...@@ -1642,18 +1648,15 @@ class Scheduler(SchedulerInterface):
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
num_tokens = req.num_generated_token_ids
if self.use_pp: if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner # need to send the sampled tokens back because the model runner
# will cache them. # will cache them.
num_tokens = num_scheduled_tokens[req_id] - len( num_tokens = req.num_generated_token_ids
spec_decode_tokens.get(req_id, ()) token_ids = req.all_token_ids[-num_tokens:] if num_tokens > 0 else []
)
token_ids = req.all_token_ids[
req.num_computed_tokens : req.num_computed_tokens + num_tokens
]
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids
if idx >= num_running_reqs: if idx >= num_running_reqs:
...@@ -1925,6 +1928,8 @@ class Scheduler(SchedulerInterface): ...@@ -1925,6 +1928,8 @@ class Scheduler(SchedulerInterface):
sampled_token_ids[req_index] if sampled_token_ids else [] sampled_token_ids[req_index] if sampled_token_ids else []
) )
request.num_generated_token_ids = len(generated_token_ids)
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id) scheduler_output.scheduled_spec_decode_tokens.get(req_id)
) )
......
...@@ -131,6 +131,7 @@ class Request: ...@@ -131,6 +131,7 @@ class Request:
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: str | None = cache_salt self.cache_salt: str | None = cache_salt
# Multi-modal related # Multi-modal related
......
...@@ -120,6 +120,7 @@ class InputBatch: ...@@ -120,6 +120,7 @@ class InputBatch:
self._req_ids: list[str | None] = [] self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {} self.req_id_to_index: dict[str, int] = {}
self.invalid_req_indices: list[int] = []
# TODO(woosuk): This buffer could be too large if max_model_len is big. # TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage. # Find a way to reduce the CPU memory usage.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
ysj
import functools import functools
import gc import gc
import itertools import itertools
...@@ -972,6 +971,7 @@ class GPUModelRunner( ...@@ -972,6 +971,7 @@ class GPUModelRunner(
# have low request overlap (e.g., alternating between two distinct # have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient. # sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids: for req_id in unscheduled_req_ids:
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
reqs_to_add: list[CachedRequestState] = [] reqs_to_add: list[CachedRequestState] = []
...@@ -1088,9 +1088,7 @@ class GPUModelRunner( ...@@ -1088,9 +1088,7 @@ class GPUModelRunner(
new_token_ids = req_data.new_token_ids[i] new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens. # This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = ( num_new_tokens = len(new_token_ids)
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
)
if num_new_tokens == 1: if num_new_tokens == 1:
# Avoid slicing list in most common case. # Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1]) req_state.output_token_ids.append(new_token_ids[-1])
...@@ -1141,14 +1139,14 @@ class GPUModelRunner( ...@@ -1141,14 +1139,14 @@ class GPUModelRunner(
# For the last rank, we don't need to update the token_ids_cpu # For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached. # because the sampled tokens are already cached.
start_token_index = num_computed_tokens
if not is_last_rank: if not is_last_rank:
# Add new_token_ids to token_ids_cpu. if len(new_token_ids) > 0:
start_token_index = num_computed_tokens end_token_index = num_computed_tokens + 1
end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[
self.input_batch.token_ids_cpu[ req_index,
req_index, start_token_index:end_token_index start_token_index:end_token_index] = new_token_ids[-1]
] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
...@@ -3316,6 +3314,7 @@ class GPUModelRunner( ...@@ -3316,6 +3314,7 @@ class GPUModelRunner(
if not self.use_async_scheduling: if not self.use_async_scheduling:
# Get the valid generated tokens. # Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
self.input_batch.invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids) valid_sampled_token_ids = self._to_list(sampled_token_ids)
...@@ -4211,6 +4210,9 @@ class GPUModelRunner( ...@@ -4211,6 +4210,9 @@ class GPUModelRunner(
if not self.num_spec_tokens or not self._draft_token_req_ids: if not self.num_spec_tokens or not self._draft_token_req_ids:
return None return None
draft_token_ids, req_ids = self._get_draft_token_ids_cpu() draft_token_ids, req_ids = self._get_draft_token_ids_cpu()
if draft_token_ids is not None:
for i in self.input_batch.invalid_req_indices:
draft_token_ids[i].clear()
return DraftTokenIds(req_ids, draft_token_ids) return DraftTokenIds(req_ids, draft_token_ids)
def _copy_draft_token_ids_to_cpu( def _copy_draft_token_ids_to_cpu(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment