Unverified Commit 24b7fb45 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

[Spec Decode] Fix input triton kernel for eagle (#15909)

parent 58f5a597
......@@ -250,13 +250,12 @@ def prepare_input_kernel(
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
indices = index_start + tl.arange(0, BLOCK_SIZE)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
indices,
index_start + offset,
mask=offset < num_tokens,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment