Unverified Commit 6d2d0ce2 authored by Ziming Huang's avatar Ziming Huang Committed by GitHub
Browse files

[PD] Improve eagle acceptance rate by transferring draft model hidden states (#10801)


Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent 271d3d0d
...@@ -430,24 +430,12 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -430,24 +430,12 @@ class SchedulerDisaggregationPrefillMixin:
self.tree_cache.cache_unfinished_req(req) # update the tree and lock self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD) req.add_latency(RequestStage.PREFILL_FORWARD)
self.disagg_prefill_inflight_queue.append(req) self.disagg_prefill_inflight_queue.append(req)
if ( if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
logits_output is not None
and logits_output.hidden_states is not None
):
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.output_topk_p = batch.spec_info.topk_p[i] req.output_topk_p = batch.spec_info.topk_p[i]
req.output_topk_index = batch.spec_info.topk_index[i] req.output_topk_index = batch.spec_info.topk_index[i]
if self.spec_algorithm.is_eagle3(): req.hidden_states_tensor = (
req.hidden_states_tensor = ( batch.spec_info.hidden_states[i].cpu().clone()
batch.spec_info.hidden_states[i].cpu().clone() )
)
else:
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
hidden_state_offset += extend_input_len_per_req[i]
else: else:
req.hidden_states_tensor = None req.hidden_states_tensor = None
if req.return_logprob: if req.return_logprob:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment