Commit ac238727 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Support penalty in overlap mode; return logprob with chunked prefill; improve...


Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <hanming_lu@berkeley.edu>
parent 0194948f
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_stream
"""
import json import json
import unittest import unittest
...@@ -12,42 +17,26 @@ from sglang.test.test_utils import ( ...@@ -12,42 +17,26 @@ from sglang.test.test_utils import (
popen_launch_server, popen_launch_server,
) )
_server_process = None
_base_url = None
_tokenizer = None
def setUpModule():
"""
Launch the server once before all tests and initialize the tokenizer.
"""
global _server_process, _base_url, _tokenizer
_server_process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init"],
)
_base_url = DEFAULT_URL_FOR_TEST
_tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
print(">>> setUpModule: Server launched, tokenizer ready")
def tearDownModule():
"""
Terminate the server once after all tests have completed.
"""
global _server_process
if _server_process is not None:
kill_process_tree(_server_process.pid)
_server_process = None
print(">>> tearDownModule: Server terminated")
class TestSkipTokenizerInit(unittest.TestCase): class TestSkipTokenizerInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init", "--stream-output"],
)
cls.tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode( def run_decode(
self, self,
prompt_text="The capital of France is", prompt_text="The capital of France is",
...@@ -56,19 +45,19 @@ class TestSkipTokenizerInit(unittest.TestCase): ...@@ -56,19 +45,19 @@ class TestSkipTokenizerInit(unittest.TestCase):
top_logprobs_num=0, top_logprobs_num=0,
n=1, n=1,
): ):
input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][ input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
0 0
].tolist() ].tolist()
response = requests.post( response = requests.post(
_base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": input_ids, "input_ids": input_ids,
"sampling_params": { "sampling_params": {
"temperature": 0 if n == 1 else 0.5, "temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"n": n, "n": n,
"stop_token_ids": [_tokenizer.eos_token_id], "stop_token_ids": [self.tokenizer.eos_token_id],
}, },
"stream": False, "stream": False,
"return_logprob": return_logprob, "return_logprob": return_logprob,
...@@ -83,13 +72,13 @@ class TestSkipTokenizerInit(unittest.TestCase): ...@@ -83,13 +72,13 @@ class TestSkipTokenizerInit(unittest.TestCase):
if item["meta_info"]["finish_reason"]["type"] == "stop": if item["meta_info"]["finish_reason"]["type"] == "stop":
self.assertEqual( self.assertEqual(
item["meta_info"]["finish_reason"]["matched"], item["meta_info"]["finish_reason"]["matched"],
_tokenizer.eos_token_id, self.tokenizer.eos_token_id,
) )
elif item["meta_info"]["finish_reason"]["type"] == "length": elif item["meta_info"]["finish_reason"]["type"] == "length":
self.assertEqual( self.assertEqual(
len(item["token_ids"]), item["meta_info"]["completion_tokens"] len(item["output_ids"]), item["meta_info"]["completion_tokens"]
) )
self.assertEqual(len(item["token_ids"]), max_new_tokens) self.assertEqual(len(item["output_ids"]), max_new_tokens)
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids)) self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
if return_logprob: if return_logprob:
...@@ -113,6 +102,63 @@ class TestSkipTokenizerInit(unittest.TestCase): ...@@ -113,6 +102,63 @@ class TestSkipTokenizerInit(unittest.TestCase):
print("=" * 100) print("=" * 100)
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
max_new_tokens = 32
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
requests.post(self.base_url + "/flush_cache")
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
print(json.dumps(ret))
output_ids = ret["output_ids"]
requests.post(self.base_url + "/flush_cache")
response_stream = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
output_ids = ret["output_ids"]
print("output from non-streaming request:")
print(output_ids)
response_stream_json = []
for line in response_stream.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_stream_json.append(json.loads(line[6:]))
out_stream_ids = []
for x in response_stream_json:
out_stream_ids += x["output_ids"]
print("output from streaming request:")
print(out_stream_ids)
assert output_ids == out_stream_ids
def test_simple_decode(self): def test_simple_decode(self):
self.run_decode() self.run_decode()
...@@ -126,6 +172,9 @@ class TestSkipTokenizerInit(unittest.TestCase): ...@@ -126,6 +172,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
def test_eos_behavior(self): def test_eos_behavior(self):
self.run_decode(max_new_tokens=256) self.run_decode(max_new_tokens=256)
def test_simple_decode_stream(self):
self.run_decode_stream()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -8,6 +8,7 @@ import random ...@@ -8,6 +8,7 @@ import random
import time import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -20,6 +21,7 @@ from sglang.test.test_utils import ( ...@@ -20,6 +21,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
run_logprob_check,
) )
...@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase):
other_args=( other_args=(
"--enable-custom-logit-processor", "--enable-custom-logit-processor",
"--mem-fraction-static", "--mem-fraction-static",
"0.8", "0.7",
"--cuda-graph-max-bs",
"8",
), ),
) )
...@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase):
for i, res in enumerate(response_json): for i, res in enumerate(response_json):
self.assertEqual( self.assertEqual(
res["meta_info"]["prompt_tokens"], res["meta_info"]["prompt_tokens"],
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]), logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
) )
assert prompts[i].endswith( assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]]) "".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
...@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase):
diff = np.abs(output_logprobs - output_logprobs_score) diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff) max_diff = np.max(diff)
self.assertLess(max_diff, 0.25) self.assertLess(max_diff, 0.35)
def run_logprob_check(self, arg):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
# This is because if logprob_start_len == 0, we added a padding for the first token.
# In other cases, we do not add the padding
delta = 0 if logprob_start_len == 0 else 1
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"]), output_len
)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][0],
)
def test_logprob_mixed(self): def test_logprob_mixed(self):
args = [] args = []
temperature = 0 temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for input_len in [1000, 2000]: for input_len in [1000, 5000, 10000, 50000]:
for output_len in [4, 8]: for output_len in [4, 8]:
for logprob_start_len in [0, 500, 1000]: for logprob_start_len in [0, 500, 2500, 5000, 25000]:
for return_logprob in [True, False]: for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]: for top_logprobs_num in [0, 5]:
...@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase):
random.shuffle(args) random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor: with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_logprob_check, args)) list(executor.map(func, args))
def test_logprob_grammar(self): def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:" prompts = "Question: Is Paris the Capital of France? Answer:"
...@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase):
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}", f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
) )
def run_stateful_custom_logit_processor(
self, first_token_id: int | None, delay: int = 2
):
"""Test custom logit processor with custom params and state.
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
If first_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": first_token_id, "delay": 2}
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
if param_dict["delay"] > 0:
param_dict["delay"] -= 1
continue
if param_dict["delay"] == 0:
param_dict["delay"] -= 1
force_token = param_dict["token_id"]
else:
output_ids = param_dict["__req__"].output_ids
force_token = output_ids[-1] + 1
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, force_token] = 0.0
return logits
prompts = "Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json = {
"text": prompts,
"sampling_params": {"temperature": 0.0},
"return_logprob": True,
}
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if first_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicStatefulLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
json=custom_json,
).json()
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
if first_token_id is not None:
self.assertTrue(
all(
x == custom_params["token_id"] + k
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
),
# Print the detailed test case info if the test fails.
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self): def test_custom_logit_processor(self):
"""Test custom logit processor with a single request.""" """Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5) self.run_custom_logit_processor(target_token_id=5)
...@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase):
with ThreadPoolExecutor(len(target_token_ids)) as executor: with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids)) list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_stateful_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_stateful_custom_logit_processor(first_token_id=5)
def test_stateful_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
)
def test_cache_tokens(self): def test_cache_tokens(self):
for _ in range(2): for _ in range(2):
time.sleep(1) time.sleep(1)
...@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase):
version = response_json["version"] version = response_json["version"]
self.assertIsInstance(version, str) self.assertIsInstance(version, str)
def test_get_server_info_concurrent(self):
"""Make sure the concurrent get_server_info doesn't crash the server."""
tp = ThreadPoolExecutor(max_workers=30)
def s():
server_info = requests.get(self.base_url + "/get_server_info")
server_info.json()
futures = []
for _ in range(4):
futures.append(tp.submit(s))
for f in futures:
f.result()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -168,9 +168,9 @@ def _run_subprocess( ...@@ -168,9 +168,9 @@ def _run_subprocess(
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True) hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_outputs = HFRunner.forward_generation_raw( hf_outputs = HFRunner.forward_generation_raw(
base_model=hf_model,
prompts=_PROMPTS, prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS, max_new_tokens=_MAX_NEW_TOKENS,
base_model=hf_model,
tokenizer=hf_tokenizer, tokenizer=hf_tokenizer,
lora_paths=None, lora_paths=None,
torch_dtype=_TORCH_DTYPE, torch_dtype=_TORCH_DTYPE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment