Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import multiprocessing as mp import multiprocessing as mp
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k): ...@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
return logprobs return logprobs
def get_token_ids_logprobs(logits, token_ids):
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
del logits
logprobs = logprobs[..., token_ids]
return logprobs
def _get_sentence_transformer_embedding_model(model_path, torch_dtype): def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from sentence_transformers.util import is_sentence_transformer_model from sentence_transformers.util import is_sentence_transformer_model
...@@ -84,8 +91,13 @@ class ModelOutput: ...@@ -84,8 +91,13 @@ class ModelOutput:
output_ids: List[int] = None output_ids: List[int] = None
top_input_logprobs: List[torch.Tensor] = None top_input_logprobs: List[torch.Tensor] = None
top_output_logprobs: List[torch.Tensor] = None top_output_logprobs: List[torch.Tensor] = None
top_output_logprob_idx: List[List[int]] = None
embed_logits: List[torch.Tensor] = None embed_logits: List[torch.Tensor] = None
scores: List[float] = None scores: List[float] = None
input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
token_ids_input_logprobs: List[torch.Tensor] = None
token_ids_output_logprobs: List[torch.Tensor] = None
class HFRunner: class HFRunner:
...@@ -157,7 +169,7 @@ class HFRunner: ...@@ -157,7 +169,7 @@ class HFRunner:
# Run forward # Run forward
while True: while True:
prompts, max_new_tokens, lora_paths = in_queue.get() prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
if lora_paths is not None: if lora_paths is not None:
assert len(prompts) == len(lora_paths) assert len(prompts) == len(lora_paths)
...@@ -165,16 +177,16 @@ class HFRunner: ...@@ -165,16 +177,16 @@ class HFRunner:
if self.model_type == "generation": if self.model_type == "generation":
out_queue.put( out_queue.put(
self.forward_generation_raw( self.forward_generation_raw(
base_model=self.base_model,
prompts=prompts, prompts=prompts,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
base_model=self.base_model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
lora_paths=lora_paths, lora_paths=lora_paths,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
output_str_only=self.output_str_only, output_str_only=self.output_str_only,
token_ids_logprob=token_ids_logprob,
) )
) )
elif self.model_type == "embedding": elif self.model_type == "embedding":
assert not self.output_str_only assert not self.output_str_only
logits = self.model.encode(prompts).tolist() logits = self.model.encode(prompts).tolist()
...@@ -199,10 +211,11 @@ class HFRunner: ...@@ -199,10 +211,11 @@ class HFRunner:
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8, max_new_tokens: int = 8,
lora_paths=None, lora_paths: Optional[List[str]] = None,
token_ids_logprob: Optional[int] = None,
): ):
self.in_queue.put((prompts, max_new_tokens, lora_paths)) self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
return self.out_queue.get() return self.out_queue.get()
def terminate(self): def terminate(self):
...@@ -218,17 +231,24 @@ class HFRunner: ...@@ -218,17 +231,24 @@ class HFRunner:
@staticmethod @staticmethod
def forward_generation_raw( def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
base_model, base_model,
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens: int,
tokenizer, tokenizer,
lora_paths,
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
output_str_only: bool, lora_paths: Optional[List[str]] = None,
output_str_only: bool = False,
token_ids_logprob: Optional[int] = None,
) -> ModelOutput: ) -> ModelOutput:
output_strs = [] output_strs = []
top_input_logprobs = [] top_input_logprobs = []
top_output_logprobs = [] top_output_logprobs = []
if token_ids_logprob is not None:
token_ids_input_logprobs = []
token_ids_output_logprobs = []
else:
token_ids_input_logprobs = token_ids_output_logprobs = None
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
if isinstance(p, str): if isinstance(p, str):
input_ids = tokenizer.encode(p, return_tensors="pt").cuda() input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
...@@ -275,18 +295,33 @@ class HFRunner: ...@@ -275,18 +295,33 @@ class HFRunner:
for logits in outputs.scores for logits in outputs.scores
] ]
) )
if token_ids_logprob is not None:
token_ids_output_logprobs.append(
[
get_token_ids_logprobs(
logits[0], token_ids_logprob
).tolist()
for logits in outputs.scores
]
)
del outputs del outputs
input_logits = model.forward(input_ids).logits[0] input_logits = model.forward(input_ids).logits[0]
top_input_logprobs.append( top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist() get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
) )
if token_ids_logprob is not None:
token_ids_input_logprobs.append(
get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
)
del input_logits del input_logits
return ModelOutput( return ModelOutput(
output_strs=output_strs, output_strs=output_strs,
top_input_logprobs=top_input_logprobs, top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs, top_output_logprobs=top_output_logprobs,
token_ids_input_logprobs=token_ids_input_logprobs,
token_ids_output_logprobs=token_ids_output_logprobs,
) )
...@@ -303,11 +338,31 @@ class SRTRunner: ...@@ -303,11 +338,31 @@ class SRTRunner:
lora_backend: str = "triton", lora_backend: str = "triton",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
chunked_prefill_size: Optional[int] = None,
dp_size: int = 1,
tokenizer_path: Optional[str] = None,
enable_ep_moe: bool = False,
mem_fraction_static: float = 0.65, mem_fraction_static: float = 0.65,
trust_remote_code: bool = False, trust_remote_code: bool = False,
speculative_draft_model_path: Optional[str] = None,
speculative_algorithm: Optional[str] = None,
speculative_num_steps: Optional[int] = None,
speculative_eagle_topk: Optional[int] = None,
speculative_num_draft_tokens: Optional[int] = None,
disable_overlap_schedule: bool = False,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
enable_dp_attention = dp_size > 1
spec_kwargs = {}
if speculative_draft_model_path:
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
spec_kwargs["speculative_algorithm"] = speculative_algorithm
spec_kwargs["speculative_num_steps"] = speculative_num_steps
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
self.engine = Engine( self.engine = Engine(
model_path=model_path, model_path=model_path,
tp_size=tp_size, tp_size=tp_size,
...@@ -321,21 +376,41 @@ class SRTRunner: ...@@ -321,21 +376,41 @@ class SRTRunner:
lora_backend=lora_backend, lora_backend=lora_backend,
disable_cuda_graph=disable_cuda_graph, disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache, disable_radix_cache=disable_radix_cache,
chunked_prefill_size=chunked_prefill_size,
enable_dp_attention=enable_dp_attention,
dp_size=dp_size,
tokenizer_path=tokenizer_path,
enable_ep_moe=enable_ep_moe,
disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4,
**spec_kwargs,
) )
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
if tokenizer_path is None:
self.tokenizer = get_tokenizer(
model_path, trust_remote_code=trust_remote_code
)
else:
self.tokenizer = None
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8, max_new_tokens: int = 8,
lora_paths=None, lora_paths: Optional[List[str]] = None,
logprob_start_len: int = 0,
top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None,
): ):
if self.is_generation: if self.is_generation:
return self.forward_generation_raw( return self.forward_generation_raw(
engine=self.engine,
prompts=prompts, prompts=prompts,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
lora_paths=lora_paths, lora_paths=lora_paths,
engine=self.engine, logprob_start_len=logprob_start_len,
top_k=top_k,
token_ids_logprob=token_ids_logprob,
) )
else: else:
response = self.engine.encode(prompts) response = self.engine.encode(prompts)
...@@ -358,10 +433,10 @@ class SRTRunner: ...@@ -358,10 +433,10 @@ class SRTRunner:
""" """
if self.is_generation: if self.is_generation:
return self.batch_forward_generation_raw( return self.batch_forward_generation_raw(
engine=self.engine,
prompts=prompts, prompts=prompts,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
lora_paths=lora_paths, lora_paths=lora_paths,
engine=self.engine,
) )
else: else:
response = self.engine.encode(prompts) response = self.engine.encode(prompts)
...@@ -381,24 +456,43 @@ class SRTRunner: ...@@ -381,24 +456,43 @@ class SRTRunner:
@staticmethod @staticmethod
def forward_generation_raw( def forward_generation_raw(
engine: Engine,
prompts: Union[List[str], List[torch.Tensor]], prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens, max_new_tokens: int = 8,
lora_paths, lora_paths: Optional[List[str]] = None,
engine, logprob_start_len: int = 0,
top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None,
): ):
# the return value contains logprobs from prefill # the return value contains logprobs from prefill
output_strs = [] output_strs = []
output_ids = []
# Input logprobs. Note that the last item in input logprob is equivalent to
# the first item in the output logprob.
top_input_logprobs = [] top_input_logprobs = []
input_token_logprobs_lst = []
top_output_logprobs = [] top_output_logprobs = []
output_token_logprobs_lst = []
top_output_logprob_idx = []
if token_ids_logprob is not None:
token_ids_input_logprobs = []
token_ids_output_logprobs = []
else:
token_ids_input_logprobs = token_ids_output_logprobs = None
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
if top_k:
sampling_params["top_k"] = top_k
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
response = engine.generate( response = engine.generate(
prompt, prompt,
lora_path=lora_paths[i] if lora_paths else None, lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params, sampling_params=sampling_params,
return_logprob=True, return_logprob=True,
logprob_start_len=0, logprob_start_len=logprob_start_len,
top_logprobs_num=NUM_TOP_LOGPROBS, top_logprobs_num=NUM_TOP_LOGPROBS,
token_ids_logprob=token_ids_logprob,
) )
text = response["text"] text = response["text"]
...@@ -408,12 +502,36 @@ class SRTRunner: ...@@ -408,12 +502,36 @@ class SRTRunner:
"Received an empty text response. Please verify your input or model configuration." "Received an empty text response. Please verify your input or model configuration."
) )
output_strs.append(text) output_strs.append(text)
# output_ids.append(response["output_ids"])
input_token_logprobs = response["meta_info"]["input_token_logprobs"]
output_token_logprobs = response["meta_info"]["output_token_logprobs"]
# print(i, input_token_logprobs)
# print(i, output_token_logprobs)
logprobs = response["meta_info"]["input_top_logprobs"]
if token_ids_logprob is not None:
input_token_ids_logprobs = response["meta_info"][
"input_token_ids_logprobs"
][1:]
else:
input_token_ids_logprobs = None
num_prompt_tokens = response["meta_info"]["prompt_tokens"]
assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
assert len(logprobs) == num_prompt_tokens - logprob_start_len
# The first token logprob has no meaning in sglang.
input_token_logprobs = input_token_logprobs[1:]
logprobs = logprobs[1:]
assert len(input_token_logprobs) == len(logprobs)
input_token_logprobs_lst.append(
input_token_logprobs + [output_token_logprobs[0]]
)
output_token_logprobs_lst.append(output_token_logprobs)
top_input_logprobs.append( top_input_logprobs.append(
[ [[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [ + [
[ [
tup[0] tup[0]
...@@ -429,11 +547,41 @@ class SRTRunner: ...@@ -429,11 +547,41 @@ class SRTRunner:
for x in response["meta_info"]["output_top_logprobs"] for x in response["meta_info"]["output_top_logprobs"]
] ]
) )
top_output_logprob_idx.append(
[
[tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
if token_ids_logprob is not None:
token_ids_input_logprobs.append(
[[tup[0] for tup in x] for x in input_token_ids_logprobs]
+ [
[
tup[0]
for tup in response["meta_info"][
"output_token_ids_logprobs"
][0]
]
]
)
token_ids_output_logprobs.append(
[
[tup[0] for tup in x]
for x in response["meta_info"]["output_token_ids_logprobs"]
]
)
return ModelOutput( return ModelOutput(
output_strs=output_strs, output_strs=output_strs,
output_ids=output_ids,
top_input_logprobs=top_input_logprobs, top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs, top_output_logprobs=top_output_logprobs,
input_token_logprobs_lst=input_token_logprobs_lst,
output_token_logprobs_lst=output_token_logprobs_lst,
top_output_logprob_idx=top_output_logprob_idx,
token_ids_input_logprobs=token_ids_input_logprobs,
token_ids_output_logprobs=token_ids_output_logprobs,
) )
@staticmethod @staticmethod
......
"""
Run one test prompt.
Usage:
python3 -m sglang.test.send_one
"""
import argparse
import json
import requests
def send_one_prompt(args):
if args.image:
args.prompt = (
"Human: Describe this image in a very short sentence.\n\nAssistant:"
)
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
else:
image_data = None
response = requests.post(
"http://localhost:30000/generate",
json={
"text": args.prompt,
"image_data": image_data,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
},
stream=args.stream,
)
if args.stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
latency = ret["meta_info"]["e2e_latency"]
if "spec_verify_ct" in ret["meta_info"]:
acc_length = (
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
speed = ret["meta_info"]["completion_tokens"] / latency
print(ret["text"])
print()
print(f"{acc_length=:.2f}")
print(f"{speed=:.2f} token/s")
return acc_length, speed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--frequency-penalty", type=float, default=0.0)
parser.add_argument("--presence-penalty", type=float, default=0.0)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--prompt",
type=str,
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
)
parser.add_argument(
"--image",
action="store_true",
)
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
send_one_prompt(args)
...@@ -8,10 +8,11 @@ import random ...@@ -8,10 +8,11 @@ import random
import subprocess import subprocess
import threading import threading
import time import time
import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from types import SimpleNamespace from types import SimpleNamespace
from typing import Callable, List, Optional from typing import Callable, List, Optional, Tuple
import numpy as np import numpy as np
import requests import requests
...@@ -408,26 +409,49 @@ def popen_launch_server( ...@@ -408,26 +409,49 @@ def popen_launch_server(
other_args: list[str] = (), other_args: list[str] = (),
env: Optional[dict] = None, env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None, return_stdout_stderr: Optional[tuple] = None,
pd_seperated: bool = False,
): ):
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
if pd_seperated:
command = "sglang.launch_pd_server"
else:
command = "sglang.launch_server"
command = [ command = [
"python3", "python3",
"-m", "-m",
"sglang.launch_server", command,
"--model-path", "--model-path",
model, model,
"--host", *[str(x) for x in other_args],
host,
"--port",
port,
*other_args,
] ]
if pd_seperated:
command.extend(
[
"--lb-host",
host,
"--lb-port",
port,
]
)
else:
command.extend(
[
"--host",
host,
"--port",
port,
]
)
if api_key: if api_key:
command += ["--api-key", api_key] command += ["--api-key", api_key]
print(f"command={' '.join(command)}")
if return_stdout_stderr: if return_stdout_stderr:
process = subprocess.Popen( process = subprocess.Popen(
command, command,
...@@ -456,6 +480,8 @@ def popen_launch_server( ...@@ -456,6 +480,8 @@ def popen_launch_server(
except requests.RequestException: except requests.RequestException:
pass pass
time.sleep(10) time.sleep(10)
kill_process_tree(process.pid)
raise TimeoutError("Server failed to start within the timeout period.") raise TimeoutError("Server failed to start within the timeout period.")
...@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float): ...@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
success = True success = True
for filename in files: for filename in files:
global process process = None
def run_one_file(filename): def run_one_file(filename):
nonlocal process
filename = os.path.join(os.getcwd(), filename) filename = os.path.join(os.getcwd(), filename)
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True) print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
process = subprocess.Popen( process = subprocess.Popen(
...@@ -534,11 +562,14 @@ def get_benchmark_args( ...@@ -534,11 +562,14 @@ def get_benchmark_args(
dataset_path="", dataset_path="",
tokenizer="", tokenizer="",
num_prompts=500, num_prompts=500,
sharegpt_output_len=None,
random_input_len=4096, random_input_len=4096,
random_output_len=2048, random_output_len=2048,
sharegpt_context_len=None,
request_rate=float("inf"), request_rate=float("inf"),
disable_stream=False, disable_stream=False,
disable_ignore_eos=False, disable_ignore_eos=False,
pd_seperated: bool = False,
): ):
return SimpleNamespace( return SimpleNamespace(
backend="sglang", backend="sglang",
...@@ -550,8 +581,8 @@ def get_benchmark_args( ...@@ -550,8 +581,8 @@ def get_benchmark_args(
model=None, model=None,
tokenizer=tokenizer, tokenizer=tokenizer,
num_prompts=num_prompts, num_prompts=num_prompts,
sharegpt_output_len=None, sharegpt_output_len=sharegpt_output_len,
sharegpt_context_len=None, sharegpt_context_len=sharegpt_context_len,
random_input_len=random_input_len, random_input_len=random_input_len,
random_output_len=random_output_len, random_output_len=random_output_len,
random_range_ratio=0.0, random_range_ratio=0.0,
...@@ -567,6 +598,8 @@ def get_benchmark_args( ...@@ -567,6 +598,8 @@ def get_benchmark_args(
apply_chat_template=False, apply_chat_template=False,
profile=None, profile=None,
lora_name=None, lora_name=None,
prompt_suffix="",
pd_seperated=pd_seperated,
) )
...@@ -580,6 +613,7 @@ def run_bench_serving( ...@@ -580,6 +613,7 @@ def run_bench_serving(
tokenizer=None, tokenizer=None,
random_input_len=4096, random_input_len=4096,
random_output_len=2048, random_output_len=2048,
sharegpt_context_len=None,
disable_stream=False, disable_stream=False,
disable_ignore_eos=False, disable_ignore_eos=False,
need_warmup=False, need_warmup=False,
...@@ -602,6 +636,7 @@ def run_bench_serving( ...@@ -602,6 +636,7 @@ def run_bench_serving(
num_prompts=num_prompts, num_prompts=num_prompts,
random_input_len=random_input_len, random_input_len=random_input_len,
random_output_len=random_output_len, random_output_len=random_output_len,
sharegpt_context_len=sharegpt_context_len,
request_rate=request_rate, request_rate=request_rate,
disable_stream=disable_stream, disable_stream=disable_stream,
disable_ignore_eos=disable_ignore_eos, disable_ignore_eos=disable_ignore_eos,
...@@ -626,6 +661,7 @@ def run_bench_serving_multi( ...@@ -626,6 +661,7 @@ def run_bench_serving_multi(
other_server_args, other_server_args,
benchmark_args, benchmark_args,
need_warmup=False, need_warmup=False,
pd_seperated=False,
): ):
# Launch the server # Launch the server
process = popen_launch_server( process = popen_launch_server(
...@@ -633,6 +669,7 @@ def run_bench_serving_multi( ...@@ -633,6 +669,7 @@ def run_bench_serving_multi(
base_url, base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args, other_args=other_server_args,
pd_seperated=pd_seperated,
) )
# run benchmark for all # run benchmark for all
...@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args): ...@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
"128", "128",
"--output", "--output",
"8", "8",
*other_args, *[str(x) for x in other_args],
] ]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
...@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None): ...@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
stdout = open(STDOUT_FILENAME, "w") stdout = open(STDOUT_FILENAME, "w")
stderr = open(STDERR_FILENAME, "w") stderr = open(STDERR_FILENAME, "w")
process = subprocess.Popen( process = subprocess.Popen(
command, stdout=stdout, stderr=stderr, env=env, text=True command, stdout=stdout, stderr=stdout, env=env, text=True
) )
# Launch a thread to stream the output # Launch a thread to stream the output
...@@ -914,3 +951,78 @@ def run_mulit_request_test( ...@@ -914,3 +951,78 @@ def run_mulit_request_test(
def write_github_step_summary(content): def write_github_step_summary(content):
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
f.write(content) f.write(content)
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
rank = 0
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
try:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][rank],
)
break
except AssertionError:
# There's a tie. Allow the second item in this case.
if (
res["meta_info"]["output_top_logprobs"][i][rank][0]
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
0
]
):
rank += 1
else:
raise
#!/bin/bash #!/bin/bash
# Check if sudo is available
if command -v sudo >/dev/null 2>&1; then
sudo apt-get update
sudo apt-get install -y lsof
else
apt-get update
apt-get install -y lsof
fi
# Show current GPU status # Show current GPU status
nvidia-smi nvidia-smi
...@@ -20,6 +11,14 @@ kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2 ...@@ -20,6 +11,14 @@ kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2
# Clean all GPU processes if any argument is provided # Clean all GPU processes if any argument is provided
if [ $# -gt 0 ]; then if [ $# -gt 0 ]; then
# Check if sudo is available
if command -v sudo >/dev/null 2>&1; then
sudo apt-get update
sudo apt-get install -y lsof
else
apt-get update
apt-get install -y lsof
fi
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
fi fi
......
"""
Usage:
# single GPU
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B
"""
import argparse
import asyncio
import json
import os
import time
from types import SimpleNamespace
import numpy as np
import requests
from sglang.bench_serving import benchmark, set_global_args
from sglang.srt.server_args import ServerArgs
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
kill_process_tree,
popen_launch_server,
)
def node0_print(msg):
if server_args.node_rank == 0:
print(msg)
prompts = [
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
]
class FakeTokenizer:
def encode(self, text: str, add_special_tokens: bool = False):
return []
def send_one_batch(base_url, num_prompts, batch_size):
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
]
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
input_requests = [(p, 0, 512) for p in padded_prompts]
# We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace(
disable_ignore_eos=False,
disable_stream=False,
return_logprob=False,
backend="sglang",
dataset_name="custom",
num_prompts=None,
sharegpt_output_len=None,
random_input_len=None,
random_output_len=None,
random_range_ratio=None,
output_file=None,
)
set_global_args(args)
tokenizer = FakeTokenizer()
# Run benchmark
results = asyncio.run(
benchmark(
backend="sglang",
api_url=f"{base_url}/generate",
base_url=base_url,
model_id="default",
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=float("inf"),
max_concurrency=batch_size,
disable_tqdm=False,
lora_name=None,
extra_request_body={},
profile=None,
)
)
assert results["completed"] == len(input_requests)
acc_length = results["accept_length"] or 1.0
avg_output_token = results["total_output_tokens"] / results["completed"]
server_info = requests.get(base_url + "/get_server_info").json()
# We use 20% percentile instead of median on purpose
step_time = np.percentile(server_info["step_time_dict"][str(batch_size)], 20)
speed = 1 / step_time * acc_length
return (
round(acc_length, 3),
round(step_time, 5),
round(speed, 3),
avg_output_token,
)
def main(args, server_args):
base_url = "http://127.0.0.1:20000"
configs = []
for batch_size in args.batch_size:
for steps in args.steps:
for topk in args.topk:
for num_draft_tokens in args.num_draft_tokens:
if steps * topk + 1 < num_draft_tokens:
continue
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
steps + topk + num_draft_tokens != 0
):
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
continue
configs.append((batch_size, steps, topk, num_draft_tokens))
for i in range(args.start, args.end or len(configs)):
batch_size, steps, topk, num_draft_tokens = configs[i]
node0_print(
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
)
# Create an LLM.
if steps == 0:
other_args = []
else:
other_args = [
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
steps,
"--speculative-eagle-topk",
topk,
"--speculative-num-draft-tokens",
num_draft_tokens,
]
if server_args.speculative_draft_model_path is not None:
other_args.extend(
[
"--speculative-draft-model-path",
server_args.speculative_draft_model_path,
]
)
other_args.extend(
[
"--cuda-graph-max-bs",
batch_size,
"--mem-fraction-static",
server_args.mem_fraction_static,
"--tp-size",
server_args.tp_size,
"--max-running-requests",
batch_size,
]
)
if server_args.quantization:
other_args.extend(
[
"--quantization",
server_args.quantization,
]
)
process = popen_launch_server(
args.model_path,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
env={
"SGLANG_RECORD_STEP_TIME": "1",
**os.environ,
},
)
try:
# Warmup
send_one_batch(base_url, batch_size, batch_size)
# Benchmark
acc_length, step_time, speed, completion_tokens = send_one_batch(
base_url, max(args.num_prompts, batch_size), batch_size
)
finally:
kill_process_tree(process.pid)
node0_print(
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
)
record = {
"batch_size": batch_size,
"steps": steps,
"topk": topk,
"num_draft_tokens": num_draft_tokens,
"acc_length": acc_length,
"step_time": step_time,
"speed": speed,
"completion_tokens": completion_tokens,
}
with open(args.output, "a") as fout:
fout.write(json.dumps(record) + "\n")
# Wait for the server to shutdown
time.sleep(5)
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--batch-size",
type=int,
nargs="+",
default=(1, 2, 4, 8, 16),
)
parser.add_argument(
"--steps",
type=int,
nargs="+",
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
)
parser.add_argument(
"--topk",
type=int,
nargs="+",
default=(0, 1, 2, 4, 8),
)
parser.add_argument(
"--num_draft_tokens",
type=int,
nargs="+",
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
)
parser.add_argument("--num-prompts", type=int, default=16)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int)
parser.add_argument("--output", type=str, default="output.jsonl")
args = parser.parse_args()
server_args: ServerArgs = ServerArgs.from_cli_args(args)
main(args, server_args)
...@@ -111,6 +111,8 @@ else: ...@@ -111,6 +111,8 @@ else:
"cublas_grouped_gemm", "cublas_grouped_gemm",
"custom_dispose", "custom_dispose",
"custom_reduce", "custom_reduce",
"build_tree_kernel_efficient",
"build_tree_kernel",
"fp8_blockwise_scaled_mm", "fp8_blockwise_scaled_mm",
"fp8_scaled_mm", "fp8_scaled_mm",
"fused_add_rmsnorm", "fused_add_rmsnorm",
...@@ -127,12 +129,10 @@ else: ...@@ -127,12 +129,10 @@ else:
"register_graph_buffers", "register_graph_buffers",
"rmsnorm", "rmsnorm",
"sampling_scaling_penalties", "sampling_scaling_penalties",
"sgl_per_token_group_quant_fp8",
"silu_and_mul", "silu_and_mul",
"top_k_renorm_prob", "top_k_renorm_prob",
"top_k_top_p_sampling_from_probs", "top_k_top_p_sampling_from_probs",
"top_p_renorm_prob", "top_p_renorm_prob",
"tree_speculative_sampling_target_only", "tree_speculative_sampling_target_only",
"build_tree_kernel_efficient",
"build_tree_kernel",
"sgl_per_token_group_quant_fp8",
] ]
...@@ -30,7 +30,9 @@ class TestSRTBackend(unittest.TestCase): ...@@ -30,7 +30,9 @@ class TestSRTBackend(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST) cls.backend = sgl.Runtime(
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
)
sgl.set_default_backend(cls.backend) sgl.set_default_backend(cls.backend)
@classmethod @classmethod
......
...@@ -12,7 +12,6 @@ suites = { ...@@ -12,7 +12,6 @@ suites = {
"models/test_generation_models.py", "models/test_generation_models.py",
"models/test_qwen_models.py", "models/test_qwen_models.py",
"models/test_reward_models.py", "models/test_reward_models.py",
"sampling/penaltylib",
"test_abort.py", "test_abort.py",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_custom_allreduce.py", "test_custom_allreduce.py",
...@@ -31,6 +30,7 @@ suites = { ...@@ -31,6 +30,7 @@ suites = {
"test_no_chunked_prefill.py", "test_no_chunked_prefill.py",
"test_no_overlap_scheduler.py", "test_no_overlap_scheduler.py",
"test_openai_server.py", "test_openai_server.py",
"test_penalty.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
"test_radix_attention.py", "test_radix_attention.py",
"test_regex_constrained.py", "test_regex_constrained.py",
...@@ -38,7 +38,8 @@ suites = { ...@@ -38,7 +38,8 @@ suites = {
"test_request_length_validation.py", "test_request_length_validation.py",
"test_retract_decode.py", "test_retract_decode.py",
"test_server_args.py", "test_server_args.py",
"test_session_control.py", # Disabled temporarily
# "test_session_control.py",
"test_skip_tokenizer_init.py", "test_skip_tokenizer_init.py",
"test_srt_engine.py", "test_srt_engine.py",
"test_srt_endpoint.py", "test_srt_endpoint.py",
...@@ -64,9 +65,6 @@ suites = { ...@@ -64,9 +65,6 @@ suites = {
# Disable temporarily # Disable temporarily
# "test_nightly_math_eval.py", # "test_nightly_math_eval.py",
], ],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True
),
} }
# Expand suite # Expand suite
...@@ -83,7 +81,7 @@ if __name__ == "__main__": ...@@ -83,7 +81,7 @@ if __name__ == "__main__":
arg_parser.add_argument( arg_parser.add_argument(
"--timeout-per-file", "--timeout-per-file",
type=int, type=int,
default=2000, default=1800,
help="The time limit for running one file in seconds.", help="The time limit for running one file in seconds.",
) )
arg_parser.add_argument( arg_parser.add_argument(
......
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import (
BatchedFrequencyPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
Penalizer = BatchedFrequencyPenalizer
frequency_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedFrequencyPenalizerTest:
self.skipTest("Base class for frequency_penalty tests")
super().setUp()
def _create_subject(self, frequency_penalty: float) -> Subject:
return Subject(
sampling_params=MockSamplingParams(
frequency_penalty=frequency_penalty,
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"frequency_penalties": self.tensor(
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_frequency_penalties": self.tensor(
[[0.0] * self.vocab_size], dtype=torch.float32
),
},
expected_logits=self.tensor(
[[1] * self.vocab_size], dtype=torch.float32
),
),
Step(
type=StepType.OUTPUT,
token_ids=[
1,
2,
2,
], # This is the output ids of one request in three steps.
expected_tensors={
"frequency_penalties": self.tensor(
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_frequency_penalties": self.tensor(
[
[
frequency_penalty * i if i in {1, 2} else 0.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
},
expected_logits=self.tensor(
[
[
1.0 - frequency_penalty * i if i in {1, 2} else 1.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
self.disabled = self._create_subject(frequency_penalty=0.0)
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = 0.12
class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = -0.12
if __name__ == "__main__":
unittest.main()
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import (
BatchedMinNewTokensPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
MIN_NEW_TOKENS = 2
EOS_TOKEN_ID = 4
STOP_TOKEN_ID = 3
ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID}
class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
Penalizer = BatchedMinNewTokensPenalizer
def _create_subject(self, min_new_tokens: int) -> Subject:
return Subject(
eos_token_id=EOS_TOKEN_ID,
sampling_params=MockSamplingParams(
min_new_tokens=min_new_tokens,
stop_token_ids={STOP_TOKEN_ID},
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"min_new_tokens": self.tensor(
[[min_new_tokens]], dtype=torch.int32
),
"stop_token_penalties": self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
),
"len_output_tokens": self.tensor([[0]], dtype=torch.int32),
},
expected_logits=(
self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
)
if min_new_tokens > 0
else torch.ones(
(1, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
),
),
Step(
type=StepType.OUTPUT,
token_ids=[0],
expected_tensors={
"min_new_tokens": self.tensor(
[[min_new_tokens]], dtype=torch.int32
),
"stop_token_penalties": self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
),
"len_output_tokens": self.tensor([[1]], dtype=torch.int32),
},
expected_logits=(
self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
)
if min_new_tokens > 1
else torch.ones(
(1, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
),
),
Step(
type=StepType.OUTPUT,
token_ids=[0],
expected_tensors={
"min_new_tokens": self.tensor(
[[min_new_tokens]], dtype=torch.int32
),
"stop_token_penalties": self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
),
"len_output_tokens": self.tensor([[2]], dtype=torch.int32),
},
expected_logits=(
self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
)
if min_new_tokens > 2
else torch.ones(
(1, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
self.disabled = self._create_subject(min_new_tokens=0.0)
if __name__ == "__main__":
unittest.main()
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import (
BatchedPresencePenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
Penalizer = BatchedPresencePenalizer
presence_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedPresencePenalizerTest:
self.skipTest("Base class for presence_penalty tests")
super().setUp()
def _create_subject(self, presence_penalty: float) -> Subject:
return Subject(
sampling_params=MockSamplingParams(
presence_penalty=presence_penalty,
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"presence_penalties": self.tensor(
[[presence_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_presence_penalties": self.tensor(
[[0.0] * self.vocab_size], dtype=torch.float32
),
},
expected_logits=self.tensor(
[[1] * self.vocab_size], dtype=torch.float32
),
),
Step(
type=StepType.OUTPUT,
token_ids=[1, 2, 2],
expected_tensors={
"presence_penalties": self.tensor(
[[presence_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_presence_penalties": self.tensor(
[
[
presence_penalty if i in {1, 2} else 0.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
},
expected_logits=self.tensor(
[
[
1.0 - presence_penalty if i in {1, 2} else 1.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
self.disabled = self._create_subject(presence_penalty=0.0)
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
presence_penalty = 0.12
class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
presence_penalty = -0.12
if __name__ == "__main__":
unittest.main()
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
BatchedRepetitionPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
REPETITION_PENALTY = 2.0
class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
Penalizer = BatchedRepetitionPenalizer
def _create_subject(self, repetition_penalty: float) -> Subject:
l = 1.0 / repetition_penalty
return Subject(
sampling_params=MockSamplingParams(
repetition_penalty=repetition_penalty,
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"repetition_penalties": self.tensor(
[[repetition_penalty] * self.vocab_size],
dtype=torch.float32,
),
"cumulated_repetition_penalties": (
self.tensor(
[[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32
)
if repetition_penalty != 1.0
else self.tensor(
[[1.0] * self.vocab_size], dtype=torch.float32
)
),
},
expected_logits=(
self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32)
if repetition_penalty != 1.0
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
),
),
Step(
type=StepType.OUTPUT,
token_ids=[0, 1, 3],
expected_tensors={
"repetition_penalties": self.tensor(
[[repetition_penalty] * self.vocab_size],
dtype=torch.float32,
),
"cumulated_repetition_penalties": (
self.tensor(
[[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32
)
if repetition_penalty != 1.0
else self.tensor(
[[1.0] * self.vocab_size], dtype=torch.float32
)
),
},
expected_logits=(
self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32)
if repetition_penalty != 1.0
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
self.disabled = self._create_subject(repetition_penalty=1.0)
if __name__ == "__main__":
unittest.main()
...@@ -138,6 +138,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -138,6 +138,7 @@ class TestBenchServing(unittest.TestCase):
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
num_prompts=50, num_prompts=50,
request_rate=1, request_rate=1,
sharegpt_context_len=3072,
disable_ignore_eos=True, disable_ignore_eos=True,
dataset_name="sharegpt", dataset_name="sharegpt",
other_server_args=[ other_server_args=[
...@@ -148,22 +149,23 @@ class TestBenchServing(unittest.TestCase): ...@@ -148,22 +149,23 @@ class TestBenchServing(unittest.TestCase):
"--speculative-num-steps", "--speculative-num-steps",
"5", "5",
"--speculative-eagle-topk", "--speculative-eagle-topk",
"8", "4",
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"64", "16",
"--mem-fraction-static", "--mem-fraction-static",
"0.7", "0.7",
"--cuda-graph-max-bs",
"32",
], ],
need_warmup=True,
) )
if is_in_ci(): if is_in_ci():
write_github_step_summary( write_github_step_summary(
f"### test_online_latency_eagle\n" f"### test_online_latency_eagle\n"
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
f'accept_length : {res["accept_length"]:.2f} \n'
) )
self.assertLess(res["median_e2e_latency_ms"], 450) self.assertLess(res["median_e2e_latency_ms"], 700)
self.assertGreater(res["accept_length"], 2.50)
def test_moe_offline_throughput_default(self): def test_moe_offline_throughput_default(self):
res = run_bench_serving( res = run_bench_serving(
......
...@@ -12,7 +12,9 @@ from sglang.test.test_utils import ( ...@@ -12,7 +12,9 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server, popen_launch_server,
write_github_step_summary,
) )
...@@ -44,6 +46,9 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -44,6 +46,9 @@ class TestEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.71) self.assertGreater(metrics["score"], 0.71)
if is_in_ci():
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
def test_human_eval(self): def test_human_eval(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
...@@ -56,6 +61,11 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -56,6 +61,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.64) self.assertGreater(metrics["score"], 0.64)
if is_in_ci():
write_github_step_summary(
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
...@@ -68,6 +78,11 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -68,6 +78,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835) self.assertGreater(metrics["score"], 0.835)
if is_in_ci():
write_github_step_summary(
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import unittest
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestHealthCheck(unittest.TestCase):
def test_health_check(self):
"""Test that metrics endpoint returns data when enabled"""
with self.assertRaises(TimeoutError):
popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=60,
other_args=[
"--disable-cuda-graph",
"--json-model-override-args",
'{"architectures": ["LlamaForCausalLMForHealthTest"]}',
],
)
if __name__ == "__main__":
unittest.main()
...@@ -49,7 +49,7 @@ class TestHiddenState(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestHiddenState(unittest.TestCase):
with torch.inference_mode(): with torch.inference_mode():
hf_out = model( hf_out = model(
torch.tensor( torch.tensor(
[input_id + output["token_ids"][:-1]], device=model.device [input_id + output["output_ids"][:-1]], device=model.device
), ),
output_hidden_states=True, output_hidden_states=True,
) )
......
...@@ -56,11 +56,13 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -56,11 +56,13 @@ class TestEnableMetrics(unittest.TestCase):
"sglang:gen_throughput", "sglang:gen_throughput",
"sglang:num_queue_reqs", "sglang:num_queue_reqs",
"sglang:cache_hit_rate", "sglang:cache_hit_rate",
"sglang:spec_accept_length",
"sglang:prompt_tokens_total", "sglang:prompt_tokens_total",
"sglang:generation_tokens_total", "sglang:generation_tokens_total",
"sglang:num_requests_total", "sglang:num_requests_total",
"sglang:time_to_first_token_seconds", "sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds", "sglang:time_per_output_token_seconds",
"sglang:inter_token_latency_seconds",
"sglang:e2e_request_latency_seconds", "sglang:e2e_request_latency_seconds",
] ]
......
...@@ -141,7 +141,7 @@ class TestDeepseekV3MTP(unittest.TestCase): ...@@ -141,7 +141,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
metrics = run_eval_few_shot_gsm8k(args) metrics = run_eval_few_shot_gsm8k(args)
print(metrics) print(metrics)
self.assertGreater(metrics["accuracy"], 0.62) self.assertGreater(metrics["accuracy"], 0.60)
if __name__ == "__main__": if __name__ == "__main__":
......
import json import json
import random
import unittest import unittest
from multiprocessing import Process from concurrent.futures import ThreadPoolExecutor
import requests import requests
...@@ -13,7 +14,7 @@ from sglang.test.test_utils import ( ...@@ -13,7 +14,7 @@ from sglang.test.test_utils import (
) )
class TestBatchPenalizerE2E(unittest.TestCase): class TestPenalty(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -23,24 +24,18 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -23,24 +24,18 @@ class TestBatchPenalizerE2E(unittest.TestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--random-seed",
"0",
),
) )
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def run_decode( def run_decode(self, sampling_params):
self, return_logprob = True
return_logprob=True, top_logprobs_num = 5
top_logprobs_num=5, return_text = True
return_text=True, n = 1
n=1,
**sampling_params,
):
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
...@@ -51,63 +46,45 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -51,63 +46,45 @@ class TestBatchPenalizerE2E(unittest.TestCase):
"n": n, "n": n,
**sampling_params, **sampling_params,
}, },
"stream": False,
"return_logprob": return_logprob, "return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num, "top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text, "return_text_in_logprobs": return_text,
"logprob_start_len": 0, "logprob_start_len": 0,
}, },
) )
assert response.status_code == 200, "Request failed: " + response.text self.assertEqual(response.status_code, 200)
print(json.dumps(response.json()))
print("=" * 100)
def test_default_values(self): def test_default_values(self):
self.run_decode() self.run_decode({})
def test_mixed(self):
"""
Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
This test triggers the merge of disabled + enabled.
"""
processes = []
p = Process(
target=self.run_decode,
)
processes.append(p)
p.start()
p = Process(
target=self.run_decode,
kwargs={
"frequency_penalty": 2,
"min_new_tokens": 16,
"presence_penalty": 2,
"repetition_penalty": 2,
},
)
processes.append(p)
p.start()
for p in processes:
p.join()
def test_frequency_penalty(self): def test_frequency_penalty(self):
self.run_decode(frequency_penalty=2) self.run_decode({"frequency_penalty": 2})
def test_min_new_tokens(self): def test_min_new_tokens(self):
self.run_decode(min_new_tokens=16) self.run_decode({"min_new_tokens": 16})
def test_presence_penalty(self): def test_presence_penalty(self):
self.run_decode(presence_penalty=2) self.run_decode({"presence_penalty": 2})
def test_repetition_penalty(self): def test_mixed(self):
self.run_decode(repetition_penalty=2) args = [
{},
{},
{},
{"frequency_penalty": 2},
{"min_new_tokens": 16},
{"presence_penalty": 1},
{"frequency_penalty": 0.2},
{"min_new_tokens": 8},
{"presence_penalty": 0.4},
{"presence_penalty": 0.4, "frequency_penalty": 2},
{"min_new_tokens": 12, "frequency_penalty": 2},
]
random.shuffle(args * 5)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_decode, args))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase): ...@@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase):
first_rid = None first_rid = None
outputs_from_session = [] outputs_from_session = []
logprobs_from_session = []
cur_logprob_start_len = 0
for i, chunk_ids in enumerate(chunks_ids): for i, chunk_ids in enumerate(chunks_ids):
max_new_tokens = gen_len if i > 0 else 1 # prefill only for the first chunk
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
...@@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase): ...@@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase):
}, },
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": ( "max_new_tokens": max_new_tokens,
gen_len if i > 0 else 1
), # prefill only for the first chunk
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len - 1,
}, },
).json() ).json()
rid = response["meta_info"]["id"] rid = response["meta_info"]["id"]
...@@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase): ...@@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase):
first_rid = rid first_rid = rid
if i > 0: if i > 0:
outputs_from_session.append(response["text"]) outputs_from_session.append(response["text"])
logprobs_from_session.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
cur_logprob_start_len += len(chunk_ids) + max_new_tokens
# query with a logprob_start_len longer than the request, should see error
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunk_ids,
"session_params": {
"id": session_id,
"rid": rid,
"offset": -1,
"replace": True,
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len + len(chunk_ids),
},
).json()
assert "Request with a lower logprob_start_len" in response["error"]["message"]
# backtrack to the first request and regenerate # backtrack to the first request and regenerate
cur_logprob_start_len = 0
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
...@@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase): ...@@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len,
}, },
).json() ).json()
outputs_from_session.append(response["text"]) outputs_from_session.append(response["text"])
logprobs_from_session.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort # query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post( response = requests.post(
...@@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
"return_logprob": True,
}, },
).json() ).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort" assert response["meta_info"]["finish_reason"]["type"] == "abort"
...@@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
"return_logprob": True,
}, },
).json() ).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort" assert response["meta_info"]["finish_reason"]["type"] == "abort"
...@@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase):
input_ids_first_req = None input_ids_first_req = None
input_ids = [] input_ids = []
outputs_normal = [] outputs_normal = []
logprobs_normal = []
for i, chunk_ids in enumerate(chunks_ids): for i, chunk_ids in enumerate(chunks_ids):
input_ids += chunk_ids input_ids += chunk_ids
response = requests.post( response = requests.post(
...@@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
"return_logprob": True,
}, },
).json() ).json()
if i > 0: if i > 0:
...@@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase): ...@@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase):
output_ids = output_ids[1:] output_ids = output_ids[1:]
input_ids += output_ids[:-1] input_ids += output_ids[:-1]
outputs_normal.append(response["text"]) outputs_normal.append(response["text"])
logprobs_normal.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
if i == 0: if i == 0:
input_ids_first_req = input_ids.copy() input_ids_first_req = input_ids.copy()
...@@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase): ...@@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
"return_logprob": True,
}, },
).json() ).json()
outputs_normal.append(response["text"]) outputs_normal.append(response["text"])
logprobs_normal.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
print("outputs from chunked queries with session control:") print("outputs from chunked queries with session control:")
print(outputs_from_session) print(outputs_from_session)
print("outputs from normal queries:") print("outputs from normal queries:")
print(outputs_normal) print(outputs_normal)
assert ( assert outputs_from_session == outputs_normal
outputs_from_session == outputs_normal print("logprobs from chunked queries with session control:")
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" print(logprobs_from_session)
print("logprobs from normal queries:")
print(logprobs_normal)
assert len(logprobs_from_session) == len(
logprobs_normal
), "logprobs must have equal length"
for a, b in zip(logprobs_from_session, logprobs_normal):
assert abs(a - b) <= 0.1, f"logprobs {a} and {b} differ by more than 0.1"
async def async_generate(self, payload): async def async_generate(self, payload):
url = self.base_url + "/generate" url = self.base_url + "/generate"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment