Unverified Commit e99729c9 authored by jiapingW's avatar jiapingW Committed by GitHub
Browse files

Fixed the issue where eagle3 TPOT was not as good as without eagle3. (#9404)

parent c10b8e6a
...@@ -453,12 +453,13 @@ class EagleVerifyInput: ...@@ -453,12 +453,13 @@ class EagleVerifyInput:
sampling_info.top_ks, self.draft_token_num, dim=0 sampling_info.top_ks, self.draft_token_num, dim=0
), ),
) # (bs * draft_token_num, vocab_size) ) # (bs * draft_token_num, vocab_size)
target_probs = top_p_renorm_prob( if not torch.all(sampling_info.top_ps == 1.0):
target_probs, target_probs = top_p_renorm_prob(
torch.repeat_interleave( target_probs,
sampling_info.top_ps, self.draft_token_num, dim=0 torch.repeat_interleave(
), sampling_info.top_ps, self.draft_token_num, dim=0
) ),
)
target_probs = target_probs.reshape(bs, self.draft_token_num, -1) target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
draft_probs = torch.zeros( draft_probs = torch.zeros(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment