test_utils.py 30.4 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
"""Common utilities for testing and benchmarking"""
2

3
import argparse
4
import copy
5
import logging
6
import os
7
import random
8
import subprocess
9
import threading
10
import time
11
import unittest
12
from concurrent.futures import ThreadPoolExecutor
Byron Hsu's avatar
Byron Hsu committed
13
from dataclasses import dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
14
from functools import partial
15
from types import SimpleNamespace
16
from typing import Callable, List, Optional, Tuple
Liangsheng Yin's avatar
Liangsheng Yin committed
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import numpy as np
import requests
20
21
import torch
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
22

23
from sglang.bench_serving import run_benchmark
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from sglang.global_config import global_config
Ying Sheng's avatar
Ying Sheng committed
25
26
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
27
28
29
30
31
32
from sglang.srt.utils import (
    get_bool_env_var,
    is_port_available,
    kill_process_tree,
    retry,
)
33
from sglang.test.run_eval import run_eval
34
from sglang.utils import get_exception_traceback
Liangsheng Yin's avatar
Liangsheng Yin committed
35

Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# General test models
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"

# MLA test models
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN"

# FP8 models
DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_ACCURACY_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST_FP8 = (
HandH1998's avatar
HandH1998 committed
52
53
    "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
)
Lianmin Zheng's avatar
Lianmin Zheng committed
54
DEFAULT_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_FP8 = (
55
56
57
    "nvidia/Llama-3.1-8B-Instruct-FP8"
)

Lianmin Zheng's avatar
Lianmin Zheng committed
58
59
60
# EAGLE
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
Stefan He's avatar
Stefan He committed
61
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63

# Other use cases
Stefan He's avatar
Stefan He committed
64
65
66
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
    "meta-llama/Llama-4-Scout-17B-16E-Instruct"
)
67
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
Xihuai Wang's avatar
Xihuai Wang committed
68
DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
69
70
71
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = (
    "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
)
Lianmin Zheng's avatar
Lianmin Zheng committed
72
73

# Nightly tests
74
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
75
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
76
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
Ke Bao's avatar
Ke Bao committed
77
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
78
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
79
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
80
81
82
83
84
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"

DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"

Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000

87
88
89

def is_in_ci():
    """Return whether it is in CI runner."""
90
    return get_bool_env_var("SGLANG_IS_IN_CI")
91
92
93


if is_in_ci():
94
95
96
    DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
        5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
    )
97
else:
98
99
100
101
    DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
        7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
    )
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
102

Lianmin Zheng's avatar
Lianmin Zheng committed
103

Liangsheng Yin's avatar
Liangsheng Yin committed
104
105
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
    assert url is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    data = {
        "inputs": prompt,
        "parameters": {
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "stop_sequences": stop,
        },
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    pred = res.json()["generated_text"][0]
    return pred


121
122
123
124
125
126
127
128
129
130
131
def find_available_port(base_port: int):
    port = base_port + random.randint(100, 1000)
    while True:
        if is_port_available(port):
            return port
        if port < 60000:
            port += 42
        else:
            port -= 43


Liangsheng Yin's avatar
Liangsheng Yin committed
132
133
134
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
        "n": n,
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    if n == 1:
        pred = res.json()["text"][0][len(prompt) :]
    else:
        pred = [x[len(prompt) :] for x in res.json()["text"]]
    return pred


151
def call_generate_outlines(
152
    prompt, temperature, max_tokens, stop=None, regex=None, n=1, url=None
153
):
Liangsheng Yin's avatar
Liangsheng Yin committed
154
155
    assert url is not None

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
        "regex": regex,
        "n": n,
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    if n == 1:
        pred = res.json()["text"][0][len(prompt) :]
    else:
        pred = [x[len(prompt) :] for x in res.json()["text"]]
    return pred


Liangsheng Yin's avatar
Liangsheng Yin committed
173
174
175
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    data = {
        "text": prompt,
        "sampling_params": {
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "stop": stop,
        },
    }
    res = requests.post(url, json=data)
    assert res.status_code == 200
    obj = res.json()
    pred = obj["text"]
    return pred


Liangsheng Yin's avatar
Liangsheng Yin committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def call_generate_guidance(
    prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
    assert model is not None
    from guidance import gen

    rets = []
    for _ in range(n):
        out = (
            model
            + prompt
            + gen(
                name="answer",
                max_tokens=max_tokens,
                temperature=temperature,
                stop=stop,
                regex=regex,
            )
        )
        rets.append(out["answer"])
    return rets if n > 1 else rets[0]


def call_select_lightllm(context, choices, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    scores = []
    for i in range(len(choices)):
        data = {
            "inputs": context + choices[i],
            "parameters": {
                "max_new_tokens": 1,
            },
        }
        res = requests.post(url, json=data)
        assert res.status_code == 200
        scores.append(0)
    return np.argmax(scores)


Liangsheng Yin's avatar
Liangsheng Yin committed
231
232
233
def call_select_vllm(context, choices, url=None):
    assert url is not None

Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
236
237
238
239
240
241
242
    scores = []
    for i in range(len(choices)):
        data = {
            "prompt": context + choices[i],
            "max_tokens": 1,
            "prompt_logprobs": 1,
        }
        res = requests.post(url, json=data)
        assert res.status_code == 200
Lianmin Zheng's avatar
Lianmin Zheng committed
243
        scores.append(res.json().get("prompt_score", 0))
Lianmin Zheng's avatar
Lianmin Zheng committed
244
245
246
247
248
249
250
251
252
253
254
    return np.argmax(scores)

    """
    Modify vllm/entrypoints/api_server.py

    if final_output.prompt_logprobs is not None:
        score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])])
        ret["prompt_score"] = score
    """


Liangsheng Yin's avatar
Liangsheng Yin committed
255
256
257
258
259
260
261
262
def call_select_guidance(context, choices, model=None):
    assert model is not None
    from guidance import select

    out = model + context + select(choices, name="answer")
    return choices.index(out["answer"])


263
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
Lianmin Zheng's avatar
Lianmin Zheng committed
264
    parser.add_argument("--parallel", type=int, default=64)
Lianmin Zheng's avatar
Lianmin Zheng committed
265
266
267
268
269
270
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=None)
    parser.add_argument(
        "--backend",
        type=str,
        required=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
271
272
273
274
        choices=[
            "vllm",
            "outlines",
            "lightllm",
275
            "gserver",
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
278
279
            "guidance",
            "srt-raw",
            "llama.cpp",
        ],
Lianmin Zheng's avatar
Lianmin Zheng committed
280
    )
Liangsheng Yin's avatar
Liangsheng Yin committed
281
    parser.add_argument("--n-ctx", type=int, default=4096)
Lianmin Zheng's avatar
Lianmin Zheng committed
282
283
284
285
286
287
288
289
290
    parser.add_argument(
        "--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
    )
    parser.add_argument("--result-file", type=str, default="result.jsonl")
    args = parser.parse_args()

    if args.port is None:
        default_port = {
            "vllm": 21000,
Liangsheng Yin's avatar
Liangsheng Yin committed
291
            "outlines": 21000,
Lianmin Zheng's avatar
Lianmin Zheng committed
292
293
            "lightllm": 22000,
            "srt-raw": 30000,
294
            "gserver": 9988,
Lianmin Zheng's avatar
Lianmin Zheng committed
295
296
297
298
299
        }
        args.port = default_port.get(args.backend, None)
    return args


300
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
Lianmin Zheng's avatar
Lianmin Zheng committed
301
302
303
304
305
306
307
308
309
    parser.add_argument("--parallel", type=int, default=64)
    parser.add_argument("--host", type=str, default="http://127.0.0.1")
    parser.add_argument("--port", type=int, default=30000)
    parser.add_argument("--backend", type=str, default="srt")
    parser.add_argument("--result-file", type=str, default="result.jsonl")
    args = parser.parse_args()
    return args


310
def select_sglang_backend(args: argparse.Namespace):
Lianmin Zheng's avatar
Lianmin Zheng committed
311
312
313
314
    if args.backend.startswith("srt"):
        if args.backend == "srt-no-parallel":
            global_config.enable_parallel_encoding = False
        backend = RuntimeEndpoint(f"{args.host}:{args.port}")
315
    elif args.backend.startswith("gpt-"):
Lianmin Zheng's avatar
Lianmin Zheng committed
316
317
318
319
        backend = OpenAI(args.backend)
    else:
        raise ValueError(f"Invalid backend: {args.backend}")
    return backend
Liangsheng Yin's avatar
Liangsheng Yin committed
320
321


322
def _get_call_generate(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
323
324
325
326
327
328
    if args.backend == "lightllm":
        return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "vllm":
        return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "srt-raw":
        return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
329
330
    elif args.backend == "gserver":
        return partial(call_generate_gserver, url=f"{args.host}:{args.port}")
Liangsheng Yin's avatar
Liangsheng Yin committed
331
332
333
334
335
336
337
338
339
340
341
342
343
    elif args.backend == "outlines":
        return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "guidance":
        from guidance import models

        model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
        call_generate = partial(call_generate_guidance, model=model)
        call_generate("Hello,", 1.0, 8, ".")
        return call_generate
    else:
        raise ValueError(f"Invalid backend: {args.backend}")


344
def _get_call_select(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    if args.backend == "lightllm":
        return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "vllm":
        return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
    elif args.backend == "guidance":
        from guidance import models

        model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
        call_select = partial(call_select_guidance, model=model)

        call_select("Hello,", ["world", "earth"])
        return call_select
    else:
        raise ValueError(f"Invalid backend: {args.backend}")


361
def get_call_generate(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
362
363
364
365
366
367
368
369
370
371
372
373
    call_generate = _get_call_generate(args)

    def func(*args, **kwargs):
        try:
            return call_generate(*args, **kwargs)
        except Exception:
            print("Exception in call_generate:\n" + get_exception_traceback())
            raise

    return func


374
def get_call_select(args: argparse.Namespace):
Liangsheng Yin's avatar
Liangsheng Yin committed
375
376
377
378
379
380
381
382
383
384
    call_select = _get_call_select(args)

    def func(*args, **kwargs):
        try:
            return call_select(*args, **kwargs)
        except Exception:
            print("Exception in call_select:\n" + get_exception_traceback())
            raise

    return func
385
386


387
def popen_launch_server(
388
389
390
391
    model: str,
    base_url: str,
    timeout: float,
    api_key: Optional[str] = None,
Mick's avatar
Mick committed
392
    other_args: list[str] = (),
393
    env: Optional[dict] = None,
394
    return_stdout_stderr: Optional[tuple] = None,
395
    pd_seperated: bool = False,
396
397
398
399
):
    _, host, port = base_url.split(":")
    host = host[2:]

400
401
402
403
404
    if pd_seperated:
        command = "sglang.launch_pd_server"
    else:
        command = "sglang.launch_server"

405
406
407
    command = [
        "python3",
        "-m",
408
        command,
409
410
        "--model-path",
        model,
411
        *[str(x) for x in other_args],
412
    ]
Chayenne's avatar
Chayenne committed
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    if pd_seperated:
        command.extend(
            [
                "--lb-host",
                host,
                "--lb-port",
                port,
            ]
        )
    else:
        command.extend(
            [
                "--host",
                host,
                "--port",
                port,
            ]
        )

433
434
435
    if api_key:
        command += ["--api-key", api_key]

436
437
    print(f"command={' '.join(command)}")

438
439
440
    if return_stdout_stderr:
        process = subprocess.Popen(
            command,
441
442
            stdout=return_stdout_stderr[0],
            stderr=return_stdout_stderr[1],
443
444
445
446
447
            env=env,
            text=True,
        )
    else:
        process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
448
449

    start_time = time.time()
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    with requests.Session() as session:
        while time.time() - start_time < timeout:
            try:
                headers = {
                    "Content-Type": "application/json; charset=utf-8",
                    "Authorization": f"Bearer {api_key}",
                }
                response = session.get(
                    f"{base_url}/health_generate",
                    headers=headers,
                )
                if response.status_code == 200:
                    return process
            except requests.RequestException:
                pass
465
466
467

            return_code = process.poll()
            if return_code is not None:
fzyzcjy's avatar
fzyzcjy committed
468
469
470
                raise Exception(
                    f"Server unexpectedly exits ({return_code=}). Usually there will be error logs describing the cause far above this line."
                )
471

472
            time.sleep(10)
473
474

    kill_process_tree(process.pid)
475
    raise TimeoutError("Server failed to start within the timeout period.")
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501


def run_with_timeout(
    func: Callable,
    args: tuple = (),
    kwargs: Optional[dict] = None,
    timeout: float = None,
):
    """Run a function with timeout."""
    ret_value = []

    def _target_func():
        ret_value.append(func(*args, **(kwargs or {})))

    t = threading.Thread(target=_target_func)
    t.start()
    t.join(timeout=timeout)
    if t.is_alive():
        raise TimeoutError()

    if not ret_value:
        raise RuntimeError()

    return ret_value[0]


Byron Hsu's avatar
Byron Hsu committed
502
503
504
505
506
507
508
@dataclass
class TestFile:
    name: str
    estimated_time: float = 60


def run_unittest_files(files: List[TestFile], timeout_per_file: float):
509
510
511
    tic = time.time()
    success = True

Lianmin Zheng's avatar
Lianmin Zheng committed
512
    for i, file in enumerate(files):
Lianmin Zheng's avatar
Lianmin Zheng committed
513
        filename, estimated_time = file.name, file.estimated_time
514
        process = None
515

Mingyi's avatar
Mingyi committed
516
        def run_one_file(filename):
517
518
            nonlocal process

Mingyi's avatar
Mingyi committed
519
            filename = os.path.join(os.getcwd(), filename)
Lianmin Zheng's avatar
Lianmin Zheng committed
520
            print(
Lianmin Zheng's avatar
Lianmin Zheng committed
521
                f".\n.\nBegin ({i}/{len(files) - 1}):\npython3 {filename}\n.\n.\n",
Lianmin Zheng's avatar
Lianmin Zheng committed
522
523
                flush=True,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
524
525
            tic = time.time()

Mingyi's avatar
Mingyi committed
526
527
528
529
            process = subprocess.Popen(
                ["python3", filename], stdout=None, stderr=None, env=os.environ
            )
            process.wait()
Lianmin Zheng's avatar
Lianmin Zheng committed
530
531
532
            elapsed = time.time() - tic

            print(
Lianmin Zheng's avatar
Lianmin Zheng committed
533
                f".\n.\nEnd ({i}/{len(files) - 1}):\n{filename=}, {elapsed=:.0f}, {estimated_time=}\n.\n.\n",
Lianmin Zheng's avatar
Lianmin Zheng committed
534
535
                flush=True,
            )
Mingyi's avatar
Mingyi committed
536
            return process.returncode
537
538

        try:
Mingyi's avatar
Mingyi committed
539
540
541
            ret_code = run_with_timeout(
                run_one_file, args=(filename,), timeout=timeout_per_file
            )
542
543
544
            assert (
                ret_code == 0
            ), f"expected return code 0, but {filename} returned {ret_code}"
545
        except TimeoutError:
546
            kill_process_tree(process.pid)
547
548
            time.sleep(5)
            print(
549
550
                f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
                flush=True,
551
            )
Mingyi's avatar
Mingyi committed
552
553
            success = False
            break
554
555

    if success:
556
        print(f"Success. Time elapsed: {time.time() - tic:.2f}s", flush=True)
557
    else:
558
        print(f"Fail. Time elapsed: {time.time() - tic:.2f}s", flush=True)
559
560

    return 0 if success else -1
561
562
563
564


def get_similarities(vec1, vec2):
    return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
565
566


567
568
569
570
571
572
def get_benchmark_args(
    base_url="",
    dataset_name="",
    dataset_path="",
    tokenizer="",
    num_prompts=500,
573
    sharegpt_output_len=None,
574
575
    random_input_len=4096,
    random_output_len=2048,
576
    sharegpt_context_len=None,
577
578
579
    request_rate=float("inf"),
    disable_stream=False,
    disable_ignore_eos=False,
580
    seed: int = 0,
581
    pd_seperated: bool = False,
582
583
584
585
586
587
588
589
590
591
592
):
    return SimpleNamespace(
        backend="sglang",
        base_url=base_url,
        host=None,
        port=None,
        dataset_name=dataset_name,
        dataset_path=dataset_path,
        model=None,
        tokenizer=tokenizer,
        num_prompts=num_prompts,
593
594
        sharegpt_output_len=sharegpt_output_len,
        sharegpt_context_len=sharegpt_context_len,
595
596
597
598
599
600
601
602
603
        random_input_len=random_input_len,
        random_output_len=random_output_len,
        random_range_ratio=0.0,
        request_rate=request_rate,
        multi=None,
        output_file=None,
        disable_tqdm=False,
        disable_stream=disable_stream,
        return_logprob=False,
604
        seed=seed,
605
606
607
608
609
        disable_ignore_eos=disable_ignore_eos,
        extra_request_body=None,
        apply_chat_template=False,
        profile=None,
        lora_name=None,
610
611
        prompt_suffix="",
        pd_seperated=pd_seperated,
612
613
614
    )


615
616
617
618
619
620
def run_bench_serving(
    model,
    num_prompts,
    request_rate,
    other_server_args,
    dataset_name="random",
621
622
    dataset_path="",
    tokenizer=None,
623
624
    random_input_len=4096,
    random_output_len=2048,
625
    sharegpt_context_len=None,
626
    disable_stream=False,
627
    disable_ignore_eos=False,
628
    need_warmup=False,
629
    seed: int = 0,
630
):
631
632
633
634
635
636
637
638
639
640
    # Launch the server
    base_url = DEFAULT_URL_FOR_TEST
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
    )

    # Run benchmark
641
    args = get_benchmark_args(
642
        base_url=base_url,
643
        dataset_name=dataset_name,
644
645
        dataset_path=dataset_path,
        tokenizer=tokenizer,
646
        num_prompts=num_prompts,
647
648
        random_input_len=random_input_len,
        random_output_len=random_output_len,
649
        sharegpt_context_len=sharegpt_context_len,
650
        request_rate=request_rate,
651
        disable_stream=disable_stream,
652
        disable_ignore_eos=disable_ignore_eos,
653
        seed=seed,
654
655
656
    )

    try:
657
658
659
660
        if need_warmup:
            warmup_args = copy.deepcopy(args)
            warmup_args.num_prompts = 16
            run_benchmark(warmup_args)
661
662
        res = run_benchmark(args)
    finally:
663
        kill_process_tree(process.pid)
664
665
666

    assert res["completed"] == num_prompts
    return res
667
668


669
670
671
672
673
674
def run_bench_serving_multi(
    model,
    base_url,
    other_server_args,
    benchmark_args,
    need_warmup=False,
675
    pd_seperated=False,
676
677
678
679
680
681
682
):
    # Launch the server
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_server_args,
683
        pd_seperated=pd_seperated,
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    )

    # run benchmark for all
    res_l = []
    try:
        for args in benchmark_args:
            if need_warmup:
                warmup_args = copy.deepcopy(args)
                warmup_args.num_prompts = 16
                run_benchmark(warmup_args)

            res = run_benchmark(args)
            res_l.append((args, res))
    finally:
        kill_process_tree(process.pid)

    return res_l


703
def run_bench_one_batch(model, other_args):
704
705
706
    command = [
        "python3",
        "-m",
707
        "sglang.bench_one_batch",
708
709
710
711
712
713
        "--batch-size",
        "1",
        "--input",
        "128",
        "--output",
        "8",
714
        *[str(x) for x in other_args],
715
    ]
saienduri's avatar
saienduri committed
716
717
    if model is not None:
        command += ["--model-path", model]
718
719
720
721
722
723
724
725
726
727
728
729
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    try:
        stdout, stderr = process.communicate()
        output = stdout.decode()
        error = stderr.decode()
        print(f"Output: {output}", flush=True)
        print(f"Error: {error}", flush=True)

        lastline = output.split("\n")[-3]
        output_throughput = float(lastline.split(" ")[-2])
    finally:
730
        kill_process_tree(process.pid)
731
732

    return output_throughput
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766


def lcs(X, Y):
    m = len(X)
    n = len(Y)
    L = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

    return L[m][n]


def calculate_rouge_l(output_strs_list1, output_strs_list2):
    """calculate the ROUGE-L score"""
    rouge_l_scores = []

    for s1, s2 in zip(output_strs_list1, output_strs_list2):
        lcs_len = lcs(s1, s2)
        precision = lcs_len / len(s1) if len(s1) > 0 else 0
        recall = lcs_len / len(s2) if len(s2) > 0 else 0
        if precision + recall > 0:
            fmeasure = (2 * precision * recall) / (precision + recall)
        else:
            fmeasure = 0.0
        rouge_l_scores.append(fmeasure)

    return rouge_l_scores
767
768
769


STDERR_FILENAME = "stderr.txt"
770
STDOUT_FILENAME = "stdout.txt"
771
772


773
def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
774
    """Print the output in real time with another thread."""
775
    while not os.path.exists(filename):
776
777
        time.sleep(1)

778
779
    pt = 0
    while pt >= 0:
780
        if pt > 0 and not os.path.exists(filename):
781
            break
782
        lines = open(filename).readlines()
783
784
        for line in lines[pt:]:
            print(line, end="", flush=True)
785
            output_lines.append(line)
786
            pt += 1
787
        time.sleep(0.1)
788
789


790
791
def run_and_check_memory_leak(
    workload_func,
792
    disable_radix_cache,
793
    enable_mixed_chunk,
794
    disable_overlap,
795
    chunked_prefill_size,
796
    assert_has_abort,
797
):
798
799
800
801
802
803
    other_args = [
        "--chunked-prefill-size",
        str(chunked_prefill_size),
        "--log-level",
        "debug",
    ]
804
805
806
807
    if disable_radix_cache:
        other_args += ["--disable-radix-cache"]
    if enable_mixed_chunk:
        other_args += ["--enable-mixed-chunk"]
808
809
    if disable_overlap:
        other_args += ["--disable-overlap-schedule"]
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830

    model = DEFAULT_MODEL_NAME_FOR_TEST
    port = random.randint(4000, 5000)
    base_url = f"http://127.0.0.1:{port}"

    # Create files and launch the server
    stdout = open(STDOUT_FILENAME, "w")
    stderr = open(STDERR_FILENAME, "w")
    process = popen_launch_server(
        model,
        base_url,
        timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
        other_args=other_args,
        return_stdout_stderr=(stdout, stderr),
    )

    # Launch a thread to stream the output
    output_lines = []
    t = threading.Thread(target=read_output, args=(output_lines,))
    t.start()

831
832
    # Run the workload
    workload_func(base_url, model)
833
834

    # Clean up everything
835
    kill_process_tree(process.pid)
836
837
    stdout.close()
    stderr.close()
838
839
840
841
    if os.path.exists(STDOUT_FILENAME):
        os.remove(STDOUT_FILENAME)
    if os.path.exists(STDERR_FILENAME):
        os.remove(STDERR_FILENAME)
Lianmin Zheng's avatar
Lianmin Zheng committed
842
    kill_process_tree(process.pid)
843
844
845
846
847
    t.join()

    # Assert success
    has_new_server = False
    has_leak = False
848
    has_abort = False
849
    for line in output_lines:
Lianmin Zheng's avatar
Lianmin Zheng committed
850
        if "Uvicorn running" in line:
851
852
853
            has_new_server = True
        if "leak" in line:
            has_leak = True
854
855
        if "Abort" in line:
            has_abort = True
856
857

    assert has_new_server
858
    assert not has_leak
859
860
    if assert_has_abort:
        assert has_abort
861
862


863
864
865
866
def run_command_and_capture_output(command, env: Optional[dict] = None):
    stdout = open(STDOUT_FILENAME, "w")
    stderr = open(STDERR_FILENAME, "w")
    process = subprocess.Popen(
867
        command, stdout=stdout, stderr=stdout, env=env, text=True
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
    )

    # Launch a thread to stream the output
    output_lines = []
    t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME))
    t.start()

    # Join the process
    process.wait()

    stdout.close()
    stderr.close()
    if os.path.exists(STDOUT_FILENAME):
        os.remove(STDOUT_FILENAME)
    if os.path.exists(STDERR_FILENAME):
        os.remove(STDERR_FILENAME)
    kill_process_tree(process.pid)
    t.join()

    return output_lines


890
891
892
def run_mmlu_test(
    disable_radix_cache=False,
    enable_mixed_chunk=False,
893
    disable_overlap=False,
894
895
896
897
898
899
900
901
902
903
904
905
906
907
    chunked_prefill_size=32,
):
    def workload_func(base_url, model):
        # Run the eval
        args = SimpleNamespace(
            base_url=base_url,
            model=model,
            eval_name="mmlu",
            num_examples=128,
            num_threads=128,
        )

        try:
            metrics = run_eval(args)
Lianmin Zheng's avatar
Lianmin Zheng committed
908
            assert metrics["score"] >= 0.65, f"{metrics=}"
909
910
911
        finally:
            pass

Chayenne's avatar
Chayenne committed
912
913
914
915
    run_and_check_memory_leak(
        workload_func,
        disable_radix_cache,
        enable_mixed_chunk,
916
        disable_overlap,
Chayenne's avatar
Chayenne committed
917
        chunked_prefill_size,
918
        assert_has_abort=False,
Chayenne's avatar
Chayenne committed
919
    )
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950


def run_mulit_request_test(
    disable_radix_cache=False,
    enable_mixed_chunk=False,
    enable_overlap=False,
    chunked_prefill_size=32,
):
    def workload_func(base_url, model):
        def run_one(_):
            prompt = """
            System: You are a helpful assistant.
            User: What is the capital of France?
            Assistant: The capital of France is
            """

            response = requests.post(
                f"{base_url}/generate",
                json={
                    "text": prompt,
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 8,
                    },
                },
            )
            ret = response.json()

        with ThreadPoolExecutor(2) as executor:
            list(executor.map(run_one, list(range(4))))

Chayenne's avatar
Chayenne committed
951
952
953
954
955
956
    run_and_check_memory_leak(
        workload_func,
        disable_radix_cache,
        enable_mixed_chunk,
        enable_overlap,
        chunked_prefill_size,
957
        assert_has_abort=False,
Chayenne's avatar
Chayenne committed
958
    )
959
960
961


def write_github_step_summary(content):
962
963
964
965
    if not os.environ.get("GITHUB_STEP_SUMMARY"):
        logging.warning("GITHUB_STEP_SUMMARY environment variable not set")
        return

966
967
    with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
        f.write(content)
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042


def run_logprob_check(self: unittest.TestCase, arg: Tuple):
    (
        input_len,
        output_len,
        temperature,
        logprob_start_len,
        return_logprob,
        top_logprobs_num,
    ) = arg
    input_ids = list(range(input_len))

    response = requests.post(
        self.base_url + "/generate",
        json={
            "input_ids": input_ids,
            "sampling_params": {
                "temperature": temperature,
                "max_new_tokens": output_len,
                "ignore_eos": True,
            },
            "return_logprob": return_logprob,
            "logprob_start_len": logprob_start_len,
            "top_logprobs_num": top_logprobs_num,
        },
    )
    response_json = response.json()

    res = response_json
    self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
    self.assertEqual(res["meta_info"]["completion_tokens"], output_len)

    # Test the number of tokens are correct
    if return_logprob:
        self.assertEqual(
            len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
            res["meta_info"]["prompt_tokens"],
        )
        self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)

        if top_logprobs_num:
            self.assertEqual(
                len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
                res["meta_info"]["prompt_tokens"],
            )
            self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)

            for i in range(output_len):
                self.assertEqual(
                    len(res["meta_info"]["output_top_logprobs"][i]),
                    top_logprobs_num,
                )

                # Test the top-1 tokens are the same as output tokens if temperature == 0
                if temperature == 0:
                    rank = 0
                    while rank < len(res["meta_info"]["output_top_logprobs"][i]):
                        try:
                            self.assertListEqual(
                                res["meta_info"]["output_token_logprobs"][i],
                                res["meta_info"]["output_top_logprobs"][i][rank],
                            )
                            break
                        except AssertionError:
                            # There's a tie. Allow the second item in this case.
                            if (
                                res["meta_info"]["output_top_logprobs"][i][rank][0]
                                == res["meta_info"]["output_top_logprobs"][i][rank + 1][
                                    0
                                ]
                            ):
                                rank += 1
                            else:
                                raise
1043
1044
1045
1046


class CustomTestCase(unittest.TestCase):
    def _callTestMethod(self, method):
1047
        max_retry = int(
Yineng Zhang's avatar
Yineng Zhang committed
1048
            os.environ.get("SGLANG_TEST_MAX_RETRY", "1" if is_in_ci() else "0")
1049
        )
1050
1051
1052
        retry(
            lambda: super(CustomTestCase, self)._callTestMethod(method),
            max_retry=max_retry,
1053
        )