Unverified Commit d7056c52 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Enhance tests in deterministic kernels (#12070)

parent 13bf565d
...@@ -17,7 +17,7 @@ import dataclasses ...@@ -17,7 +17,7 @@ import dataclasses
import json import json
import os import os
import random import random
from typing import List from typing import Any, Dict, List, Optional
import requests import requests
...@@ -78,6 +78,7 @@ class BenchArgs: ...@@ -78,6 +78,7 @@ class BenchArgs:
"single", "single",
"prefix", "prefix",
"radix_cache", "radix_cache",
"p_vs_d",
], ],
) )
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
...@@ -101,6 +102,8 @@ def send_single( ...@@ -101,6 +102,8 @@ def send_single(
input_ids: List[int] = None, input_ids: List[int] = None,
prompt: List[str] = None, prompt: List[str] = None,
max_new_tokens: int = None, max_new_tokens: int = None,
extra_params: Optional[Dict[str, Any]] = None,
pick_first_result: bool = True,
): ):
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
...@@ -121,6 +124,7 @@ def send_single( ...@@ -121,6 +124,7 @@ def send_single(
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
**(extra_params or {}),
} }
else: else:
assert input_ids is None assert input_ids is None
...@@ -138,6 +142,7 @@ def send_single( ...@@ -138,6 +142,7 @@ def send_single(
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
**(extra_params or {}),
} }
if args.sampling_seed is not None: if args.sampling_seed is not None:
...@@ -170,7 +175,8 @@ def send_single( ...@@ -170,7 +175,8 @@ def send_single(
else: else:
ret = response.json() ret = response.json()
ret = ret[0] if isinstance(ret, list) else ret if pick_first_result:
ret = ret[0] if isinstance(ret, list) else ret
if return_full_response: if return_full_response:
return ret return ret
...@@ -220,6 +226,127 @@ def send_prefix(args, batch_size: int, prompts: List[str]): ...@@ -220,6 +226,127 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
return ret_dict return ret_dict
def _test_mode_p_vs_d(args, batch_size):
print()
print(f"Execute: test p_vs_d {batch_size=}")
random.seed(42)
args.return_logprob = True
query_extra_params = {
"logprob_start_len": 0,
"return_text_in_logprobs": True,
}
def _create_prompts():
ans = [PROMPT_1, PROMPT_2]
for i in range(batch_size - len(ans)):
end = random.randrange(1, 4096)
if random.random() < 0.5:
begin = 0
else:
begin = random.randrange(0, end)
ans.append(LONG_PROMPT[begin:end])
return ans[:batch_size]
# warmup + flush
send_single(args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True)
requests.post(f"http://{args.host}:{args.port}/flush_cache")
prompts = _create_prompts()
resp_a = send_single(
args,
prompt=prompts,
max_new_tokens=args.max_new_tokens,
return_full_response=True,
pick_first_result=False,
extra_params=query_extra_params,
)
info_a = _extract_ids_and_logprobs(resp_a)
requests.post(f"http://{args.host}:{args.port}/flush_cache")
resp_b = send_single(
args,
input_ids=[x["io"].token_ids for x in info_a],
max_new_tokens=1,
return_full_response=True,
pick_first_result=False,
extra_params=query_extra_params,
)
info_b = _extract_ids_and_logprobs(resp_b)
ans = []
for i, (info_a_item, info_b_item) in enumerate(zip(info_a, info_b, strict=True)):
print(f"Compare sequence {i} in batch...")
correct = TokenIdsAndLogprobs.compare(info_a_item["io"], info_b_item["input"])
ans.append(int(correct))
return ans
@dataclasses.dataclass
class TokenIdsAndLogprobs:
token_ids: List[int]
logprobs: List[float]
def __add__(self, other):
return TokenIdsAndLogprobs(
token_ids=self.token_ids + other.token_ids,
logprobs=self.logprobs + other.logprobs,
)
@classmethod
def compare(cls, a: "TokenIdsAndLogprobs", b: "TokenIdsAndLogprobs"):
assert len(a.token_ids) == len(b.token_ids)
token_match = a.token_ids == b.token_ids
logprobs_match = a.logprobs == b.logprobs
if token_match:
print(f"Token match: {a.token_ids}")
else:
print(f"❗Token mismatch: {a.token_ids=} {b.token_ids=}")
if logprobs_match:
print(f"Logprobs match:", a.logprobs)
else:
print(f"❗Logprobs mismatch")
print(
" A: ",
[f"{x:.10f}" if x is not None else "None" for x in a.logprobs],
)
print(
" B: ",
[f"{x:.10f}" if x is not None else "None" for x in b.logprobs],
)
diff = [
abs(x - y) if x is not None else float("nan")
for x, y in zip(a.logprobs, b.logprobs)
]
print(" Diff:", [f"{x:.10e}" for x in diff])
return token_match and logprobs_match
def _extract_ids_and_logprobs(responses):
def _extract_part(response, name):
token_ids, logprobs = [], []
for item in response["meta_info"][name]:
logprob, token_id, text = item
token_ids.append(token_id)
logprobs.append(logprob)
return TokenIdsAndLogprobs(token_ids=token_ids, logprobs=logprobs)
def _extract_one_response(response):
input = _extract_part(response, "input_token_logprobs")
output = _extract_part(response, "output_token_logprobs")
return dict(input=input, output=output, io=input + output)
if not isinstance(responses, list):
responses = [responses]
return [_extract_one_response(x) for x in responses]
def test_deterministic(args): def test_deterministic(args):
if args.test_mode == "single": if args.test_mode == "single":
# In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials. # In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
...@@ -416,6 +543,13 @@ def test_deterministic(args): ...@@ -416,6 +543,13 @@ def test_deterministic(args):
print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗") print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
return [0] return [0]
elif args.test_mode == "p_vs_d":
# TODO also extract other modes to functions
ans = []
for i in range(1, args.n_trials + 1):
ans += _test_mode_p_vs_d(args, batch_size=i)
return ans
else: else:
raise ValueError(f"Invalid test mode: {args.test_mode}") raise ValueError(f"Invalid test mode: {args.test_mode}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment