Unverified Commit 19818b9c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Minor: style improvement of radix_cache and memory_pool (#395)

parent 9216b106
......@@ -236,9 +236,8 @@ class Batch:
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
if not self.tree_cache.disable:
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
print("Prefill out of memory. This should never happen.")
......@@ -307,8 +306,8 @@ class Batch:
if self.token_to_kv_pool.available_size() >= bs:
return True
if not self.tree_cache.disable:
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
if self.token_to_kv_pool.available_size() >= bs:
return True
......@@ -341,7 +340,7 @@ class Batch:
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx]
][: seq_lens_np[idx]]
self.token_to_kv_pool.free(token_indices)
self.token_to_kv_pool.dec_refs(token_indices)
self.filter_batch(sorted_indices)
......@@ -372,7 +371,7 @@ class Batch:
prefix_len = self.tree_cache.insert(
token_ids_in_memory, indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len])
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
......
......@@ -113,7 +113,7 @@ class ModelRpcServer:
logger.info(server_args.get_optional_modes_logging())
# Init cache
self.tree_cache = RadixCache(server_args.disable_radix_cache)
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.schedule_heuristic,
......@@ -628,7 +628,7 @@ class ModelRpcServer:
token_ids[:seq_len], indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len])
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
......
import heapq
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple
import torch
......@@ -16,23 +14,23 @@ class TreeNode:
self.ref_counter = 0
self.last_access_time = time.time()
def __lt__(self, other):
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
def match(key, seq):
def _key_match(key0, key1):
i = 0
for k, w in zip(key, seq):
if k != w:
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i
class RadixCache:
def __init__(self, disable=False):
self.reset()
def __init__(self, disable: bool = False):
self.disable = disable
self.reset()
##### Public API #####
......@@ -71,7 +69,7 @@ class RadixCache:
def evict(self, num_tokens, evict_callback):
if self.disable:
raise RuntimeError()
return
leaves = self._collect_leaves()
heapq.heapify(leaves)
......@@ -115,6 +113,7 @@ class RadixCache:
return self.evictable_size_
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
node.last_access_time = time.time()
if len(key) == 0:
......@@ -122,7 +121,7 @@ class RadixCache:
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = match(child.key, key)
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
......@@ -153,7 +152,7 @@ class RadixCache:
if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = match(child.key, key)
prefix_len = _key_match(child.key, key)
if prefix_len == len(child.key):
if prefix_len == len(key):
......@@ -212,7 +211,7 @@ class RadixCache:
if __name__ == "__main__":
tree = RadixCache(disable=False)
tree = RadixCache()
tree.insert("Hello")
tree.insert("Hello")
......
......@@ -31,9 +31,6 @@ class ReqToTokenPool:
self.can_use_mem_size += free_index.shape[0]
self.mem_state[free_index] = 1
# if self.can_use_mem_size == len(self.mem_state):
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
def clear(self):
self.mem_state.fill_(1)
self.can_use_mem_size = len(self.mem_state)
......@@ -42,7 +39,7 @@ class ReqToTokenPool:
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
self.alloc_ct = 0
self.total_ref_ct = 0
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
......@@ -83,9 +80,6 @@ class TokenToKVPool:
self.add_refs(select_index)
return select_index.to(torch.int32), start_loc, start_loc + need_size
def free(self, free_index):
return self.decrease_refs(free_index)
def used_size(self):
return len(torch.nonzero(self.mem_state).squeeze(1))
......@@ -93,20 +87,17 @@ class TokenToKVPool:
return torch.sum(self.mem_state == 0).item()
def add_refs(self, token_index: torch.Tensor):
self.alloc_ct += len(token_index)
self.total_ref_ct += len(token_index)
self.mem_state[token_index] += 1
def decrease_refs(self, token_index: torch.Tensor):
self.alloc_ct -= len(token_index)
def dec_refs(self, token_index: torch.Tensor):
self.total_ref_ct -= len(token_index)
self.mem_state[token_index] -= 1
num_freed = torch.sum(self.mem_state[token_index] == 0)
# if self.alloc_ct == 0:
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
return num_freed
def clear(self):
self.mem_state.fill_(0)
self.alloc_ct = 0
self.total_ref_ct = 0
......@@ -500,7 +500,7 @@ async def v1_chat_completions(raw_request: Request):
return response
def launch_server(server_args, pipe_finish_writer):
def launch_server(server_args: ServerArgs, pipe_finish_writer):
global tokenizer_manager
global chat_template_name
......
......@@ -105,7 +105,7 @@ def test_generate_worker(
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.token_to_kv_pool.free(
model.token_to_kv_pool.dec_refs(
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
)
model.req_to_token_pool.free(req_pool_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment