Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_stream
"""
import json
import unittest
......@@ -12,42 +17,26 @@ from sglang.test.test_utils import (
popen_launch_server,
)
_server_process = None
_base_url = None
_tokenizer = None
def setUpModule():
"""
Launch the server once before all tests and initialize the tokenizer.
"""
global _server_process, _base_url, _tokenizer
_server_process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init"],
)
_base_url = DEFAULT_URL_FOR_TEST
_tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
print(">>> setUpModule: Server launched, tokenizer ready")
def tearDownModule():
"""
Terminate the server once after all tests have completed.
"""
global _server_process
if _server_process is not None:
kill_process_tree(_server_process.pid)
_server_process = None
print(">>> tearDownModule: Server terminated")
class TestSkipTokenizerInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init", "--stream-output"],
)
cls.tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
self,
prompt_text="The capital of France is",
......@@ -56,19 +45,19 @@ class TestSkipTokenizerInit(unittest.TestCase):
top_logprobs_num=0,
n=1,
):
input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
0
].tolist()
response = requests.post(
_base_url + "/generate",
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [_tokenizer.eos_token_id],
"stop_token_ids": [self.tokenizer.eos_token_id],
},
"stream": False,
"return_logprob": return_logprob,
......@@ -83,13 +72,13 @@ class TestSkipTokenizerInit(unittest.TestCase):
if item["meta_info"]["finish_reason"]["type"] == "stop":
self.assertEqual(
item["meta_info"]["finish_reason"]["matched"],
_tokenizer.eos_token_id,
self.tokenizer.eos_token_id,
)
elif item["meta_info"]["finish_reason"]["type"] == "length":
self.assertEqual(
len(item["token_ids"]), item["meta_info"]["completion_tokens"]
len(item["output_ids"]), item["meta_info"]["completion_tokens"]
)
self.assertEqual(len(item["token_ids"]), max_new_tokens)
self.assertEqual(len(item["output_ids"]), max_new_tokens)
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
if return_logprob:
......@@ -113,6 +102,63 @@ class TestSkipTokenizerInit(unittest.TestCase):
print("=" * 100)
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
max_new_tokens = 32
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
requests.post(self.base_url + "/flush_cache")
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
print(json.dumps(ret))
output_ids = ret["output_ids"]
requests.post(self.base_url + "/flush_cache")
response_stream = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
output_ids = ret["output_ids"]
print("output from non-streaming request:")
print(output_ids)
response_stream_json = []
for line in response_stream.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_stream_json.append(json.loads(line[6:]))
out_stream_ids = []
for x in response_stream_json:
out_stream_ids += x["output_ids"]
print("output from streaming request:")
print(out_stream_ids)
assert output_ids == out_stream_ids
def test_simple_decode(self):
self.run_decode()
......@@ -126,6 +172,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
def test_eos_behavior(self):
self.run_decode(max_new_tokens=256)
def test_simple_decode_stream(self):
self.run_decode_stream()
if __name__ == "__main__":
unittest.main()
......@@ -8,6 +8,7 @@ import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Optional
import numpy as np
......@@ -20,6 +21,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
run_logprob_check,
)
......@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase):
other_args=(
"--enable-custom-logit-processor",
"--mem-fraction-static",
"0.8",
"0.7",
"--cuda-graph-max-bs",
"8",
),
)
......@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase):
for i, res in enumerate(response_json):
self.assertEqual(
res["meta_info"]["prompt_tokens"],
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
)
assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
......@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase):
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.25)
def run_logprob_check(self, arg):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
# This is because if logprob_start_len == 0, we added a padding for the first token.
# In other cases, we do not add the padding
delta = 0 if logprob_start_len == 0 else 1
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"]), output_len
)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][0],
)
self.assertLess(max_diff, 0.35)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for input_len in [1000, 2000]:
for input_len in [1000, 5000, 10000, 50000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 500, 1000]:
for logprob_start_len in [0, 500, 2500, 5000, 25000]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
......@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase):
random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_logprob_check, args))
list(executor.map(func, args))
def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:"
......@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase):
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def run_stateful_custom_logit_processor(
self, first_token_id: int | None, delay: int = 2
):
"""Test custom logit processor with custom params and state.
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
If first_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": first_token_id, "delay": 2}
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
if param_dict["delay"] > 0:
param_dict["delay"] -= 1
continue
if param_dict["delay"] == 0:
param_dict["delay"] -= 1
force_token = param_dict["token_id"]
else:
output_ids = param_dict["__req__"].output_ids
force_token = output_ids[-1] + 1
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, force_token] = 0.0
return logits
prompts = "Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json = {
"text": prompts,
"sampling_params": {"temperature": 0.0},
"return_logprob": True,
}
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if first_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicStatefulLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
json=custom_json,
).json()
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
if first_token_id is not None:
self.assertTrue(
all(
x == custom_params["token_id"] + k
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
),
# Print the detailed test case info if the test fails.
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
......@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase):
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_stateful_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_stateful_custom_logit_processor(first_token_id=5)
def test_stateful_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
)
def test_cache_tokens(self):
for _ in range(2):
time.sleep(1)
......@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase):
version = response_json["version"]
self.assertIsInstance(version, str)
def test_get_server_info_concurrent(self):
"""Make sure the concurrent get_server_info doesn't crash the server."""
tp = ThreadPoolExecutor(max_workers=30)
def s():
server_info = requests.get(self.base_url + "/get_server_info")
server_info.json()
futures = []
for _ in range(4):
futures.append(tp.submit(s))
for f in futures:
f.result()
if __name__ == "__main__":
unittest.main()
......@@ -168,9 +168,9 @@ def _run_subprocess(
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_outputs = HFRunner.forward_generation_raw(
base_model=hf_model,
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
base_model=hf_model,
tokenizer=hf_tokenizer,
lora_paths=None,
torch_dtype=_TORCH_DTYPE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment