Commit ad477f82 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix scheduler issu in pp + mtp

parent 1e636721
......@@ -1099,16 +1099,14 @@ class Scheduler(SchedulerInterface):
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
num_tokens = req.num_generated_token_ids
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
token_ids = req.all_token_ids[-num_tokens:]
new_token_ids.append(token_ids)
elif use_connector:
# When using a KVConnector, we add a placeholder to avoid index
......@@ -1318,6 +1316,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
......@@ -1328,6 +1327,7 @@ class Scheduler(SchedulerInterface):
# num_computed_tokens is decreased by the number of rejected
# tokens.
request.num_computed_tokens -= num_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=num_draft_tokens,
......
......@@ -85,6 +85,7 @@ class Request:
self.num_output_placeholders = 0 # Used in async scheduling.
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related
......
......@@ -606,8 +606,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
# Update the block IDs.
if not resumed_from_preemption:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment