Unverified Commit fa82dfcc authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix EagleVerifyInput (#3378)

parent 5da3d21c
......@@ -177,29 +177,21 @@ class EagleVerifyInput:
spec_steps: int,
num_verify_token: int,
):
score_list = torch.cat(score_list, dim=1).flatten(
1
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
ss_token_list = torch.cat(
token_list, dim=1
) # b, (self.topk + (num_steps-1) * self.topk)
top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1)
parent_list = torch.cat(parents_list[:-1], dim=1)
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
parent_list,
top_scores_index,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_token,
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
build_tree_kernel(
verified_id,
score_list,
token_list,
parents_list,
seq_lens,
seq_lens_sum,
topk,
spec_steps,
num_verify_token,
)
)
return cls(
draft_tokens.flatten(),
draft_tokens,
tree_mask,
position,
retrive_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment