Unverified Commit f5bbf603 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Fix: Complete int32 to int64 conversion (#4465)

parent 5cbd709e
......@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
if value:
value = torch.cat(value)
else:
value = torch.tensor([], dtype=torch.int32)
value = torch.tensor([], dtype=torch.int64)
last_node_global = last_node
while last_node.evicted:
......
......@@ -622,11 +622,10 @@ class HostKVCache(abc.ABC):
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
self.free_slots = torch.arange(self.size, dtype=torch.int32)
self.can_use_mem_size = self.size
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
self.clear()
@abc.abstractmethod
def get_size_per_token(self):
......@@ -656,7 +655,7 @@ class HostKVCache(abc.ABC):
def clear(self):
self.mem_state.fill_(0)
self.can_use_mem_size = self.size
self.free_slots = torch.arange(self.size, dtype=torch.int32)
self.free_slots = torch.arange(self.size, dtype=torch.int64)
@synchronized
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
......
......@@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache):
return (
torch.empty(
(0,),
dtype=torch.int32,
dtype=torch.int64,
device=self.device,
),
self.root_node,
......@@ -154,7 +154,7 @@ class RadixCache(BasePrefixCache):
if value:
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int32, device=self.device)
value = torch.empty((0,), dtype=torch.int64, device=self.device)
return value, last_node
def insert(self, key: List, value=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment