"vscode:/vscode.git/clone" did not exist on "c69dbbb736ee18c4aa073060e9f83ba6df83d16a"
Unverified Commit c9b75917 authored by Kai-Hsun Chen's avatar Kai-Hsun Chen Committed by GitHub
Browse files

[server] Passing `model_override_args` to `launch_server` via the CLI. (#1298)


Signed-off-by: default avatarKai-Hsun Chen <kaihsun@anyscale.com>
parent 662ecd93
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
# Launch sglang # Launch sglang
# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 # python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
# offline # offline
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
......
...@@ -480,6 +480,7 @@ def main(server_args, bench_args): ...@@ -480,6 +480,7 @@ def main(server_args, bench_args):
if __name__ == "__main__": if __name__ == "__main__":
# TODO(kevin85421): Make the parser setup unit testable.
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser)
......
"""Launch the inference server.""" """Launch the inference server."""
import argparse
import os import os
import sys
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() server_args = prepare_server_args(sys.argv[1:])
ServerArgs.add_cli_args(parser) model_override_args = server_args.json_model_override_args
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
try: try:
launch_server(server_args) launch_server(server_args, model_override_args=model_override_args)
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
......
"""Launch the inference server for Llava-video model.""" """Launch the inference server for Llava-video model."""
import argparse import sys
from sglang.srt.server import ServerArgs, launch_server from sglang.srt.server import launch_server, prepare_server_args
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() server_args = prepare_server_args(sys.argv[1:])
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
model_override_args = {} model_override_args = {}
model_override_args["mm_spatial_pool_stride"] = 2 model_override_args["mm_spatial_pool_stride"] = 2
...@@ -20,7 +17,7 @@ if __name__ == "__main__": ...@@ -20,7 +17,7 @@ if __name__ == "__main__":
model_override_args["max_sequence_length"] = 4096 * 2 model_override_args["max_sequence_length"] = 4096 * 2
model_override_args["tokenizer_model_max_length"] = 4096 * 2 model_override_args["tokenizer_model_max_length"] = 4096 * 2
model_override_args["model_max_length"] = 4096 * 2 model_override_args["model_max_length"] = 4096 * 2
if "34b" in args.model_path.lower(): if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002 model_override_args["image_token_index"] = 64002
launch_server(server_args, model_override_args, None) launch_server(server_args, model_override_args, None)
...@@ -17,6 +17,7 @@ limitations under the License. ...@@ -17,6 +17,7 @@ limitations under the License.
import argparse import argparse
import dataclasses import dataclasses
import json
import logging import logging
import random import random
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -95,6 +96,9 @@ class ServerArgs: ...@@ -95,6 +96,9 @@ class ServerArgs:
nnodes: int = 1 nnodes: int = 1
node_rank: Optional[int] = None node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: Optional[dict] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
...@@ -455,10 +459,22 @@ class ServerArgs: ...@@ -455,10 +459,22 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
) )
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size args.dp_size = args.data_parallel_size
args.json_model_override_args = (
json.loads(args.json_model_override_args)
if args.json_model_override_args
else None
)
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs}) return cls(**{attr: getattr(args, attr) for attr in attrs})
...@@ -482,6 +498,24 @@ class ServerArgs: ...@@ -482,6 +498,24 @@ class ServerArgs:
self.disable_flashinfer = False self.disable_flashinfer = False
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
Args:
args: The command line arguments. Typically, it should be `sys.argv[1:]`
to ensure compatibility with `parse_args` when no arguments are passed.
Returns:
The server arguments.
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(args)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
tokenizer_port: int tokenizer_port: int
......
...@@ -19,6 +19,7 @@ suites = { ...@@ -19,6 +19,7 @@ suites = {
"test_triton_attn_backend.py", "test_triton_attn_backend.py",
"test_update_weights.py", "test_update_weights.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_server_args.py",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True
......
import unittest
from sglang.srt.server_args import prepare_server_args
class TestPrepareServerArgs(unittest.TestCase):
def test_prepare_server_args(self):
server_args = prepare_server_args(
[
"--model-path",
"model_path",
"--json-model-override-args",
'{"rope_scaling": {"factor": 2.0, "type": "linear"}}',
]
)
self.assertEqual(server_args.model_path, "model_path")
self.assertEqual(
server_args.json_model_override_args,
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
)
if __name__ == "__main__":
unittest.main()
...@@ -12,7 +12,7 @@ class TestServingLatency(unittest.TestCase): ...@@ -12,7 +12,7 @@ class TestServingLatency(unittest.TestCase):
"python3", "python3",
"-m", "-m",
"sglang.bench_latency", "sglang.bench_latency",
"--model", "--model-path",
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
"--batch-size", "--batch-size",
"1", "1",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment