Unverified Commit f5b34a51 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

Bugfix: Fix Type consistency for KV indices in SWARadixCache (#11452)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 5a6ec8f9
......@@ -449,11 +449,13 @@ class SWARadixCache(BasePrefixCache):
if self.page_size != 1:
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
......@@ -502,10 +504,12 @@ class SWARadixCache(BasePrefixCache):
if self.page_size != 1:
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment