"vscode:/vscode.git/clone" did not exist on "f790068600a0920e95d33d22686355bf8e48b077"
Commit 2cb921da authored by lizhigong's avatar lizhigong
Browse files

fix scheduler issu in pp + mtp

parent 5086453d
......@@ -1047,16 +1047,14 @@ class Scheduler(SchedulerInterface):
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
num_tokens = req.num_generated_token_ids
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
token_ids = req.all_token_ids[-num_tokens:]
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
......@@ -1190,6 +1188,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
......@@ -1197,9 +1196,11 @@ class Scheduler(SchedulerInterface):
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
......
......@@ -79,6 +79,7 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related
......
......@@ -496,8 +496,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
# Update the block IDs.
if not resumed_from_preemption:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment