Unverified Commit af02f99b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add more logprob tests (#3162)

parent 9472e699
...@@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -32,7 +32,11 @@ class TestSRTEndpoint(unittest.TestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=("--enable-custom-logit-processor",), other_args=(
"--enable-custom-logit-processor",
"--mem-fraction-static",
"0.8",
),
) )
@classmethod @classmethod
...@@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -155,14 +159,26 @@ class TestSRTEndpoint(unittest.TestCase):
}, },
"return_logprob": True, "return_logprob": True,
"logprob_start_len": -1, "logprob_start_len": -1,
"top_logprobs_num": 5,
}, },
) )
response_json = response.json() response_json = response.json()
print(json.dumps(response_json, indent=2)) # print(json.dumps(response_json, indent=2))
res = response_json res = response_json
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
# Test the number of tokens are correct
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens)
# Test the top-1 tokens are the same as output tokens (because temp = 0.0)
for i in range(new_tokens):
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][0],
)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5)
def test_logprob_match(self): def test_logprob_match(self):
"""Test the output logprobs are close to the input logprobs if we run a prefill again.""" """Test the output logprobs are close to the input logprobs if we run a prefill again."""
...@@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -221,6 +237,103 @@ class TestSRTEndpoint(unittest.TestCase):
max_diff = np.max(diff) max_diff = np.max(diff)
self.assertLess(max_diff, 0.25) self.assertLess(max_diff, 0.25)
def run_logprob_check(self, arg):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
# This is because if logprob_start_len == 0, we added a padding for the first token.
# In other cases, we do not add the padding
delta = 0 if logprob_start_len == 0 else 1
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"]), output_len
)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][0],
)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for input_len in [1000, 2000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 500, 1000]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
if logprob_start_len >= input_len:
continue
args.append(
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
)
)
random.shuffle(args)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_logprob_check, args))
def test_logprob_grammar(self): def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:" prompts = "Question: Is Paris the Capital of France? Answer:"
allowed_tokens = [" Yes", " No"] allowed_tokens = [" Yes", " No"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment