Unverified Commit 343f6523 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2][BugFix] fix num_sampled dtype for probabilistic rej… (#39951)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 19fa90ed
...@@ -527,7 +527,7 @@ def probabilistic_rejection_sample( ...@@ -527,7 +527,7 @@ def probabilistic_rejection_sample(
sampled = draft_sampled.new_empty( sampled = draft_sampled.new_empty(
num_reqs, num_speculative_steps + 1, dtype=torch.int64 num_reqs, num_speculative_steps + 1, dtype=torch.int64
) )
num_sampled = sampled.new_empty(num_reqs) num_sampled = sampled.new_empty(num_reqs, dtype=torch.int32)
target_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32) target_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
draft_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32) draft_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
_probabilistic_rejection_kernel[(num_reqs,)]( _probabilistic_rejection_kernel[(num_reqs,)](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment