Unverified Commit b730aa6b authored by 996_icu's avatar 996_icu Committed by GitHub
Browse files

[EAGLE] Fix some boundary situation when retract reqs and req's max token = 1 (#2939)


Co-authored-by: default avatarjosephyou <josephyou@tencent.com>
parent 60b2a44a
...@@ -1112,6 +1112,8 @@ class ScheduleBatch: ...@@ -1112,6 +1112,8 @@ class ScheduleBatch:
self.has_grammar = any(req.grammar for req in self.reqs) self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices) self.sampling_info.filter_batch(keep_indices, new_indices)
if self.spec_info:
self.spec_info.filter_batch(new_indices)
def merge_batch(self, other: "ScheduleBatch"): def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
......
...@@ -228,6 +228,14 @@ class EAGLEDraftInput(SpecInfo): ...@@ -228,6 +228,14 @@ class EAGLEDraftInput(SpecInfo):
assert len(batch.extend_lens) == 1 assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def filter_batch(
self,
new_indices: torch.Tensor,
):
self.sample_output = self.sample_output[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
def prepare_for_decode(self, batch: ScheduleBatch): def prepare_for_decode(self, batch: ScheduleBatch):
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
top = torch.topk(prob, self.topk, dim=-1) top = torch.topk(prob, self.topk, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment