Unverified Commit b1a3a454 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

add `--disable-disk-cache` (#160)


Co-authored-by: default avatarJa1Zhou <50169346+Ja1Zhou@users.noreply.github.com>
parent 79e6b84b
...@@ -49,7 +49,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -49,7 +49,7 @@ class ModelRpcServer(rpyc.Service):
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.no_regex_jump_forward = server_args.no_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
...@@ -254,7 +254,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -254,7 +254,7 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm # Init regex fsm
if req.sampling_params.regex is not None: if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.no_regex_jump_forward: if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query( req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex req.sampling_params.regex
) )
...@@ -451,7 +451,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -451,7 +451,7 @@ class ModelRpcServer(rpyc.Service):
self.min_new_token_ratio, self.min_new_token_ratio,
) )
if not self.no_regex_jump_forward: if not self.disable_regex_jump_forward:
# check for jump-forward # check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward() jump_forward_reqs = batch.check_for_jump_forward()
......
...@@ -21,6 +21,7 @@ from fastapi import FastAPI, HTTPException, Request ...@@ -21,6 +21,7 @@ from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained.disk_cache import disable_cache
from sglang.srt.conversation import ( from sglang.srt.conversation import (
Conversation, Conversation,
SeparatorStyle, SeparatorStyle,
...@@ -372,6 +373,10 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -372,6 +373,10 @@ def launch_server(server_args, pipe_finish_writer):
global tokenizer_manager global tokenizer_manager
global chat_template_name global chat_template_name
# disable disk cache if needed
if server_args.disable_disk_cache:
disable_cache()
# Handle ports # Handle ports
server_args.port, server_args.additional_ports = handle_port_init( server_args.port, server_args.additional_ports = handle_port_init(
server_args.port, server_args.additional_ports, server_args.tp_size server_args.port, server_args.additional_ports, server_args.tp_size
...@@ -499,6 +504,7 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -499,6 +504,7 @@ def launch_server(server_args, pipe_finish_writer):
timeout=60, timeout=60,
) )
print(f"Warmup done. model response: {res.json()['text']}") print(f"Warmup done. model response: {res.json()['text']}")
print("=" * 20, "Server is ready", "=" * 20, flush=True)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e)) pipe_finish_writer.send(str(e))
......
...@@ -25,7 +25,8 @@ class ServerArgs: ...@@ -25,7 +25,8 @@ class ServerArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
log_level: str = "info" log_level: str = "info"
no_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -172,10 +173,15 @@ class ServerArgs: ...@@ -172,10 +173,15 @@ class ServerArgs:
help="Log stats interval in second.", help="Log stats interval in second.",
) )
parser.add_argument( parser.add_argument(
"--no-regex-jump-forward", "--disable-regex-jump-forward",
action="store_true", action="store_true",
help="Disable regex jump-forward", help="Disable regex jump-forward",
) )
parser.add_argument(
"--disable-disk-cache",
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment