Commit ad477f82 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix scheduler issu in pp + mtp

parent 1e636721
...@@ -1099,16 +1099,14 @@ class Scheduler(SchedulerInterface): ...@@ -1099,16 +1099,14 @@ class Scheduler(SchedulerInterface):
for req in itertools.chain(running_reqs, resumed_reqs): for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] - num_tokens = req.num_generated_token_ids
len(spec_decode_tokens.get(req_id, ())))
if self.use_pp: if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner # need to send the sampled tokens back because the model runner
# will cache them. # will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req. token_ids = req.all_token_ids[-num_tokens:]
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
elif use_connector: elif use_connector:
# When using a KVConnector, we add a placeholder to avoid index # When using a KVConnector, we add a placeholder to avoid index
...@@ -1318,6 +1316,7 @@ class Scheduler(SchedulerInterface): ...@@ -1318,6 +1316,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids) num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1 num_accepted = len(generated_token_ids) - 1
...@@ -1328,6 +1327,7 @@ class Scheduler(SchedulerInterface): ...@@ -1328,6 +1327,7 @@ class Scheduler(SchedulerInterface):
# num_computed_tokens is decreased by the number of rejected # num_computed_tokens is decreased by the number of rejected
# tokens. # tokens.
request.num_computed_tokens -= num_rejected request.num_computed_tokens -= num_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
......
...@@ -85,6 +85,7 @@ class Request: ...@@ -85,6 +85,7 @@ class Request:
self.num_output_placeholders = 0 # Used in async scheduling. self.num_output_placeholders = 0 # Used in async scheduling.
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt self.cache_salt: Optional[str] = cache_salt
# Multi-modal related # Multi-modal related
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment