Unverified Commit 1b355479 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize `server_args` (#277)

parent faba293a
......@@ -16,10 +16,10 @@ please build it from source (the compilation takes a long time).
### Run a Server With Flashinfer Mode
Add `--model-mode flashinfer` argument to enable flashinfer when launching a server.
Add `--enable-flashinfer` argument to enable flashinfer when launching a server.
Example:
```bash
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --enable-flashinfer
```
......@@ -43,6 +43,17 @@ def Runtime(*args, **kwargs):
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def flush_cache(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return False
return backend.flush_cache()
def get_server_args(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return None
return backend.get_server_args()
def gen(
name: Optional[str] = None,
......
......@@ -72,3 +72,9 @@ class BaseBackend:
def shutdown(self):
pass
def flush_cache(self):
pass
def get_server_args(self):
pass
......@@ -35,6 +35,22 @@ class RuntimeEndpoint(BaseBackend):
def get_model_name(self):
return self.model_info["model_path"]
def flush_cache(self):
res = http_request(
self.base_url + "/flush_cache",
auth_token=self.auth_token,
verify=self.verify,
)
return res.status_code == 200
def get_server_args(self):
res = http_request(
self.base_url + "/get_server_args",
auth_token=self.auth_token,
verify=self.verify,
)
return res.json()
def get_chat_template(self):
return self.chat_template
......
......@@ -15,11 +15,9 @@ class RadixAttention(nn.Module):
self.head_dim = head_dim
self.layer_id = layer_id
from sglang.srt.managers.router.model_runner import global_server_args
from sglang.srt.managers.router.model_runner import global_server_args_dict
self.use_flashinfer = "flashinfer" in global_server_args.model_mode
if self.use_flashinfer:
if global_server_args_dict["enable_flashinfer"]:
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
......
......@@ -4,10 +4,10 @@
import torch
import triton
import triton.language as tl
from sglang.srt.managers.router.model_runner import global_server_args
from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher
if global_server_args.attention_reduce_in_fp32:
if global_server_args_dict["attention_reduce_in_fp32"]:
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:
......
......@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
# Copy arguments
self.model_mode = server_args.model_mode
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
......@@ -61,15 +60,22 @@ class ModelRpcServer(rpyc.Service):
server_args.trust_remote_code,
context_length=server_args.context_length,
)
# for model end global settings
server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
server_args=server_args,
load_format=server_args.load_format,
trust_remote_code=server_args.trust_remote_code,
server_args_dict=server_args_dict,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
......@@ -104,11 +110,11 @@ class ModelRpcServer(rpyc.Service):
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
f"model_mode={self.model_mode}"
)
logger.info(server_args.get_optional_modes_logging())
# Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.tree_cache = RadixCache(server_args.disable_radix_cache)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.schedule_heuristic,
......
......@@ -23,7 +23,7 @@ logger = logging.getLogger("model_runner")
# for server args in model endpoints
global_server_args = None
global_server_args_dict: dict = None
@lru_cache()
......@@ -222,7 +222,7 @@ class InputMetadata:
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
if "flashinfer" in global_server_args.model_mode:
if global_server_args_dict["enable_flashinfer"]:
ret.init_flashinfer_args(tp_size)
return ret
......@@ -236,9 +236,9 @@ class ModelRunner:
tp_rank,
tp_size,
nccl_port,
server_args,
load_format="auto",
trust_remote_code=True,
server_args_dict: dict = {},
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
......@@ -248,8 +248,8 @@ class ModelRunner:
self.load_format = load_format
self.trust_remote_code = trust_remote_code
global global_server_args
global_server_args = server_args
global global_server_args_dict
global_server_args_dict = server_args_dict
# Init torch distributed
torch.cuda.set_device(self.tp_rank)
......
......@@ -82,6 +82,8 @@ class TokenizerManager:
server_args: ServerArgs,
port_args: PortArgs,
):
self.server_args = server_args
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
......
"""SRT: SGLang Runtime"""
import asyncio
import dataclasses
import json
import multiprocessing as mp
import os
......@@ -86,6 +87,11 @@ async def get_model_info():
return result
@app.get("/get_server_args")
async def get_server_args():
return dataclasses.asdict(tokenizer_manager.server_args)
@app.get("/flush_cache")
async def flush_cache():
await tokenizer_manager.flush_cache()
......@@ -548,7 +554,6 @@ class Runtime:
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
context_length: int = ServerArgs.context_length,
tp_size: int = 1,
model_mode: List[str] = (),
schedule_heuristic: str = "lpm",
attention_reduce_in_fp32: bool = False,
random_seed: int = 42,
......@@ -571,7 +576,6 @@ class Runtime:
max_prefill_num_token=max_prefill_num_token,
context_length=context_length,
tp_size=tp_size,
model_mode=model_mode,
schedule_heuristic=schedule_heuristic,
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed,
......
......@@ -18,7 +18,6 @@ class ServerArgs:
max_prefill_num_token: Optional[int] = None
context_length: Optional[int] = None
tp_size: int = 1
model_mode: List[str] = ()
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
attention_reduce_in_fp32: bool = False
......@@ -27,6 +26,10 @@ class ServerArgs:
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
# optional modes
disable_radix_cache: bool = False
enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
......@@ -131,14 +134,6 @@ class ServerArgs:
default=ServerArgs.tp_size,
help="Tensor parallelism degree.",
)
parser.add_argument(
"--model-mode",
type=str,
default=[],
nargs="+",
choices=["flashinfer", "no-cache"],
help="Model mode: [flashinfer, no-cache]",
)
parser.add_argument(
"--schedule-heuristic",
type=str,
......@@ -185,6 +180,17 @@ class ServerArgs:
default=ServerArgs.log_stats_interval,
help="Log stats interval in second.",
)
# optional modes
parser.add_argument(
"--disable-radix-cache",
action="store_true",
help="Disable RadixAttention",
)
parser.add_argument(
"--enable-flashinfer",
action="store_true",
help="Enable flashinfer inference kernels",
)
parser.add_argument(
"--disable-regex-jump-forward",
action="store_true",
......@@ -204,6 +210,15 @@ class ServerArgs:
def url(self):
return f"http://{self.host}:{self.port}"
def get_optional_modes_logging(self):
return (
f"disable_radix_cache={self.disable_radix_cache}, "
f"enable_flashinfer={self.enable_flashinfer}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
)
@dataclasses.dataclass
class PortArgs:
......
......@@ -151,7 +151,7 @@ def bench_generate_worker(
shared_len,
unique_len,
decode_len,
model_mode,
server_args_dict,
):
assert unique_num % shared_num == 0
......@@ -162,7 +162,7 @@ def bench_generate_worker(
tp_rank=tp_rank,
tp_size=tp_size,
nccl_port=28888,
model_mode=model_mode,
server_args_dict=server_args_dict,
)
batch = BenchBatch(model_runner)
......@@ -227,7 +227,7 @@ def bench_generate(
shared_len,
unique_len,
decode_len,
model_mode,
server_args_dict,
):
print(
f"tp_size: {tp_size}, "
......@@ -236,7 +236,7 @@ def bench_generate(
f"shared_len: {shared_len}, "
f"unique_len: {unique_len}, "
f"decode_len: {decode_len}, "
f"model_mode: {model_mode}"
f"server_args: {server_args_dict}"
)
workers = []
for tp_rank in range(tp_size):
......@@ -251,7 +251,7 @@ def bench_generate(
shared_len,
unique_len,
decode_len,
model_mode,
server_args_dict,
),
)
proc.start()
......@@ -270,5 +270,5 @@ if __name__ == "__main__":
shared_len=256,
unique_len=256,
decode_len=8,
model_mode=[],
server_args_dict={},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment