Unverified Commit fa82dfcc authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix EagleVerifyInput (#3378)

parent 5da3d21c
...@@ -177,29 +177,21 @@ class EagleVerifyInput: ...@@ -177,29 +177,21 @@ class EagleVerifyInput:
spec_steps: int, spec_steps: int,
num_verify_token: int, num_verify_token: int,
): ):
score_list = torch.cat(score_list, dim=1).flatten( tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
1 build_tree_kernel(
) # b, n, topk; n= 1 + (num_steps-1) * self.topk verified_id,
ss_token_list = torch.cat( score_list,
token_list, dim=1 token_list,
) # b, (self.topk + (num_steps-1) * self.topk) parents_list,
top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1)
parent_list = torch.cat(parents_list[:-1], dim=1)
tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel(
parent_list,
top_scores_index,
seq_lens, seq_lens,
seq_lens_sum, seq_lens_sum,
topk, topk,
spec_steps, spec_steps,
num_verify_token, num_verify_token,
) )
)
return cls( return cls(
draft_tokens.flatten(), draft_tokens,
tree_mask, tree_mask,
position, position,
retrive_index, retrive_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment