test_utils.py 29.5 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
"""Common utilities for testing and benchmarking"""
2

3
import argparse
4
import copy
5
import logging
6
import os
7
import random
8
import subprocess
9
import threading
10
import time
11
import traceback
12
import unittest
13
from concurrent.futures import ThreadPoolExecutor
Byron Hsu's avatar
Byron Hsu committed
14
from dataclasses import dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
15
from functools import partial
16
from types import SimpleNamespace
17
from typing import Callable, List, Optional, Tuple
Liangsheng Yin's avatar
Liangsheng Yin committed
18

Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
import numpy as np
import requests
21
22
import torch
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
23

24
from sglang.bench_serving import run_benchmark
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from sglang.global_config import global_config
Ying Sheng's avatar
Ying Sheng committed
26
27
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
28
from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry
29
from sglang.test.run_eval import run_eval
30
from sglang.utils import get_exception_traceback
Liangsheng Yin's avatar
Liangsheng Yin committed
31

Lianmin Zheng's avatar
Lianmin Zheng committed
32
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
HandH1998's avatar
HandH1998 committed
33
34
35
36
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST = "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
    "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
)
37
38
39
40
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = (
    "nvidia/Llama-3.1-8B-Instruct-FP8"
)

41
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
Lianmin Zheng's avatar
Lianmin Zheng committed
42
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
Yineng Zhang's avatar
Yineng Zhang committed
43
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
44
45
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
Ke Bao's avatar
Ke Bao committed
46
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
Yineng Zhang's avatar
Yineng Zhang committed
47
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
Xihuai Wang's avatar
Xihuai Wang committed
48
DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
49
50
51
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = (
    "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
)
52
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
53
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
54
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
55
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
Ke Bao's avatar
Ke Bao committed
56
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
57
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
58
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
59
60
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"

61
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
62
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
63

64
65
66
DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"

67
68
69

def is_in_ci():
    """Return whether it is in CI runner."""
70
    return get_bool_env_var("SGLANG_IS_IN_CI")
71
72
73


if is_in_ci():
74
75
76
    DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
        5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
    )
77
else:
78
79
80
81
    DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
        7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
    )
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
82

Lianmin Zheng's avatar
Lianmin Zheng committed
83

Liangsheng Yin's avatar
Liangsheng Yin committed
84
85
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
    assert url is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    data = {
        "inputs": prompt,
        "parameters": {
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "stop_sequences": stop,
        },
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    pred = res.json()["generated_text"][0]
    return pred


Liangsheng Yin's avatar
Liangsheng Yin committed
101
102
103
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
        "n": n,
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    if n == 1:
        pred = res.json()["text"][0][len(prompt) :]
    else:
        pred = [x[len(prompt) :] for x in res.json()["text"]]
    return pred


120
def call_generate_outlines(
121
    prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None
122
):
Liangsheng Yin's avatar
Liangsheng Yin committed
123
124
    assert url is not None

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
        "regex": regex,
        "n": n,
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    if n == 1:
        pred = res.json()["text"][0][len(prompt) :]
    else:
        pred = [x[len(prompt) :] for x in res.json()["text"]]
    return pred


Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
144
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    data = {
        "text": prompt,
        "sampling_params": {
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "stop": stop,
        },
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    obj = res.json()
    pred = obj["text"]
    return pred


Liangsheng Yin's avatar
Liangsheng Yin committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def call_generate_guidance(
    prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
    assert model is not None
    from guidance import gen

    rets = []
    for _ in range(n):
        out = (
            model
            + prompt
            + gen(
                name="answer",
                max_tokens=max_tokens,
                temperature=temperature,
                stop=stop,
                regex=regex,
            )
        )
        rets.append(out["answer"])
    return rets if n > 1 else rets[0]


def call_select_lightllm(context, choices, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    scores = []
    for i in range(len(choices)):
        data = {
            "inputs": context + choices[i],
            "parameters": {
                "max_new_tokens": 1,
            },
        }
        res = requests.post(url, json=data)
        assert res.status_code == 200
        scores.append(0)
    return np.argmax(scores)


Liangsheng Yin's avatar
Liangsheng Yin committed
200
201
202
def call_select_vllm(context, choices, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
203
204
205
206
207
208
209
210
211
    scores = []
    for i in range(len(choices)):
        data = {
            "prompt": context + choices[i],
            "max_tokens": 1,
            "prompt_logprobs": 1,
        }
        res = requests.post(url, json=data)
        assert res.status_code == 200
Lianmin Zheng's avatar
Lianmin Zheng committed
212
        scores.append(res.json().get("prompt_score", 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
213
214
215
216
217
218
219
220
221
222
223
    return np.argmax(scores)

    """
    Modify vllm/entrypoints/api_server.py

    if final_output.prompt_logprobs is not None:
        score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])])
        ret["prompt_score"] = score
    """


Liangsheng Yin's avatar
Liangsheng Yin committed
224
225
226
227
228
229
230
231
def call_select_guidance(context, choices, model=None):
    assert model is not None
    from guidance import select

    out = model + context + select(choices, name="answer")
    return choices.index(out["answer"])


232
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
Lianmin Zheng's avatar
Lianmin Zheng committed
233
    parser.add_argument("--parallel", type=int, default=64)
Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
236
237
238
239
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument(
        "--backend",
        type=str,
        required=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
240
241
242
243
        choices=[
            "vllm",
            "outlines",
            "lightllm",
244
            "gserver",
Liangsheng Yin's avatar
Liangsheng Yin committed
245
246
247
248
            "guidance",
            "srt-raw",
            "llama.cpp",
        ],
Lianmin Zheng's avatar
Lianmin Zheng committed
249
    )
Liangsheng Yin's avatar
Liangsheng Yin committed
250
    parser.add_argument("--n-ctx", type=int, default=4096)
Lianmin Zheng's avatar
Lianmin Zheng committed
251
252
253
254
255
256
257
258
259
    parser.add_argument(
        "--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
    )
    parser.add_argument("--result-file", type=str, default="result.jsonl")
    args = parser.parse_args()

    if args.port is None:
        default_port = {
            "vllm": 21000,
Liangsheng Yin's avatar
Liangsheng Yin committed
260
            "outlines": 21000,
Lianmin Zheng's avatar
Lianmin Zheng committed
261
262
            "lightllm": 22000,
            "srt-raw": 30000,
263
            "gserver": 9988,
Lianmin Zheng's avatar
Lianmin Zheng committed
264
265
266
267
268
        }
        args.port = default_port.get(args.backend, None)
    return args


269
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
Lianmin Zheng's avatar
Lianmin Zheng committed
270
271
272
273
274
275
276
277
278
    parser.add_argument("--parallel", type=int, default=64)
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=30000)
    parser.add_argument("--backend", type=str, default="srt")
    parser.add_argument("--result-file", type=str, default="result.jsonl")
    args = parser.parse_args()
    return args


279
def select_sglang_backend(args: argparse.Namespace):
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
282
283
    if args.backend.startswith("srt"):
        if args.backend == "srt-no-parallel":
            global_config.enable_parallel_encoding = False
        backend = RuntimeEndpoint(f"{args.host}:{args.port}")
284
    elif args.backend.startswith("gpt-"):
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
        backend = OpenAI(args.backend)
    else:
        raise ValueError(f"Invalid backend: {args.backend}")
    return backend
Liangsheng Yin's avatar
Liangsheng Yin committed
289
290


291
def _get_call_generate(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
292
293
294
295
296
297
    if args.backend == "lightllm":
        return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "vllm":
        return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "srt-raw":
        return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
298
299
    elif args.backend == "gserver":
        return partial(call_generate_gserver, url=f"{args.host}:{args.port}")
Liangsheng Yin's avatar
Liangsheng Yin committed
300
301
302
303
304
305
306
307
308
309
310
311
312
    elif args.backend == "outlines":
        return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "guidance":
        from guidance import models

        model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
        call_generate = partial(call_generate_guidance, model=model)
        call_generate("Hello,", 1.0, 8, ".")
        return call_generate
    else:
        raise ValueError(f"Invalid backend: {args.backend}")


313
def _get_call_select(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    if args.backend == "lightllm":
        return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "vllm":
        return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "guidance":
        from guidance import models

        model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
        call_select = partial(call_select_guidance, model=model)

        call_select("Hello,", ["world", "earth"])
        return call_select
    else:
        raise ValueError(f"Invalid backend: {args.backend}")


330
def get_call_generate(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
331
332
333
334
335
336
337
338
339
340
341
342
    call_generate = _get_call_generate(args)

    def func(*args, **kwargs):
        try:
            return call_generate(*args, **kwargs)
        except Exception:
            print("Exception in call_generate:\n" + get_exception_traceback())
            raise

    return func


343
def get_call_select(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
344
345
346
347
348
349
350
351
352
353
    call_select = _get_call_select(args)

    def func(*args, **kwargs):
        try:
            return call_select(*args, **kwargs)
        except Exception:
            print("Exception in call_select:\n" + get_exception_traceback())
            raise

    return func
354
355


356
def popen_launch_server(
357
358
359
360
    model: str,
    base_url: str,
    timeout: float,
    api_key: Optional[str] = None,
Mick's avatar
Mick committed
361
    other_args: list[str] = (),
362
    env: Optional[dict] = None,
363
    return_stdout_stderr: Optional[tuple] = None,
364
    pd_seperated: bool = False,
365
366
367
368
):
    _, host, port = base_url.split(":")
    host = host[2:]

369
370
371
372
373
    if pd_seperated:
        command = "sglang.launch_pd_server"
    else:
        command = "sglang.launch_server"

374
375
376
    command = [
        "python3",
        "-m",
377
        command,
378
379
        "--model-path",
        model,
380
        *[str(x) for x in other_args],
381
    ]
Chayenne's avatar
Chayenne committed
382

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    if pd_seperated:
        command.extend(
            [
                "--lb-host",
                host,
                "--lb-port",
                port,
            ]
        )
    else:
        command.extend(
            [
                "--host",
                host,
                "--port",
                port,
            ]
        )

402
403
404
    if api_key:
        command += ["--api-key", api_key]

405
406
    print(f"command={' '.join(command)}")

407
408
409
    if return_stdout_stderr:
        process = subprocess.Popen(
            command,
410
411
            stdout=return_stdout_stderr[0],
            stderr=return_stdout_stderr[1],
412
413
414
415
416
            env=env,
            text=True,
        )
    else:
        process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
417
418

    start_time = time.time()
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    with requests.Session() as session:
        while time.time() - start_time < timeout:
            try:
                headers = {
                    "Content-Type": "application/json; charset=utf-8",
                    "Authorization": f"Bearer {api_key}",
                }
                response = session.get(
                    f"{base_url}/health_generate",
                    headers=headers,
                )
                if response.status_code == 200:
                    return process
            except requests.RequestException:
                pass
434
435
436
437
438

            return_code = process.poll()
            if return_code is not None:
                raise Exception(f"Server unexpectedly exits ({return_code=}).")

439
            time.sleep(10)
440
441

    kill_process_tree(process.pid)
442
    raise TimeoutError("Server failed to start within the timeout period.")
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468


def run_with_timeout(
    func: Callable,
    args: tuple = (),
    kwargs: Optional[dict] = None,
    timeout: float = None,
):
    """Run a function with timeout."""
    ret_value = []

    def _target_func():
        ret_value.append(func(*args, **(kwargs or {})))

    t = threading.Thread(target=_target_func)
    t.start()
    t.join(timeout=timeout)
    if t.is_alive():
        raise TimeoutError()

    if not ret_value:
        raise RuntimeError()

    return ret_value[0]


Byron Hsu's avatar
Byron Hsu committed
469
470
471
472
473
474
475
@dataclass
class TestFile:
    name: str
    estimated_time: float = 60


def run_unittest_files(files: List[TestFile], timeout_per_file: float):
476
477
478
    tic = time.time()
    success = True

Lianmin Zheng's avatar
Lianmin Zheng committed
479
480
    for file in files:
        filename, estimated_time = file.name, file.estimated_time
481
        process = None
482

Mingyi's avatar
Mingyi committed
483
        def run_one_file(filename):
484
485
            nonlocal process

Mingyi's avatar
Mingyi committed
486
            filename = os.path.join(os.getcwd(), filename)
Lianmin Zheng's avatar
Lianmin Zheng committed
487
488
489
            print(f".\n.\nBegin:\npython3 {filename}\n.\n.\n", flush=True)
            tic = time.time()

Mingyi's avatar
Mingyi committed
490
491
492
493
            process = subprocess.Popen(
                ["python3", filename], stdout=None, stderr=None, env=os.environ
            )
            process.wait()
Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
496
497
498
499
            elapsed = time.time() - tic

            print(
                f".\n.\nEnd:\n{filename=}, {elapsed=:.0f}, {estimated_time=}\n.\n.\n",
                flush=True,
            )
Mingyi's avatar
Mingyi committed
500
            return process.returncode
501
502

        try:
Mingyi's avatar
Mingyi committed
503
504
505
            ret_code = run_with_timeout(
                run_one_file, args=(filename,), timeout=timeout_per_file
            )
506
507
508
            assert (
                ret_code == 0
            ), f"expected return code 0, but {filename} returned {ret_code}"
509
        except TimeoutError:
510
            kill_process_tree(process.pid)
511
512
            time.sleep(5)
            print(
513
514
                f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
                flush=True,
515
            )
Mingyi's avatar
Mingyi committed
516
517
            success = False
            break
518
519

    if success:
520
        print(f"Success. Time elapsed: {time.time() - tic:.2f}s", flush=True)
521
    else:
522
        print(f"Fail. Time elapsed: {time.time() - tic:.2f}s", flush=True)
523
524

    return 0 if success else -1
525
526
527
528


def get_similarities(vec1, vec2):
    return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
529
530


531
532
533
534
535
536
def get_benchmark_args(
    base_url="",
    dataset_name="",
    dataset_path="",
    tokenizer="",
    num_prompts=500,
537
    sharegpt_output_len=None,
538
539
    random_input_len=4096,
    random_output_len=2048,
540
    sharegpt_context_len=None,
541
542
543
    request_rate=float("inf"),
    disable_stream=False,
    disable_ignore_eos=False,
544
    seed: int = 0,
545
    pd_seperated: bool = False,
546
547
548
549
550
551
552
553
554
555
556
):
    return SimpleNamespace(
        backend="sglang",
        base_url=base_url,
        host=None,
        port=None,
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        model=None,
        tokenizer=tokenizer,
        num_prompts=num_prompts,
557
558
        sharegpt_output_len=sharegpt_output_len,
        sharegpt_context_len=sharegpt_context_len,
559
560
561
562
563
564
565
566
567
        random_input_len=random_input_len,
        random_output_len=random_output_len,
        random_range_ratio=0.0,
        request_rate=request_rate,
        multi=None,
        output_file=None,
        disable_tqdm=False,
        disable_stream=disable_stream,
        return_logprob=False,
568
        seed=seed,
569
570
571
572
573
        disable_ignore_eos=disable_ignore_eos,
        extra_request_body=None,
        apply_chat_template=False,
        profile=None,
        lora_name=None,
574
575
        prompt_suffix="",
        pd_seperated=pd_seperated,
576
577
578
    )


579
580
581
582
583
584
def run_bench_serving(
    model,
    num_prompts,
    request_rate,
    other_server_args,
    dataset_name="random",
585
586
    dataset_path="",
    tokenizer=None,
587
588
    random_input_len=4096,
    random_output_len=2048,
589
    sharegpt_context_len=None,
590
    disable_stream=False,
591
    disable_ignore_eos=False,
592
    need_warmup=False,
593
    seed: int = 0,
594
):
595
596
597
598
599
600
601
602
603
604
    # Launch the server
    base_url = DEFAULT_URL_FOR_TEST
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
    )

    # Run benchmark
605
    args = get_benchmark_args(
606
        base_url=base_url,
607
        dataset_name=dataset_name,
608
609
        dataset_path=dataset_path,
        tokenizer=tokenizer,
610
        num_prompts=num_prompts,
611
612
        random_input_len=random_input_len,
        random_output_len=random_output_len,
613
        sharegpt_context_len=sharegpt_context_len,
614
        request_rate=request_rate,
615
        disable_stream=disable_stream,
616
        disable_ignore_eos=disable_ignore_eos,
617
        seed=seed,
618
619
620
    )

    try:
621
622
623
624
        if need_warmup:
            warmup_args = copy.deepcopy(args)
            warmup_args.num_prompts = 16
            run_benchmark(warmup_args)
625
626
        res = run_benchmark(args)
    finally:
627
        kill_process_tree(process.pid)
628
629
630

    assert res["completed"] == num_prompts
    return res
631
632


633
634
635
636
637
638
def run_bench_serving_multi(
    model,
    base_url,
    other_server_args,
    benchmark_args,
    need_warmup=False,
639
    pd_seperated=False,
640
641
642
643
644
645
646
):
    # Launch the server
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
647
        pd_seperated=pd_seperated,
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
    )

    # run benchmark for all
    res_l = []
    try:
        for args in benchmark_args:
            if need_warmup:
                warmup_args = copy.deepcopy(args)
                warmup_args.num_prompts = 16
                run_benchmark(warmup_args)

            res = run_benchmark(args)
            res_l.append((args, res))
    finally:
        kill_process_tree(process.pid)

    return res_l


667
def run_bench_one_batch(model, other_args):
668
669
670
    command = [
        "python3",
        "-m",
671
        "sglang.bench_one_batch",
672
673
674
675
676
677
        "--batch-size",
        "1",
        "--input",
        "128",
        "--output",
        "8",
678
        *[str(x) for x in other_args],
679
    ]
saienduri's avatar
saienduri committed
680
681
    if model is not None:
        command += ["--model-path", model]
682
683
684
685
686
687
688
689
690
691
692
693
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    try:
        stdout, stderr = process.communicate()
        output = stdout.decode()
        error = stderr.decode()
        print(f"Output: {output}", flush=True)
        print(f"Error: {error}", flush=True)

        lastline = output.split("\n")[-3]
        output_throughput = float(lastline.split(" ")[-2])
    finally:
694
        kill_process_tree(process.pid)
695
696

    return output_throughput
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730


def lcs(X, Y):
    m = len(X)
    n = len(Y)
    L = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

    return L[m][n]


def calculate_rouge_l(output_strs_list1, output_strs_list2):
    """calculate the ROUGE-L score"""
    rouge_l_scores = []

    for s1, s2 in zip(output_strs_list1, output_strs_list2):
        lcs_len = lcs(s1, s2)
        precision = lcs_len / len(s1) if len(s1) > 0 else 0
        recall = lcs_len / len(s2) if len(s2) > 0 else 0
        if precision + recall > 0:
            fmeasure = (2 * precision * recall) / (precision + recall)
        else:
            fmeasure = 0.0
        rouge_l_scores.append(fmeasure)

    return rouge_l_scores
731
732
733


STDERR_FILENAME = "stderr.txt"
734
STDOUT_FILENAME = "stdout.txt"
735
736


737
def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
738
    """Print the output in real time with another thread."""
739
    while not os.path.exists(filename):
740
741
        time.sleep(1)

742
743
    pt = 0
    while pt >= 0:
744
        if pt > 0 and not os.path.exists(filename):
745
            break
746
        lines = open(filename).readlines()
747
748
        for line in lines[pt:]:
            print(line, end="", flush=True)
749
            output_lines.append(line)
750
            pt += 1
751
        time.sleep(0.1)
752
753


754
755
def run_and_check_memory_leak(
    workload_func,
756
    disable_radix_cache,
757
    enable_mixed_chunk,
758
    disable_overlap,
759
    chunked_prefill_size,
760
    assert_has_abort,
761
):
762
763
764
765
766
767
    other_args = [
        "--chunked-prefill-size",
        str(chunked_prefill_size),
        "--log-level",
        "debug",
    ]
768
769
770
771
    if disable_radix_cache:
        other_args += ["--disable-radix-cache"]
    if enable_mixed_chunk:
        other_args += ["--enable-mixed-chunk"]
772
773
    if disable_overlap:
        other_args += ["--disable-overlap-schedule"]
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794

    model = DEFAULT_MODEL_NAME_FOR_TEST
    port = random.randint(4000, 5000)
    base_url = f"http://127.0.0.1:{port}"

    # Create files and launch the server
    stdout = open(STDOUT_FILENAME, "w")
    stderr = open(STDERR_FILENAME, "w")
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_args,
        return_stdout_stderr=(stdout, stderr),
    )

    # Launch a thread to stream the output
    output_lines = []
    t = threading.Thread(target=read_output, args=(output_lines,))
    t.start()

795
796
    # Run the workload
    workload_func(base_url, model)
797
798

    # Clean up everything
799
    kill_process_tree(process.pid)
800
801
    stdout.close()
    stderr.close()
802
803
804
805
    if os.path.exists(STDOUT_FILENAME):
        os.remove(STDOUT_FILENAME)
    if os.path.exists(STDERR_FILENAME):
        os.remove(STDERR_FILENAME)
Lianmin Zheng's avatar
Lianmin Zheng committed
806
    kill_process_tree(process.pid)
807
808
809
810
811
    t.join()

    # Assert success
    has_new_server = False
    has_leak = False
812
    has_abort = False
813
    for line in output_lines:
Lianmin Zheng's avatar
Lianmin Zheng committed
814
        if "Uvicorn running" in line:
815
816
817
            has_new_server = True
        if "leak" in line:
            has_leak = True
818
819
        if "Abort" in line:
            has_abort = True
820
821

    assert has_new_server
822
    assert not has_leak
823
824
    if assert_has_abort:
        assert has_abort
825
826


827
828
829
830
def run_command_and_capture_output(command, env: Optional[dict] = None):
    stdout = open(STDOUT_FILENAME, "w")
    stderr = open(STDERR_FILENAME, "w")
    process = subprocess.Popen(
831
        command, stdout=stdout, stderr=stdout, env=env, text=True
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
    )

    # Launch a thread to stream the output
    output_lines = []
    t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME))
    t.start()

    # Join the process
    process.wait()

    stdout.close()
    stderr.close()
    if os.path.exists(STDOUT_FILENAME):
        os.remove(STDOUT_FILENAME)
    if os.path.exists(STDERR_FILENAME):
        os.remove(STDERR_FILENAME)
    kill_process_tree(process.pid)
    t.join()

    return output_lines


854
855
856
def run_mmlu_test(
    disable_radix_cache=False,
    enable_mixed_chunk=False,
857
    disable_overlap=False,
858
859
860
861
862
863
864
865
866
867
868
869
870
871
    chunked_prefill_size=32,
):
    def workload_func(base_url, model):
        # Run the eval
        args = SimpleNamespace(
            base_url=base_url,
            model=model,
            eval_name="mmlu",
            num_examples=128,
            num_threads=128,
        )

        try:
            metrics = run_eval(args)
Lianmin Zheng's avatar
Lianmin Zheng committed
872
            assert metrics["score"] >= 0.65, f"{metrics=}"
873
874
875
        finally:
            pass

Chayenne's avatar
Chayenne committed
876
877
878
879
    run_and_check_memory_leak(
        workload_func,
        disable_radix_cache,
        enable_mixed_chunk,
880
        disable_overlap,
Chayenne's avatar
Chayenne committed
881
        chunked_prefill_size,
882
        assert_has_abort=False,
Chayenne's avatar
Chayenne committed
883
    )
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914


def run_mulit_request_test(
    disable_radix_cache=False,
    enable_mixed_chunk=False,
    enable_overlap=False,
    chunked_prefill_size=32,
):
    def workload_func(base_url, model):
        def run_one(_):
            prompt = """
            System: You are a helpful assistant.
            User: What is the capital of France?
            Assistant: The capital of France is
            """

            response = requests.post(
                f"{base_url}/generate",
                json={
                    "text": prompt,
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 8,
                    },
                },
            )
            ret = response.json()

        with ThreadPoolExecutor(2) as executor:
            list(executor.map(run_one, list(range(4))))

Chayenne's avatar
Chayenne committed
915
916
917
918
919
920
    run_and_check_memory_leak(
        workload_func,
        disable_radix_cache,
        enable_mixed_chunk,
        enable_overlap,
        chunked_prefill_size,
921
        assert_has_abort=False,
Chayenne's avatar
Chayenne committed
922
    )
923
924
925


def write_github_step_summary(content):
926
927
928
929
    if not os.environ.get("GITHUB_STEP_SUMMARY"):
        logging.warning("GITHUB_STEP_SUMMARY environment variable not set")
        return

930
931
    with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
        f.write(content)
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006


def run_logprob_check(self: unittest.TestCase, arg: Tuple):
    (
        input_len,
        output_len,
        temperature,
        logprob_start_len,
        return_logprob,
        top_logprobs_num,
    ) = arg
    input_ids = list(range(input_len))

    response = requests.post(
        self.base_url + "/generate",
        json={
            "input_ids": input_ids,
            "sampling_params": {
                "temperature": temperature,
                "max_new_tokens": output_len,
                "ignore_eos": True,
            },
            "return_logprob": return_logprob,
            "logprob_start_len": logprob_start_len,
            "top_logprobs_num": top_logprobs_num,
        },
    )
    response_json = response.json()

    res = response_json
    self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
    self.assertEqual(res["meta_info"]["completion_tokens"], output_len)

    # Test the number of tokens are correct
    if return_logprob:
        self.assertEqual(
            len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
            res["meta_info"]["prompt_tokens"],
        )
        self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)

        if top_logprobs_num:
            self.assertEqual(
                len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
                res["meta_info"]["prompt_tokens"],
            )
            self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)

            for i in range(output_len):
                self.assertEqual(
                    len(res["meta_info"]["output_top_logprobs"][i]),
                    top_logprobs_num,
                )

                # Test the top-1 tokens are the same as output tokens if temperature == 0
                if temperature == 0:
                    rank = 0
                    while rank < len(res["meta_info"]["output_top_logprobs"][i]):
                        try:
                            self.assertListEqual(
                                res["meta_info"]["output_token_logprobs"][i],
                                res["meta_info"]["output_top_logprobs"][i][rank],
                            )
                            break
                        except AssertionError:
                            # There's a tie. Allow the second item in this case.
                            if (
                                res["meta_info"]["output_top_logprobs"][i][rank][0]
                                == res["meta_info"]["output_top_logprobs"][i][rank + 1][
                                    0
                                ]
                            ):
                                rank += 1
                            else:
                                raise
1007
1008
1009
1010


class CustomTestCase(unittest.TestCase):
    def _callTestMethod(self, method):
1011
        max_retry = int(
Yineng Zhang's avatar
Yineng Zhang committed
1012
            os.environ.get("SGLANG_TEST_MAX_RETRY", "1" if is_in_ci() else "0")
1013
        )
1014
1015
1016
        retry(
            lambda: super(CustomTestCase, self)._callTestMethod(method),
            max_retry=max_retry,
1017
        )