Commit b4610c06 authored by 王敏's avatar 王敏
Browse files

[fix]修复eagle 创建cu_num_tokens类型错误问题

parent 2071c380
...@@ -438,7 +438,7 @@ class EagleProposer: ...@@ -438,7 +438,7 @@ class EagleProposer:
# [batch_size] # [batch_size]
num_accepted_tokens_tensor: torch.Tensor, num_accepted_tokens_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device) cu_num_tokens = torch.arange(cu_target_query_lens.shape[0], device=cu_target_query_lens.device, dtype=torch.int32)
token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1] token_indices = num_accepted_tokens_tensor + cu_target_query_lens[:-1]
return cu_num_tokens, token_indices return cu_num_tokens, token_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment