Unverified Commit 6a9d6ca3 authored by zyksir's avatar zyksir Committed by GitHub
Browse files

fix unexcepted answer in EAGLE mode (#9252)

parent 94371dbb
......@@ -177,11 +177,24 @@ class EagleDraftInput:
)
return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor):
self.topk_p = self.topk_p[: len(new_indices)]
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
if has_been_filtered:
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
if len(new_indices) != len(self.topk_p):
logger.warning(
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
)
self.topk_p = self.topk_p[: len(new_indices)]
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
self.topk_index = self.topk_index[new_indices]
self.hidden_states = self.hidden_states[new_indices]
self.verified_id = self.verified_id[new_indices]
def merge_batch(self, spec_info: EagleDraftInput):
if self.hidden_states is None:
......
......@@ -836,6 +836,21 @@ class EAGLEWorker(TpModelWorker):
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
has_finished, unfinished_req_index = False, []
for i, req in enumerate(batch.reqs):
if req.finished():
has_finished = True
else:
unfinished_req_index.append(i)
if has_finished:
unfinished_index_device = torch.tensor(
unfinished_req_index,
dtype=torch.int64,
device=batch.spec_info.topk_p.device,
)
batch.spec_info.filter_batch(
unfinished_index_device, has_been_filtered=False
)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
assert isinstance(batch.spec_info, EagleDraftInput)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment