Commit 2be9c33c authored by lixh6's avatar lixh6
Browse files

[BUGFIX]解决推测解码内核类型不匹配

parent fd2a4660
......@@ -79,7 +79,8 @@ def eagle_prepare_next_token_padded_kernel(
if is_discarded:
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
valid_count = tl.full((), 0, dtype=tl.uint32)
# valid_count = tl.full((), 0, dtype=tl.uint32)
valid_count = tl.cast(0, tl.uint32)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
else:
......@@ -92,7 +93,8 @@ def eagle_prepare_next_token_padded_kernel(
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
valid_count = tl.sum(is_valid_mask)
# valid_count = tl.sum(is_valid_mask)
valid_count = tl.cast(tl.sum(is_valid_mask), tl.uint32)
if valid_count > 0:
# Guaranteed to be well-defined since
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment