Unverified Commit c9b75917 authored by Kai-Hsun Chen's avatar Kai-Hsun Chen Committed by GitHub
Browse files

[server] Passing `model_override_args` to `launch_server` via the CLI. (#1298)


Signed-off-by: default avatarKai-Hsun Chen <kaihsun@anyscale.com>
parent 662ecd93
......@@ -6,7 +6,7 @@
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
# Launch sglang
# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
# offline
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
......
......@@ -480,6 +480,7 @@ def main(server_args, bench_args):
if __name__ == "__main__":
# TODO(kevin85421): Make the parser setup unit testable.
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
......
"""Launch the inference server."""
import argparse
import os
import sys
from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_child_process
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
server_args = prepare_server_args(sys.argv[1:])
model_override_args = server_args.json_model_override_args
try:
launch_server(server_args)
launch_server(server_args, model_override_args=model_override_args)
except Exception as e:
raise e
finally:
......
"""Launch the inference server for Llava-video model."""
import argparse
import sys
from sglang.srt.server import ServerArgs, launch_server
from sglang.srt.server import launch_server, prepare_server_args
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
server_args = prepare_server_args(sys.argv[1:])
model_override_args = {}
model_override_args["mm_spatial_pool_stride"] = 2
......@@ -20,7 +17,7 @@ if __name__ == "__main__":
model_override_args["max_sequence_length"] = 4096 * 2
model_override_args["tokenizer_model_max_length"] = 4096 * 2
model_override_args["model_max_length"] = 4096 * 2
if "34b" in args.model_path.lower():
if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002
launch_server(server_args, model_override_args, None)
......@@ -17,6 +17,7 @@ limitations under the License.
import argparse
import dataclasses
import json
import logging
import random
from typing import List, Optional, Union
......@@ -95,6 +96,9 @@ class ServerArgs:
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: Optional[dict] = None
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
......@@ -455,10 +459,22 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size
args.json_model_override_args = (
json.loads(args.json_model_override_args)
if args.json_model_override_args
else None
)
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
......@@ -482,6 +498,24 @@ class ServerArgs:
self.disable_flashinfer = False
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
Args:
args: The command line arguments. Typically, it should be `sys.argv[1:]`
to ensure compatibility with `parse_args` when no arguments are passed.
Returns:
The server arguments.
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(args)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
@dataclasses.dataclass
class PortArgs:
tokenizer_port: int
......
......@@ -19,6 +19,7 @@ suites = {
"test_triton_attn_backend.py",
"test_update_weights.py",
"test_vision_openai_server.py",
"test_server_args.py",
],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True
......
import unittest
from sglang.srt.server_args import prepare_server_args
class TestPrepareServerArgs(unittest.TestCase):
def test_prepare_server_args(self):
server_args = prepare_server_args(
[
"--model-path",
"model_path",
"--json-model-override-args",
'{"rope_scaling": {"factor": 2.0, "type": "linear"}}',
]
)
self.assertEqual(server_args.model_path, "model_path")
self.assertEqual(
server_args.json_model_override_args,
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
)
if __name__ == "__main__":
unittest.main()
......@@ -12,7 +12,7 @@ class TestServingLatency(unittest.TestCase):
"python3",
"-m",
"sglang.bench_latency",
"--model",
"--model-path",
DEFAULT_MODEL_NAME_FOR_TEST,
"--batch-size",
"1",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment