Unverified Commit 92510edc authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

remove cuda check in `top_k_top_p_triton` kernel (#35011)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent a6c13752
......@@ -248,7 +248,7 @@ def apply_top_k_top_p(
if p is None and k is None:
return logits
if HAS_TRITON and logits.shape[0] >= 8 and logits.is_cuda:
if HAS_TRITON and logits.shape[0] >= 8:
return apply_top_k_top_p_triton(logits, k, p)
# Use pytorch sort implementation for small batch sizes.
......
......@@ -967,7 +967,6 @@ def apply_top_k_top_p_triton(
"""
assert logits.ndim == 2
assert logits.dtype == torch.float32
assert logits.is_cuda
batch_size, vocab_size = logits.shape
......@@ -978,13 +977,13 @@ def apply_top_k_top_p_triton(
return logits
if k is not None:
assert k.ndim == 1 and k.shape[0] == batch_size and k.is_cuda
assert k.ndim == 1 and k.shape[0] == batch_size
k_ptr = k.to(torch.int32)
else:
k_ptr = logits # Dummy pointer (won't be read)
if p is not None:
assert p.ndim == 1 and p.shape[0] == batch_size and p.is_cuda
assert p.ndim == 1 and p.shape[0] == batch_size
p_ptr = p.to(torch.float32)
else:
p_ptr = logits # Dummy pointer (won't be read)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment