Commit 9ebe3034 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix precision issue in mtp

parent 3cec42d1
...@@ -40,6 +40,7 @@ class CachedRequestState: ...@@ -40,6 +40,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
output_token_ids: list[int] output_token_ids: list[int]
spec_token_ids: list[int] = None
mrope_positions: Optional[torch.Tensor] = None mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None mrope_position_delta: Optional[int] = None
...@@ -303,9 +304,16 @@ class InputBatch: ...@@ -303,9 +304,16 @@ class InputBatch:
end_idx = start_idx + len(request.output_token_ids) end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids start_idx:end_idx] = request.output_token_ids
num_spec_tokens = 0
if request.spec_token_ids != None:
num_spec_tokens = len(request.spec_token_ids)
self.token_ids_cpu[req_index,
end_idx:end_idx + num_spec_tokens] = request.spec_token_ids
# Number of token ids in token_ids_cpu. # Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens. # NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens self.num_tokens[req_index] = request.num_tokens + num_spec_tokens
# Number of tokens without spec decode tokens. # Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens self.num_tokens_no_spec[req_index] = request.num_tokens
......
...@@ -555,6 +555,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -555,6 +555,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if not is_last_rank: if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
...@@ -571,6 +573,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -571,6 +573,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif num_new_tokens > 0: elif num_new_tokens > 0:
req_state.output_token_ids.extend( req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:]) new_token_ids[-num_new_tokens:])
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
# Update the block IDs. # Update the block IDs.
if not resumed_from_preemption: if not resumed_from_preemption:
...@@ -610,8 +614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -610,8 +614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_batch.num_tokens[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids: if spec_token_ids:
num_spec_tokens = len(spec_token_ids) num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index] start_index = self.input_batch.num_tokens_no_spec[req_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment