Unverified Commit ea13cb14 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update test_deterministic.py, test_deterministi... (20251024) (#12083)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
parent a04212f1
......@@ -184,7 +184,9 @@ def send_single(
return ret["text"]
def send_prefix(args, batch_size: int, prompts: List[str]):
def send_prefix(
args, batch_size: int, prompts: List[str], return_full_response: bool = False
):
requests.post(f"http://{args.host}:{args.port}/flush_cache")
batch_data = []
......@@ -219,11 +221,36 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
print(ret)
return -1, -1, -1
ret_dict = {i: [] for i in range(len(prompts))}
for i in range(batch_size):
ret_dict[sampled_indices[i]].append(ret[i]["text"])
if return_full_response:
# Return full responses grouped by prompt index
ret_dict = {i: [] for i in range(len(prompts))}
for i in range(batch_size):
ret_dict[sampled_indices[i]].append(ret[i])
return ret_dict
else:
# Return only text grouped by prompt index
ret_dict = {i: [] for i in range(len(prompts))}
for i in range(batch_size):
ret_dict[sampled_indices[i]].append(ret[i]["text"])
return ret_dict
def compare_logprobs(logprobs1, logprobs2, tolerance=0):
"""Compare two logprobs sequences with a tolerance."""
if len(logprobs1) != len(logprobs2):
return False, f"Length mismatch: {len(logprobs1)} vs {len(logprobs2)}"
for i, (lp1, lp2) in enumerate(zip(logprobs1, logprobs2)):
# Each element is [logprob, token_id]
if lp1[1] != lp2[1]:
return False, f"Token ID mismatch at position {i}: {lp1[1]} vs {lp2[1]}"
if abs(lp1[0] - lp2[0]) > tolerance:
return (
False,
f"Logprob mismatch at position {i}: {lp1[0]} vs {lp2[0]} (diff: {abs(lp1[0] - lp2[0])})",
)
return ret_dict
return True, "Logprobs match"
def _test_mode_p_vs_d(args, batch_size):
......@@ -366,15 +393,28 @@ def test_deterministic(args):
num_prompts = len(len_prefix)
outputs = {i: [] for i in range(4)}
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
# If return_logprob is enabled, store full responses for comparison
if args.return_logprob:
full_responses = {i: [] for i in range(4)}
for i in range(args.n_start, args.n_start + args.n_trials):
batch_size = i
ret_dict = send_prefix(args, batch_size, prompts)
ret_dict = send_prefix(
args, batch_size, prompts, return_full_response=args.return_logprob
)
msg = f"Testing Trial {i} with batch size {batch_size},"
for i in range(num_prompts):
msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
print(msg)
for i in range(num_prompts):
outputs[i].extend(ret_dict[i])
if args.return_logprob:
# Store full response for logprob comparison
full_responses[i].extend(ret_dict[i])
# Extract text for determinism check
outputs[i].extend([resp["text"] for resp in ret_dict[i]])
else:
outputs[i].extend(ret_dict[i])
for i in range(num_prompts):
print(
......@@ -384,6 +424,54 @@ def test_deterministic(args):
results = []
for i in range(num_prompts):
results.append(len(set(outputs[i])))
# If logprobs are enabled, compare them across different batch sizes
if args.return_logprob:
print(f"\n{'='*60}")
print("Logprobs Comparison Across Batch Sizes")
print("=" * 60)
logprob_results = []
for prompt_idx in range(num_prompts):
print(
f"\nPrompt {prompt_idx} (prefix length {len_prefix[prompt_idx]}):"
)
responses = full_responses[prompt_idx]
if len(responses) < 2:
continue
# Compare all responses against the first one
reference = responses[0]
all_match = True
mismatches = []
for j, resp in enumerate(responses[1:], start=1):
ref_logprobs = reference["meta_info"]["output_token_logprobs"]
resp_logprobs = resp["meta_info"]["output_token_logprobs"]
match, msg = compare_logprobs(ref_logprobs, resp_logprobs)
if not match:
print(f" ✗ Sample {j+1}: {msg}")
mismatches.append((j + 1, msg))
all_match = False
if all_match:
print(f" ✓ All {len(responses)} samples have identical logprobs")
logprob_results.append(1)
else:
print(
f" ✗ Found {len(mismatches)} mismatches out of {len(responses)} samples"
)
logprob_results.append(0)
print(f"\n{'='*60}")
if all(r == 1 for r in logprob_results):
print("✓✓✓ Logprobs are identical across all batch sizes! ✓✓✓")
else:
print("✗✗✗ Some logprobs differ across batch sizes! ✗✗✗")
return results
elif args.test_mode == "radix_cache":
......
......@@ -60,7 +60,7 @@ class TestDeterministicBase(CustomTestCase):
for result in results:
assert result == 1
def test_prefix(self):
def test_prefix_with_logprobs(self):
args = BenchArgs()
url = DEFAULT_URL_FOR_TEST
args.host, args.port = self._extract_host_and_port(url)
......@@ -68,6 +68,7 @@ class TestDeterministicBase(CustomTestCase):
args.n_start = 10
args.n_trials = 10
args.temperature = 0.5 # test for deterministic sampling
args.return_logprob = True # Enable logprobs comparison
results = test_deterministic(args)
for result in results:
assert result == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment