Unverified Commit 92510edc authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

remove cuda check in `top_k_top_p_triton` kernel (#35011)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent a6c13752
...@@ -248,7 +248,7 @@ def apply_top_k_top_p( ...@@ -248,7 +248,7 @@ def apply_top_k_top_p(
if p is None and k is None: if p is None and k is None:
return logits return logits
if HAS_TRITON and logits.shape[0] >= 8 and logits.is_cuda: if HAS_TRITON and logits.shape[0] >= 8:
return apply_top_k_top_p_triton(logits, k, p) return apply_top_k_top_p_triton(logits, k, p)
# Use pytorch sort implementation for small batch sizes. # Use pytorch sort implementation for small batch sizes.
......
...@@ -967,7 +967,6 @@ def apply_top_k_top_p_triton( ...@@ -967,7 +967,6 @@ def apply_top_k_top_p_triton(
""" """
assert logits.ndim == 2 assert logits.ndim == 2
assert logits.dtype == torch.float32 assert logits.dtype == torch.float32
assert logits.is_cuda
batch_size, vocab_size = logits.shape batch_size, vocab_size = logits.shape
...@@ -978,13 +977,13 @@ def apply_top_k_top_p_triton( ...@@ -978,13 +977,13 @@ def apply_top_k_top_p_triton(
return logits return logits
if k is not None: if k is not None:
assert k.ndim == 1 and k.shape[0] == batch_size and k.is_cuda assert k.ndim == 1 and k.shape[0] == batch_size
k_ptr = k.to(torch.int32) k_ptr = k.to(torch.int32)
else: else:
k_ptr = logits # Dummy pointer (won't be read) k_ptr = logits # Dummy pointer (won't be read)
if p is not None: if p is not None:
assert p.ndim == 1 and p.shape[0] == batch_size and p.is_cuda assert p.ndim == 1 and p.shape[0] == batch_size
p_ptr = p.to(torch.float32) p_ptr = p.to(torch.float32)
else: else:
p_ptr = logits # Dummy pointer (won't be read) p_ptr = logits # Dummy pointer (won't be read)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment