Unverified Commit 81561f8e authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Flush Cache API (#103)

parent 3a581e99
"""Flush cache in the backend by sending random requests."""
import argparse
import random
import string
import time
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
import sglang as sgl
@sgl.function
def flush_radix_cache(s, prompt):
s += prompt + sgl.gen("flush", max_tokens=1, stop="END")
def main(args, max_total_tokens, context_length, print_flag):
backend = select_sglang_backend(args)
flush_length = int(context_length * 0.8)
batch_size = int(max_total_tokens / flush_length)
prompt_length = flush_length * 2
prompts = [
" ".join(random.choices(string.ascii_letters, k=int(prompt_length)))
for _ in range(batch_size)
]
arguments = [{"prompt": prompts[i]} for i in range(batch_size)]
start_time = time.time()
flush_radix_cache.run_batch(
arguments, temperature=0, backend=backend, num_threads=1
)
end_time = time.time()
if print_flag:
print(
f"Flush length: {flush_length}\n",
f"Prompt length: {prompt_length}\n",
f"Total Prompt letters: {batch_size * prompt_length}\n",
f"Flush radix cache latency: {end_time - start_time:.3f}",
sep="",
)
# to prevent the backend still running
time.sleep(1)
def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False):
main(args, max_total_tokens, context_length, print_flag=print_flag)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-total-tokens", type=int, default=20000)
parser.add_argument("--context-length", type=int, default=1024)
args = add_common_sglang_args_and_parse(parser)
random.seed(0)
main(args, args.max_total_tokens, args.context_length, print_flag=True)
...@@ -87,3 +87,8 @@ class BatchStrOut: ...@@ -87,3 +87,8 @@ class BatchStrOut:
output_str: List[str] output_str: List[str]
meta_info: List[Dict] meta_info: List[Dict]
finished: List[bool] finished: List[bool]
@dataclass
class FlushCacheReq:
pass
...@@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer ...@@ -15,7 +15,11 @@ from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.fast_forward import FastForwardCache from sglang.srt.constrained.fast_forward import FastForwardCache
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
TokenizedGenerateReqInput,
FlushCacheReq,
)
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.router.radix_cache import RadixCache
...@@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service): ...@@ -127,6 +131,22 @@ class ModelRpcServer(rpyc.Service):
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0) self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
self.new_token_ratio_step = (0.0001, 0.05) # (down, up) self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
def flush_cache(self):
if len(self.forward_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
else:
warnings.warn(
"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
def exposed_step(self, recv_reqs): def exposed_step(self, recv_reqs):
if self.tp_size != 1: if self.tp_size != 1:
recv_reqs = obtain(recv_reqs) recv_reqs = obtain(recv_reqs)
...@@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service): ...@@ -136,6 +156,8 @@ class ModelRpcServer(rpyc.Service):
for recv_req in recv_reqs: for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput): if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req) self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
else: else:
raise ValueError(f"Invalid request: {recv_req}") raise ValueError(f"Invalid request: {recv_req}")
......
...@@ -30,14 +30,17 @@ def match(key, seq): ...@@ -30,14 +30,17 @@ def match(key, seq):
class RadixCache: class RadixCache:
def __init__(self, disable=False): def __init__(self, disable=False):
self.reset()
self.disable = disable
##### Public API #####
def reset(self):
self.root_node = TreeNode() self.root_node = TreeNode()
self.root_node.value = [] self.root_node.value = []
self.root_node.ref_counter = 1 self.root_node.ref_counter = 1
self.evictable_size_ = 0 self.evictable_size_ = 0
self.disable = disable
##### Public API #####
def match_prefix(self, key): def match_prefix(self, key):
if self.disable: if self.disable:
return [], self.root_node return [], self.root_node
......
...@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -20,6 +20,7 @@ from sglang.srt.managers.io_struct import (
BatchStrOut, BatchStrOut,
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
FlushCacheReq,
) )
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
...@@ -228,6 +229,10 @@ class TokenizerManager: ...@@ -228,6 +229,10 @@ class TokenizerManager:
yield output_list yield output_list
async def flush_cache(self):
flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req)
async def create_handle_loop(self): async def create_handle_loop(self):
self.to_create_loop = False self.to_create_loop = False
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
......
...@@ -71,6 +71,15 @@ async def get_model_info(): ...@@ -71,6 +71,15 @@ async def get_model_info():
return result return result
@app.get("/flush_cache")
async def flush_cache():
await tokenizer_manager.flush_cache()
return Response(
content="Cache flushed.\nPlease check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)\n",
status_code=200,
)
async def stream_generator(obj): async def stream_generator(obj):
async for out in tokenizer_manager.generate_request(obj): async for out in tokenizer_manager.generate_request(obj):
yield out yield out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment