Commit 7e3e2339 authored by jujl1's avatar jujl1
Browse files

fix: pp+chunkprefill多并发input ids更新bug

parent 794553fd
...@@ -1101,13 +1101,14 @@ class Scheduler(SchedulerInterface): ...@@ -1101,13 +1101,14 @@ class Scheduler(SchedulerInterface):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
num_tokens = req.num_generated_token_ids num_tokens = req.num_generated_token_ids
if self.use_pp: if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner # need to send the sampled tokens back because the model runner
# will cache them. # will cache them.
token_ids = req.all_token_ids[-num_tokens:] token_ids = req.all_token_ids[-num_tokens:] if num_tokens > 0 else []
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
...@@ -1241,7 +1242,7 @@ class Scheduler(SchedulerInterface): ...@@ -1241,7 +1242,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1 request.num_generated_token_ids = len(generated_token_ids)
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens # num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled
...@@ -1253,7 +1254,6 @@ class Scheduler(SchedulerInterface): ...@@ -1253,7 +1254,6 @@ class Scheduler(SchedulerInterface):
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids)) len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids), num_draft_tokens=len(scheduled_spec_token_ids),
......
...@@ -512,14 +512,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -512,14 +512,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
new_token_ids = req_data.new_token_ids[i] new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens. # This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) - num_new_tokens = len(new_token_ids)
req_state.num_tokens) if num_new_tokens > 0:
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend( req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:]) new_token_ids)
if len(spec_token_ids) > 0: if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids req_state.spec_token_ids = spec_token_ids
...@@ -539,8 +535,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -539,8 +535,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# The request is not in the persistent batch. # The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not # The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again. # scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id) if not is_last_rank:
continue req_state = self.requests[req_id]
self.input_batch.add_request(req_state)
req_index = self.input_batch.req_id_to_index.get(req_id)
else:
req_ids_to_add.append(req_id)
continue
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
...@@ -552,13 +553,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -552,13 +553,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if not is_last_rank: if not is_last_rank:
# Add new_token_ids to token_ids_cpu. # Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + 1 if len(new_token_ids) > 0:
self.input_batch.token_ids_cpu[ end_token_index = num_computed_tokens + 1
req_index, self.input_batch.token_ids_cpu[
start_token_index:end_token_index] = new_token_ids[-1] req_index,
self.input_batch.num_tokens_no_spec[ start_token_index:end_token_index] = new_token_ids[-1]
req_index] = end_token_index self.input_batch.num_tokens_no_spec[
self.input_batch.num_tokens[req_index] = end_token_index req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
if spec_token_ids: if spec_token_ids:
...@@ -1276,7 +1278,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1276,7 +1278,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# prefills, causing unnecessary and excessive padding of activations. # prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager: if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# auto # auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto": if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit. # Early exit.
return 0, None return 0, None
...@@ -2089,7 +2091,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2089,7 +2091,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection This is to help balance expert-selection
- during profile_run - during profile_run
- during DP rank dummy run - during DP rank dummy run
""" """
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
...@@ -3481,7 +3483,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3481,7 +3483,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states[:num_scheduled_tokens], hidden_states[:num_scheduled_tokens],
scheduler_output, scheduler_output,
) )
#-----------------------------------
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment