Unverified Commit 1b355479 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize `server_args` (#277)

parent faba293a
...@@ -16,10 +16,10 @@ please build it from source (the compilation takes a long time). ...@@ -16,10 +16,10 @@ please build it from source (the compilation takes a long time).
### Run a Server With Flashinfer Mode ### Run a Server With Flashinfer Mode
Add `--model-mode flashinfer` argument to enable flashinfer when launching a server. Add `--enable-flashinfer` argument to enable flashinfer when launching a server.
Example: Example:
```bash ```bash
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --enable-flashinfer
``` ```
...@@ -43,6 +43,17 @@ def Runtime(*args, **kwargs): ...@@ -43,6 +43,17 @@ def Runtime(*args, **kwargs):
def set_default_backend(backend: BaseBackend): def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend global_config.default_backend = backend
def flush_cache(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return False
return backend.flush_cache()
def get_server_args(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return None
return backend.get_server_args()
def gen( def gen(
name: Optional[str] = None, name: Optional[str] = None,
......
...@@ -72,3 +72,9 @@ class BaseBackend: ...@@ -72,3 +72,9 @@ class BaseBackend:
def shutdown(self): def shutdown(self):
pass pass
def flush_cache(self):
pass
def get_server_args(self):
pass
...@@ -35,6 +35,22 @@ class RuntimeEndpoint(BaseBackend): ...@@ -35,6 +35,22 @@ class RuntimeEndpoint(BaseBackend):
def get_model_name(self): def get_model_name(self):
return self.model_info["model_path"] return self.model_info["model_path"]
def flush_cache(self):
res = http_request(
self.base_url + "/flush_cache",
auth_token=self.auth_token,
verify=self.verify,
)
return res.status_code == 200
def get_server_args(self):
res = http_request(
self.base_url + "/get_server_args",
auth_token=self.auth_token,
verify=self.verify,
)
return res.json()
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
......
...@@ -15,11 +15,9 @@ class RadixAttention(nn.Module): ...@@ -15,11 +15,9 @@ class RadixAttention(nn.Module):
self.head_dim = head_dim self.head_dim = head_dim
self.layer_id = layer_id self.layer_id = layer_id
from sglang.srt.managers.router.model_runner import global_server_args from sglang.srt.managers.router.model_runner import global_server_args_dict
self.use_flashinfer = "flashinfer" in global_server_args.model_mode if global_server_args_dict["enable_flashinfer"]:
if self.use_flashinfer:
self.prefill_forward = self.prefill_forward_flashinfer self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.router.model_runner import global_server_args from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher from sglang.srt.utils import wrap_kernel_launcher
if global_server_args.attention_reduce_in_fp32: if global_server_args_dict["attention_reduce_in_fp32"]:
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32 REDUCE_TORCH_TYPE = torch.float32
else: else:
......
...@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service): ...@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service):
server_args, port_args = [obtain(x) for x in [server_args, port_args]] server_args, port_args = [obtain(x) for x in [server_args, port_args]]
# Copy arguments # Copy arguments
self.model_mode = server_args.model_mode
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
...@@ -61,15 +60,22 @@ class ModelRpcServer(rpyc.Service): ...@@ -61,15 +60,22 @@ class ModelRpcServer(rpyc.Service):
server_args.trust_remote_code, server_args.trust_remote_code,
context_length=server_args.context_length, context_length=server_args.context_length,
) )
# for model end global settings
server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
server_args=server_args,
load_format=server_args.load_format, load_format=server_args.load_format,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
server_args_dict=server_args_dict,
) )
if is_multimodal_model(server_args.model_path): if is_multimodal_model(server_args.model_path):
self.processor = get_processor( self.processor = get_processor(
...@@ -104,11 +110,11 @@ class ModelRpcServer(rpyc.Service): ...@@ -104,11 +110,11 @@ class ModelRpcServer(rpyc.Service):
f"max_total_num_token={self.max_total_num_token}, " f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
f"model_mode={self.model_mode}"
) )
logger.info(server_args.get_optional_modes_logging())
# Init cache # Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode) self.tree_cache = RadixCache(server_args.disable_radix_cache)
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler( self.scheduler = Scheduler(
self.schedule_heuristic, self.schedule_heuristic,
......
...@@ -23,7 +23,7 @@ logger = logging.getLogger("model_runner") ...@@ -23,7 +23,7 @@ logger = logging.getLogger("model_runner")
# for server args in model endpoints # for server args in model endpoints
global_server_args = None global_server_args_dict: dict = None
@lru_cache() @lru_cache()
...@@ -222,7 +222,7 @@ class InputMetadata: ...@@ -222,7 +222,7 @@ class InputMetadata:
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args() ret.init_extend_args()
if "flashinfer" in global_server_args.model_mode: if global_server_args_dict["enable_flashinfer"]:
ret.init_flashinfer_args(tp_size) ret.init_flashinfer_args(tp_size)
return ret return ret
...@@ -236,9 +236,9 @@ class ModelRunner: ...@@ -236,9 +236,9 @@ class ModelRunner:
tp_rank, tp_rank,
tp_size, tp_size,
nccl_port, nccl_port,
server_args,
load_format="auto", load_format="auto",
trust_remote_code=True, trust_remote_code=True,
server_args_dict: dict = {},
): ):
self.model_config = model_config self.model_config = model_config
self.mem_fraction_static = mem_fraction_static self.mem_fraction_static = mem_fraction_static
...@@ -248,8 +248,8 @@ class ModelRunner: ...@@ -248,8 +248,8 @@ class ModelRunner:
self.load_format = load_format self.load_format = load_format
self.trust_remote_code = trust_remote_code self.trust_remote_code = trust_remote_code
global global_server_args global global_server_args_dict
global_server_args = server_args global_server_args_dict = server_args_dict
# Init torch distributed # Init torch distributed
torch.cuda.set_device(self.tp_rank) torch.cuda.set_device(self.tp_rank)
......
...@@ -82,6 +82,8 @@ class TokenizerManager: ...@@ -82,6 +82,8 @@ class TokenizerManager:
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
): ):
self.server_args = server_args
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
......
"""SRT: SGLang Runtime""" """SRT: SGLang Runtime"""
import asyncio import asyncio
import dataclasses
import json import json
import multiprocessing as mp import multiprocessing as mp
import os import os
...@@ -86,6 +87,11 @@ async def get_model_info(): ...@@ -86,6 +87,11 @@ async def get_model_info():
return result return result
@app.get("/get_server_args")
async def get_server_args():
return dataclasses.asdict(tokenizer_manager.server_args)
@app.get("/flush_cache") @app.get("/flush_cache")
async def flush_cache(): async def flush_cache():
await tokenizer_manager.flush_cache() await tokenizer_manager.flush_cache()
...@@ -548,7 +554,6 @@ class Runtime: ...@@ -548,7 +554,6 @@ class Runtime:
max_prefill_num_token: int = ServerArgs.max_prefill_num_token, max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
context_length: int = ServerArgs.context_length, context_length: int = ServerArgs.context_length,
tp_size: int = 1, tp_size: int = 1,
model_mode: List[str] = (),
schedule_heuristic: str = "lpm", schedule_heuristic: str = "lpm",
attention_reduce_in_fp32: bool = False, attention_reduce_in_fp32: bool = False,
random_seed: int = 42, random_seed: int = 42,
...@@ -571,7 +576,6 @@ class Runtime: ...@@ -571,7 +576,6 @@ class Runtime:
max_prefill_num_token=max_prefill_num_token, max_prefill_num_token=max_prefill_num_token,
context_length=context_length, context_length=context_length,
tp_size=tp_size, tp_size=tp_size,
model_mode=model_mode,
schedule_heuristic=schedule_heuristic, schedule_heuristic=schedule_heuristic,
attention_reduce_in_fp32=attention_reduce_in_fp32, attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed, random_seed=random_seed,
......
...@@ -18,7 +18,6 @@ class ServerArgs: ...@@ -18,7 +18,6 @@ class ServerArgs:
max_prefill_num_token: Optional[int] = None max_prefill_num_token: Optional[int] = None
context_length: Optional[int] = None context_length: Optional[int] = None
tp_size: int = 1 tp_size: int = 1
model_mode: List[str] = ()
schedule_heuristic: str = "lpm" schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
attention_reduce_in_fp32: bool = False attention_reduce_in_fp32: bool = False
...@@ -27,6 +26,10 @@ class ServerArgs: ...@@ -27,6 +26,10 @@ class ServerArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
log_level: str = "info" log_level: str = "info"
# optional modes
disable_radix_cache: bool = False
enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
...@@ -131,14 +134,6 @@ class ServerArgs: ...@@ -131,14 +134,6 @@ class ServerArgs:
default=ServerArgs.tp_size, default=ServerArgs.tp_size,
help="Tensor parallelism degree.", help="Tensor parallelism degree.",
) )
parser.add_argument(
"--model-mode",
type=str,
default=[],
nargs="+",
choices=["flashinfer", "no-cache"],
help="Model mode: [flashinfer, no-cache]",
)
parser.add_argument( parser.add_argument(
"--schedule-heuristic", "--schedule-heuristic",
type=str, type=str,
...@@ -185,6 +180,17 @@ class ServerArgs: ...@@ -185,6 +180,17 @@ class ServerArgs:
default=ServerArgs.log_stats_interval, default=ServerArgs.log_stats_interval,
help="Log stats interval in second.", help="Log stats interval in second.",
) )
# optional modes
parser.add_argument(
"--disable-radix-cache",
action="store_true",
help="Disable RadixAttention",
)
parser.add_argument(
"--enable-flashinfer",
action="store_true",
help="Enable flashinfer inference kernels",
)
parser.add_argument( parser.add_argument(
"--disable-regex-jump-forward", "--disable-regex-jump-forward",
action="store_true", action="store_true",
...@@ -204,6 +210,15 @@ class ServerArgs: ...@@ -204,6 +210,15 @@ class ServerArgs:
def url(self): def url(self):
return f"http://{self.host}:{self.port}" return f"http://{self.host}:{self.port}"
def get_optional_modes_logging(self):
return (
f"disable_radix_cache={self.disable_radix_cache}, "
f"enable_flashinfer={self.enable_flashinfer}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
)
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
......
...@@ -151,7 +151,7 @@ def bench_generate_worker( ...@@ -151,7 +151,7 @@ def bench_generate_worker(
shared_len, shared_len,
unique_len, unique_len,
decode_len, decode_len,
model_mode, server_args_dict,
): ):
assert unique_num % shared_num == 0 assert unique_num % shared_num == 0
...@@ -162,7 +162,7 @@ def bench_generate_worker( ...@@ -162,7 +162,7 @@ def bench_generate_worker(
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=tp_size, tp_size=tp_size,
nccl_port=28888, nccl_port=28888,
model_mode=model_mode, server_args_dict=server_args_dict,
) )
batch = BenchBatch(model_runner) batch = BenchBatch(model_runner)
...@@ -227,7 +227,7 @@ def bench_generate( ...@@ -227,7 +227,7 @@ def bench_generate(
shared_len, shared_len,
unique_len, unique_len,
decode_len, decode_len,
model_mode, server_args_dict,
): ):
print( print(
f"tp_size: {tp_size}, " f"tp_size: {tp_size}, "
...@@ -236,7 +236,7 @@ def bench_generate( ...@@ -236,7 +236,7 @@ def bench_generate(
f"shared_len: {shared_len}, " f"shared_len: {shared_len}, "
f"unique_len: {unique_len}, " f"unique_len: {unique_len}, "
f"decode_len: {decode_len}, " f"decode_len: {decode_len}, "
f"model_mode: {model_mode}" f"server_args: {server_args_dict}"
) )
workers = [] workers = []
for tp_rank in range(tp_size): for tp_rank in range(tp_size):
...@@ -251,7 +251,7 @@ def bench_generate( ...@@ -251,7 +251,7 @@ def bench_generate(
shared_len, shared_len,
unique_len, unique_len,
decode_len, decode_len,
model_mode, server_args_dict,
), ),
) )
proc.start() proc.start()
...@@ -270,5 +270,5 @@ if __name__ == "__main__": ...@@ -270,5 +270,5 @@ if __name__ == "__main__":
shared_len=256, shared_len=256,
unique_len=256, unique_len=256,
decode_len=8, decode_len=8,
model_mode=[], server_args_dict={},
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment