"vscode:/vscode.git/clone" did not exist on "05b13c9fe3114f44e4ce0cf984b7c0d9109ffe9a"
Unverified Commit 13219e1e authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

completely remove mixed mode deterministic test as prefix mode could cover it (#11783)


Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
parent 33e9bbec
......@@ -5,9 +5,6 @@ Usage:
# Single mode: test determinism with varying batch sizes
python3 -m sglang.test.test_deterministic --n-trials 50 --test-mode single
# Mixed mode: test with mixed prompts
python3 -m sglang.test.test_deterministic --n-trials 50 --test-mode mixed
# Prefix mode: test with shared prefixes
python3 -m sglang.test.test_deterministic --n-start 1 --n-trials 50 --test-mode prefix
......@@ -79,7 +76,6 @@ class BenchArgs:
default=BenchArgs.test_mode,
choices=[
"single",
"mixed",
"prefix",
"radix_cache",
],
......@@ -181,52 +177,6 @@ def send_single(
return ret["text"]
def send_mixed(args, batch_size: int):
num_long_prompt = 0 if batch_size <= 10 else random.randint(1, 10)
num_prompt_1 = random.randint(1, batch_size - num_long_prompt)
num_prompt_2 = batch_size - num_prompt_1 - num_long_prompt
json_data = {
"text": [PROMPT_1] * num_prompt_1
+ [PROMPT_2] * num_prompt_2
+ [LONG_PROMPT] * num_long_prompt,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
if args.sampling_seed is not None:
json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=args.stream,
)
ret = response.json()
if response.status_code != 200:
print(ret)
return -1, -1, -1
prompt_1_ret = [ret[i]["text"] for i in range(num_prompt_1)]
prompt_2_ret = [
ret[i]["text"] for i in range(num_prompt_1, num_prompt_1 + num_prompt_2)
]
long_prompt_ret = [
ret[i]["text"]
for i in range(
num_prompt_1 + num_prompt_2, num_prompt_1 + num_prompt_2 + num_long_prompt
)
]
return prompt_1_ret, prompt_2_ret, long_prompt_ret
def send_prefix(args, batch_size: int, prompts: List[str]):
requests.post(f"http://{args.host}:{args.port}/flush_cache")
......@@ -282,38 +232,6 @@ def test_deterministic(args):
print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
return [len(set(texts))]
elif args.test_mode == "mixed":
# In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials.
output_prompt_1 = []
output_prompt_2 = []
output_long_prompt = []
for i in range(1, args.n_trials + 1):
batch_size = i
ret_prompt_1, ret_prompt_2, ret_long_prompt = send_mixed(args, batch_size)
output_prompt_1.extend(ret_prompt_1)
output_prompt_2.extend(ret_prompt_2)
output_long_prompt.extend(ret_long_prompt)
print(
f"Testing Trial {i} with batch size {batch_size}, number of prompt 1: {len(ret_prompt_1)}, number of prompt 2: {len(ret_prompt_2)}, number of long prompt: {len(ret_long_prompt)}"
)
print(
f"Prompt 1: total samples: {len(output_prompt_1)}, Unique samples: {len(set(output_prompt_1))}"
)
print(
f"Prompt 2: total samples: {len(output_prompt_2)}, Unique samples: {len(set(output_prompt_2))}"
)
print(
f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}"
)
return [
len(set(output_prompt_1)),
len(set(output_prompt_2)),
len(set(output_long_prompt)),
]
elif args.test_mode == "prefix":
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
len_prefix = [1, 511, 2048, 4097]
......
......@@ -56,18 +56,6 @@ class TestDeterministicBase(CustomTestCase):
for result in results:
assert result == 1
def test_mixed(self):
args = BenchArgs()
url = DEFAULT_URL_FOR_TEST
args.host, args.port = self._extract_host_and_port(url)
args.test_mode = "mixed"
args.n_start = 10
args.n_trials = 20
args.temperature = 0.5 # test for deterministic sampling
results = test_deterministic(args)
for result in results:
assert result == 1
def test_prefix(self):
args = BenchArgs()
url = DEFAULT_URL_FOR_TEST
......
......@@ -69,7 +69,7 @@ suites = {
TestFile("test_build_eagle_tree.py", 8),
TestFile("test_chunked_prefill.py", 313),
TestFile("test_create_kvindices.py", 2),
TestFile("test_deterministic.py", 300),
TestFile("test_deterministic.py", 320),
TestFile("test_eagle_infer_a.py", 370),
TestFile("test_eagle_infer_b.py", 700),
TestFile("test_eagle_infer_beta.py", 300),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment