Unverified Commit 6efb195a authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

[V1] Fix: make sure `k_index` is int64 for `apply_top_k_only` (#15907)


Signed-off-by: default avatarBrayden Zhong <b8zhong@uwaterloo.ca>
parent 24b7fb45
...@@ -200,7 +200,7 @@ def apply_top_k_only( ...@@ -200,7 +200,7 @@ def apply_top_k_only(
# topk.values tensor has shape [batch_size, max_top_k]. # topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k). # Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1) k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows. # Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf")) logits.masked_fill_(logits < top_k_mask, -float("inf"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment