Unverified Commit 71b54eea authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Add cache metrics (#119)

parent 74b3bfaa
"""Base cache class."""
import time
class BaseCache:
def __init__(self, enable=True):
self.enable = enable
self.reset()
def reset(self):
self.cache = {}
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
def query(self, key):
def _init_with_timer(key):
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
curr_total = self.metrics["total"]
new_total = curr_total + 1
# Update average init time without old_avg * old_total to avoid overflow.
self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total
) * self.metrics["avg_init_time"]
self.metrics["total"] += 1
return val
if key in self.cache:
self.metrics["hit"] += 1
val = self.cache[key]
else:
# Cache miss or disabled.
val = _init_with_timer(key)
if self.enable:
self.cache[key] = val
return val
def init_value(self, key):
raise NotImplementedError
def get_cache_hit_rate(self):
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
def get_avg_init_time(self):
return self.metrics["avg_init_time"]
import interegular
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.disk_cache import disk_cache
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
......@@ -56,15 +57,12 @@ class FastForwardMap:
return fast_forward_str, next_state
class FastForwardCache:
class FastForwardCache(BaseCache):
def __init__(self):
self.cache = {}
super().__init__()
def init_fast_forward_map(self, regex_string):
if regex_string not in self.cache:
fast_forward_map = FastForwardMap(regex_string)
self.cache[regex_string] = fast_forward_map
return self.cache[regex_string]
def init_value(self, regex):
return FastForwardMap(regex)
def test_main():
......
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
_enable_memory_cache = True
class FSMCache:
def __init__(self, tokenizer_path, tokenizer_args_dict):
self.cache = {}
class FSMCache(BaseCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable)
self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict
)
def init_fsm(self, regex):
if _enable_memory_cache:
if regex not in self.cache:
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
return self.cache[regex]
def init_value(self, regex):
return RegexFSM(regex, self.outlines_tokenizer)
......@@ -60,7 +60,11 @@ class Req:
def tokenize_fast_forward(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"):
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
first_token = first_token.decode() if isinstance(first_token, bytes) else first_token
if first_token.startswith("▁"):
old_output_str = " " + old_output_str
new_input_string = (
self.input_text
......
......@@ -4,8 +4,7 @@ import multiprocessing
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union
from typing import List
import numpy as np
import rpyc
......@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
# Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.schedule_heuristic,
self.max_num_running_seq,
......@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
......@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.no_regex_fast_forward:
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map(
req.fast_forward_map = self.fast_forward_cache.query(
req.sampling_params.regex
)
......@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
can_run_list = []
new_batch_total_tokens = 0
new_batch_input_tokens = 0
new_batch_prefix_tokens = 0
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
......@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
return None
if self.tp_rank == 0:
running_req = 0 if self.running_batch is None else len(self.running_batch.reqs)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
logger.info(
f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
f"#cached_token: {hit_tokens}. "
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
)
logger.debug(
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. "
f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. "
)
new_batch = Batch.init_new(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment