launch_server.py 2.25 KB
Newer Older
1
2
3
import argparse
import os

4
NUM_LORAS = 4
5
LORA_PATH = {
6
7
    "base": "meta-llama/Llama-2-7b-hf",
    "lora": "winddude/wizardLM-LlaMA-LoRA-7B",
8
9
10
11
12
13
14
15
}


def launch_server(args):
    base_path = LORA_PATH["base"]
    lora_path = LORA_PATH["lora"]

    if args.base_only:
16
        cmd = f"python3 -m sglang.launch_server --model {base_path} "
17
    else:
18
        cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths "
19
20
21
        for i in range(NUM_LORAS):
            lora_name = f"lora{i}"
            cmd += f"{lora_name}={lora_path} "
22
    cmd += f"--disable-radix "
23
    cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
24
    cmd += f"--max-running-requests {args.max_running_requests} "
25
26
27
28
    cmd += f"--lora-backend {args.lora_backend} "
    cmd += f"--tp-size {args.tp_size} "
    if args.disable_custom_all_reduce:
        cmd += "--disable-custom-all-reduce"
29
30
    if args.enable_mscclpp:
        cmd += "--enable-mscclpp"
31
32
    if args.enable_torch_symm_mem:
        cmd += "--enable-torch-symm-mem"
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    print(cmd)
    os.system(cmd)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base-only",
        action="store_true",
    )
    parser.add_argument(
        "--max-loras-per-batch",
        type=int,
        default=8,
    )
    parser.add_argument(
        "--max-running-requests",
        type=int,
        default=8,
    )
53
54
55
    parser.add_argument(
        "--lora-backend",
        type=str,
56
        default="csgmv",
57
    )
58
59
60
61
62
63
64
65
66
67
68
69
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallel size for distributed inference",
    )
    # disable_custom_all_reduce
    parser.add_argument(
        "--disable-custom-all-reduce",
        action="store_true",
        help="Disable custom all reduce when device does not support p2p communication",
    )
70
71
72
73
74
    parser.add_argument(
        "--enable-mscclpp",
        action="store_true",
        help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
    )
75
76
77
78
79
    parser.add_argument(
        "--enable-torch-symm-mem",
        action="store_true",
        help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL.",
    )
80
81
82
    args = parser.parse_args()

    launch_server(args)