Unverified Commit f5bbf603 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Fix: Complete int32 to int64 conversion (#4465)

parent 5cbd709e
...@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache): ...@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
if value: if value:
value = torch.cat(value) value = torch.cat(value)
else: else:
value = torch.tensor([], dtype=torch.int32) value = torch.tensor([], dtype=torch.int64)
last_node_global = last_node last_node_global = last_node
while last_node.evicted: while last_node.evicted:
......
...@@ -622,11 +622,10 @@ class HostKVCache(abc.ABC): ...@@ -622,11 +622,10 @@ class HostKVCache(abc.ABC):
self.mem_state = torch.zeros( self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device (self.size,), dtype=torch.uint8, device=self.device
) )
self.free_slots = torch.arange(self.size, dtype=torch.int32)
self.can_use_mem_size = self.size
# A lock for synchronized operations on memory allocation and state transitions. # A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock() self.lock = threading.RLock()
self.clear()
@abc.abstractmethod @abc.abstractmethod
def get_size_per_token(self): def get_size_per_token(self):
...@@ -656,7 +655,7 @@ class HostKVCache(abc.ABC): ...@@ -656,7 +655,7 @@ class HostKVCache(abc.ABC):
def clear(self): def clear(self):
self.mem_state.fill_(0) self.mem_state.fill_(0)
self.can_use_mem_size = self.size self.can_use_mem_size = self.size
self.free_slots = torch.arange(self.size, dtype=torch.int32) self.free_slots = torch.arange(self.size, dtype=torch.int64)
@synchronized @synchronized
def get_state(self, indices: torch.Tensor) -> MemoryStateInt: def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
......
...@@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache): ...@@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache):
return ( return (
torch.empty( torch.empty(
(0,), (0,),
dtype=torch.int32, dtype=torch.int64,
device=self.device, device=self.device,
), ),
self.root_node, self.root_node,
...@@ -154,7 +154,7 @@ class RadixCache(BasePrefixCache): ...@@ -154,7 +154,7 @@ class RadixCache(BasePrefixCache):
if value: if value:
value = torch.cat(value) value = torch.cat(value)
else: else:
value = torch.empty((0,), dtype=torch.int32, device=self.device) value = torch.empty((0,), dtype=torch.int64, device=self.device)
return value, last_node return value, last_node
def insert(self, key: List, value=None): def insert(self, key: List, value=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment