Unverified Commit 19818b9c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Minor: style improvement of radix_cache and memory_pool (#395)

parent 9216b106
...@@ -236,9 +236,8 @@ class Batch: ...@@ -236,9 +236,8 @@ class Batch:
extend_num_tokens = seq_lens.sum() - prefix_lens.sum() extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None: if out_cache_loc is None:
if not self.tree_cache.disable: self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free) out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None: if out_cache_loc is None:
print("Prefill out of memory. This should never happen.") print("Prefill out of memory. This should never happen.")
...@@ -307,8 +306,8 @@ class Batch: ...@@ -307,8 +306,8 @@ class Batch:
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool.available_size() >= bs:
return True return True
if not self.tree_cache.disable: self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool.available_size() >= bs:
return True return True
...@@ -341,7 +340,7 @@ class Batch: ...@@ -341,7 +340,7 @@ class Batch:
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx] req_pool_indices_np[idx]
][: seq_lens_np[idx]] ][: seq_lens_np[idx]]
self.token_to_kv_pool.free(token_indices) self.token_to_kv_pool.dec_refs(token_indices)
self.filter_batch(sorted_indices) self.filter_batch(sorted_indices)
...@@ -372,7 +371,7 @@ class Batch: ...@@ -372,7 +371,7 @@ class Batch:
prefix_len = self.tree_cache.insert( prefix_len = self.tree_cache.insert(
token_ids_in_memory, indices.clone() token_ids_in_memory, indices.clone()
) )
self.token_to_kv_pool.free(indices[:prefix_len]) self.token_to_kv_pool.dec_refs(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx) self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node) self.tree_cache.dec_ref_counter(req.last_node)
......
...@@ -113,7 +113,7 @@ class ModelRpcServer: ...@@ -113,7 +113,7 @@ class ModelRpcServer:
logger.info(server_args.get_optional_modes_logging()) logger.info(server_args.get_optional_modes_logging())
# Init cache # Init cache
self.tree_cache = RadixCache(server_args.disable_radix_cache) self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler( self.scheduler = Scheduler(
self.schedule_heuristic, self.schedule_heuristic,
...@@ -628,7 +628,7 @@ class ModelRpcServer: ...@@ -628,7 +628,7 @@ class ModelRpcServer:
token_ids[:seq_len], indices.clone() token_ids[:seq_len], indices.clone()
) )
self.token_to_kv_pool.free(indices[:prefix_len]) self.token_to_kv_pool.dec_refs(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx) self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node) self.tree_cache.dec_ref_counter(req.last_node)
......
import heapq import heapq
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple
import torch import torch
...@@ -16,23 +14,23 @@ class TreeNode: ...@@ -16,23 +14,23 @@ class TreeNode:
self.ref_counter = 0 self.ref_counter = 0
self.last_access_time = time.time() self.last_access_time = time.time()
def __lt__(self, other): def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
def match(key, seq): def _key_match(key0, key1):
i = 0 i = 0
for k, w in zip(key, seq): for k0, k1 in zip(key0, key1):
if k != w: if k0 != k1:
break break
i += 1 i += 1
return i return i
class RadixCache: class RadixCache:
def __init__(self, disable=False): def __init__(self, disable: bool = False):
self.reset()
self.disable = disable self.disable = disable
self.reset()
##### Public API ##### ##### Public API #####
...@@ -71,7 +69,7 @@ class RadixCache: ...@@ -71,7 +69,7 @@ class RadixCache:
def evict(self, num_tokens, evict_callback): def evict(self, num_tokens, evict_callback):
if self.disable: if self.disable:
raise RuntimeError() return
leaves = self._collect_leaves() leaves = self._collect_leaves()
heapq.heapify(leaves) heapq.heapify(leaves)
...@@ -115,6 +113,7 @@ class RadixCache: ...@@ -115,6 +113,7 @@ class RadixCache:
return self.evictable_size_ return self.evictable_size_
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node): def _match_prefix_helper(self, node, key, value, last_node):
node.last_access_time = time.time() node.last_access_time = time.time()
if len(key) == 0: if len(key) == 0:
...@@ -122,7 +121,7 @@ class RadixCache: ...@@ -122,7 +121,7 @@ class RadixCache:
if key[0] in node.children.keys(): if key[0] in node.children.keys():
child = node.children[key[0]] child = node.children[key[0]]
prefix_len = match(child.key, key) prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value) value.append(new_node.value)
...@@ -153,7 +152,7 @@ class RadixCache: ...@@ -153,7 +152,7 @@ class RadixCache:
if key[0] in node.children.keys(): if key[0] in node.children.keys():
child = node.children[key[0]] child = node.children[key[0]]
prefix_len = match(child.key, key) prefix_len = _key_match(child.key, key)
if prefix_len == len(child.key): if prefix_len == len(child.key):
if prefix_len == len(key): if prefix_len == len(key):
...@@ -212,7 +211,7 @@ class RadixCache: ...@@ -212,7 +211,7 @@ class RadixCache:
if __name__ == "__main__": if __name__ == "__main__":
tree = RadixCache(disable=False) tree = RadixCache()
tree.insert("Hello") tree.insert("Hello")
tree.insert("Hello") tree.insert("Hello")
......
...@@ -31,9 +31,6 @@ class ReqToTokenPool: ...@@ -31,9 +31,6 @@ class ReqToTokenPool:
self.can_use_mem_size += free_index.shape[0] self.can_use_mem_size += free_index.shape[0]
self.mem_state[free_index] = 1 self.mem_state[free_index] = 1
# if self.can_use_mem_size == len(self.mem_state):
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
def clear(self): def clear(self):
self.mem_state.fill_(1) self.mem_state.fill_(1)
self.can_use_mem_size = len(self.mem_state) self.can_use_mem_size = len(self.mem_state)
...@@ -42,7 +39,7 @@ class ReqToTokenPool: ...@@ -42,7 +39,7 @@ class ReqToTokenPool:
class TokenToKVPool: class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num): def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda") self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
self.alloc_ct = 0 self.total_ref_ct = 0
# [size, key/value, head_num, head_dim] for each layer # [size, key/value, head_num, head_dim] for each layer
self.kv_data = [ self.kv_data = [
...@@ -83,9 +80,6 @@ class TokenToKVPool: ...@@ -83,9 +80,6 @@ class TokenToKVPool:
self.add_refs(select_index) self.add_refs(select_index)
return select_index.to(torch.int32), start_loc, start_loc + need_size return select_index.to(torch.int32), start_loc, start_loc + need_size
def free(self, free_index):
return self.decrease_refs(free_index)
def used_size(self): def used_size(self):
return len(torch.nonzero(self.mem_state).squeeze(1)) return len(torch.nonzero(self.mem_state).squeeze(1))
...@@ -93,20 +87,17 @@ class TokenToKVPool: ...@@ -93,20 +87,17 @@ class TokenToKVPool:
return torch.sum(self.mem_state == 0).item() return torch.sum(self.mem_state == 0).item()
def add_refs(self, token_index: torch.Tensor): def add_refs(self, token_index: torch.Tensor):
self.alloc_ct += len(token_index) self.total_ref_ct += len(token_index)
self.mem_state[token_index] += 1 self.mem_state[token_index] += 1
def decrease_refs(self, token_index: torch.Tensor): def dec_refs(self, token_index: torch.Tensor):
self.alloc_ct -= len(token_index) self.total_ref_ct -= len(token_index)
self.mem_state[token_index] -= 1 self.mem_state[token_index] -= 1
num_freed = torch.sum(self.mem_state[token_index] == 0) num_freed = torch.sum(self.mem_state[token_index] == 0)
# if self.alloc_ct == 0:
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
return num_freed return num_freed
def clear(self): def clear(self):
self.mem_state.fill_(0) self.mem_state.fill_(0)
self.alloc_ct = 0 self.total_ref_ct = 0
...@@ -500,7 +500,7 @@ async def v1_chat_completions(raw_request: Request): ...@@ -500,7 +500,7 @@ async def v1_chat_completions(raw_request: Request):
return response return response
def launch_server(server_args, pipe_finish_writer): def launch_server(server_args: ServerArgs, pipe_finish_writer):
global tokenizer_manager global tokenizer_manager
global chat_template_name global chat_template_name
......
...@@ -105,7 +105,7 @@ def test_generate_worker( ...@@ -105,7 +105,7 @@ def test_generate_worker(
for i in range(batch_size): for i in range(batch_size):
req_idx = req_pool_indices[i].item() req_idx = req_pool_indices[i].item()
model.token_to_kv_pool.free( model.token_to_kv_pool.dec_refs(
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]] model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
) )
model.req_to_token_pool.free(req_pool_indices) model.req_to_token_pool.free(req_pool_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment