Unverified Commit f919d4cb authored by Eugene Cheah's avatar Eugene Cheah Committed by GitHub
Browse files

[BugFix] Fix logits repetition penalty cuda check (#22592)

parent afa5b7ca
...@@ -311,7 +311,7 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, ...@@ -311,7 +311,7 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: A boolean tensor indicating which tokens appear in the output. output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ). repetition_penalties: The repetition penalties of shape (num_seqs, ).
""" """
if current_platform.is_cuda() and logits.is_contiguous(): if logits.is_cuda and logits.is_contiguous():
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
repetition_penalties) repetition_penalties)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment