Unverified Commit 71b54eea authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Add cache metrics (#119)

parent 74b3bfaa
"""Base cache class."""
import time
class BaseCache:
def __init__(self, enable=True):
self.enable = enable
self.reset()
def reset(self):
self.cache = {}
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
def query(self, key):
def _init_with_timer(key):
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
curr_total = self.metrics["total"]
new_total = curr_total + 1
# Update average init time without old_avg * old_total to avoid overflow.
self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total
) * self.metrics["avg_init_time"]
self.metrics["total"] += 1
return val
if key in self.cache:
self.metrics["hit"] += 1
val = self.cache[key]
else:
# Cache miss or disabled.
val = _init_with_timer(key)
if self.enable:
self.cache[key] = val
return val
def init_value(self, key):
raise NotImplementedError
def get_cache_hit_rate(self):
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
def get_avg_init_time(self):
return self.metrics["avg_init_time"]
import interegular import interegular
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.disk_cache import disk_cache from sglang.srt.constrained.disk_cache import disk_cache
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
...@@ -56,15 +57,12 @@ class FastForwardMap: ...@@ -56,15 +57,12 @@ class FastForwardMap:
return fast_forward_str, next_state return fast_forward_str, next_state
class FastForwardCache: class FastForwardCache(BaseCache):
def __init__(self): def __init__(self):
self.cache = {} super().__init__()
def init_fast_forward_map(self, regex_string): def init_value(self, regex):
if regex_string not in self.cache: return FastForwardMap(regex)
fast_forward_map = FastForwardMap(regex_string)
self.cache[regex_string] = fast_forward_map
return self.cache[regex_string]
def test_main(): def test_main():
......
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.fsm import RegexFSM from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer from sglang.srt.constrained.tokenizer import TransformerTokenizer
_enable_memory_cache = True
class FSMCache(BaseCache):
class FSMCache: def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
def __init__(self, tokenizer_path, tokenizer_args_dict): super().__init__(enable=enable)
self.cache = {}
self.outlines_tokenizer = TransformerTokenizer( self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict tokenizer_path, **tokenizer_args_dict
) )
def init_fsm(self, regex): def init_value(self, regex):
if _enable_memory_cache:
if regex not in self.cache:
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
return self.cache[regex]
return RegexFSM(regex, self.outlines_tokenizer) return RegexFSM(regex, self.outlines_tokenizer)
...@@ -60,7 +60,11 @@ class Req: ...@@ -60,7 +60,11 @@ class Req:
def tokenize_fast_forward(self, fast_forward_str, next_state): def tokenize_fast_forward(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids) old_output_str = self.tokenizer.decode(self.output_ids)
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"): # FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
first_token = first_token.decode() if isinstance(first_token, bytes) else first_token
if first_token.startswith("▁"):
old_output_str = " " + old_output_str old_output_str = " " + old_output_str
new_input_string = ( new_input_string = (
self.input_text self.input_text
......
...@@ -4,8 +4,7 @@ import multiprocessing ...@@ -4,8 +4,7 @@ import multiprocessing
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto from typing import List
from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import rpyc import rpyc
...@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
# Init cache # Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode) self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler( self.scheduler = Scheduler(
self.schedule_heuristic, self.schedule_heuristic,
self.max_num_running_seq, self.max_num_running_seq,
...@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service): ...@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
self.running_batch is None or len(self.running_batch.reqs) == 0 self.running_batch is None or len(self.running_batch.reqs) == 0
): ):
self.tree_cache.reset() self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
self.token_to_kv_pool.clear() self.token_to_kv_pool.clear()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm # Init regex fsm
if req.sampling_params.regex is not None: if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex) req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.no_regex_fast_forward: if not self.no_regex_fast_forward:
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map( req.fast_forward_map = self.fast_forward_cache.query(
req.sampling_params.regex req.sampling_params.regex
) )
...@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service): ...@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
can_run_list = [] can_run_list = []
new_batch_total_tokens = 0 new_batch_total_tokens = 0
new_batch_input_tokens = 0 new_batch_input_tokens = 0
new_batch_prefix_tokens = 0
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
...@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service): ...@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
return None return None
if self.tp_rank == 0: if self.tp_rank == 0:
running_req = 0 if self.running_batch is None else len(self.running_batch.reqs)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (hit_tokens + new_batch_input_tokens) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
logger.info( logger.info(
f"new fill batch. #seq: {len(can_run_list)}. " f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. " f"#cached_token: {hit_tokens}. "
f"#new_token: {new_batch_input_tokens}. " f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. " f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
)
logger.debug(
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. "
f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. "
) )
new_batch = Batch.init_new( new_batch = Batch.init_new(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment