Unverified Commit 7e4813e8 authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1184 from kvcache-ai/update_param

change test
parents 73935878 3a044e6b
......@@ -25,19 +25,10 @@ class DataEvaluator:
"""
# 读取 Parquet 文件
# dataset = load_dataset('parquet', data_files=file_path)
ds = load_dataset(file_path,"all")
df = pd.DataFrame(ds['test'])
# print(ds)
# # ds_1 = ds['train']
# ds_2 = ds['validation']
# ds_3 = ds['test']
# # 将数据集转换为 Pandas DataFrame
# df_test = pd.DataFrame(ds['test'])
# df_val = pd.DataFrame(ds['validation'])
# for _, row in df.iterrows():
# self.data.append(row.to_dict())
# df = pd.read_parquet(file_path)
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
'dev': 'all/dev-00000-of-00001.parquet',
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
for _, row in df.iterrows():
self.data.append(row.to_dict())
......
......@@ -8,12 +8,57 @@ from datasets import load_dataset
import os
import concurrent.futures
import threading
import re
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = ''
os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
def extract_final_answer(text):
"""
提取模型预测的最终选项(如 A/B/C/D)
支持自然语言、多行、markdown、高亮、非末尾结论等格式
"""
text = text.strip()
# 1. 显式语句匹配(优先)
explicit_patterns = [
r'Answer:\s*([A-D])\b',
r'Correct answer:\s*([A-D])\b',
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
r'Answer is\s*([A-D])\b',
r'Therefore,\s*answer is\s*([A-D])\b',
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
r'Option\s+([A-D])\s+is correct',
]
for pat in explicit_patterns:
match = re.search(pat, text, re.IGNORECASE)
if match:
return match.group(1).upper()
# 2. markdown 强调 **C**, **C. something**
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
if markdown_match:
return markdown_match[-1].upper()
# 3. 查找单引号中的 'C' 或 "C"
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
if quote_match:
return quote_match[-1].upper()
# 4. 倒数几行是否以 "C." 或 "C" 开头
lines = text.splitlines()
for line in reversed(lines[-5:]):
line = line.strip()
match = re.match(r'^([A-D])([.\s]|$)', line)
if match:
return match.group(1).upper()
# 再不行就返回 None
return None
class DataEvaluator:
def __init__(self):
self.data = []
......@@ -22,8 +67,10 @@ class DataEvaluator:
"""
从数据文件中加载数据,每条记录对应一个实例
"""
ds = load_dataset(file_path, "all")
df = pd.DataFrame(ds['test'])
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
'dev': 'all/dev-00000-of-00001.parquet',
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
for _, row in df.iterrows():
self.data.append(row.to_dict())
......@@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
start_total_time = time.time()
total_score = 0
total_exact_score = 0
results = []
file_lock = threading.Lock()
......@@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
def worker(index, data_item):
nonlocal total_score
nonlocal total_exact_score
question = data_evaluator.get_prompt(data_item)
start_time = time.time()
try:
......@@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
answer = chr(data_item['answer'] + 65)
processed_prediction = data_evaluator.post_processing(prediction)
score = data_evaluator.score(processed_prediction, answer)
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
elapsed_time = time.time() - start_time
result_data = {
"question_id": index,
"answer": answer,
"prediction": processed_prediction,
"real_prediction": prediction,
"full_prediction": prediction,
"score": score,
"exact_score": exact_score,
"time": elapsed_time
}
# 写入结果时加锁保证线程安全
......@@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
if res is not None:
results.append(res)
total_score += res['score']
total_exact_score += res['exact_score']
total_time = time.time() - start_total_time
throughput = len(data_subset) / total_time if total_time > 0 else 0
......@@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
average_score = total_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Score: {average_score}\n")
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Exact Score: {average_exact_score}\n")
log_f.write('-' * 40 + '\n')
print(f"Results saved to {result_file}")
......@@ -152,4 +206,4 @@ if __name__ == "__main__":
data_evaluator = DataEvaluator()
data_evaluator.load_data(args.file)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
\ No newline at end of file
......@@ -2,23 +2,18 @@ import asyncio
import json
import sys
import aiohttp
import random
import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1]
decodesz_list = [128]
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
async def fetch_event_stream(session, payload, request_id):
try:
prompt_list = [
'Please elaborate on modern world history.',
'Please introduce Harry Potter.',
'I want to learn Python. Please give me some advice.',
'Please tell me a joke '
]
async def fetch_event_stream(session, payload, request_id, stream):
try:
headers = {
'accept': 'application/json',
'Content-Type': 'application/json'
......@@ -31,104 +26,80 @@ async def fetch_event_stream(session, payload, request_id):
print(f"Request {request_id}: Error, status {response.status}")
return
output_text = "" # 存储当前 response 的所有 token
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间
decode_end_time = None # 记录 decode 结束时间
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
# 过滤空行
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
# 确保 JSON 数据是合法的
if not decoded_line:
continue
response_data = json.loads(decoded_line) # 解析 JSON
# 确保 choices 存在
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
if decode_start_time is None:
decode_start_time = time.time() # 记录 decode 开始时间
output_text += token # 追加 token
sys.stdout.write(token) # 直接输出 token
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
total_tokens += 1 # 增加 token 计数
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
# 检查是否完成
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
# print(f"\nRequest {request_id}: Done")
break # 结束流式处理
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
output_text = ""
if stream:
async for line in response.content:
try:
decoded_line = line.decode("utf-8").strip()
if not decoded_line or not decoded_line.startswith("data: "):
continue
decoded_line = decoded_line[6:].strip()
if not decoded_line:
continue
response_data = json.loads(decoded_line)
choices = response_data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
token = delta.get("content", "")
if token:
output_text += token
sys.stdout.write(token)
sys.stdout.flush()
finish_reason = choices[0].get("finish_reason", None)
if finish_reason:
break
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
else:
# 非 stream 模式下,一次性接收完整 json
response_data = await response.json()
choices = response_data.get("choices", [])
if choices:
content = choices[0].get("message", {}).get("content", "")
print(f"Request {request_id} Output:\n{content}")
output_text += content
except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}")
async def main(prompt_id):
async def main(prompt_id, model, stream, max_tokens, temperature, top_p):
async with aiohttp.ClientSession() as session:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt_list[prompt_id]}
],
"model": "DeepSeek-V3",
"stream": True,
"max_completion_tokens": 2,
# "temperature": 0.3,
# "top_p": 1.0,
# "max_tokens" : 20,
"model": model,
"stream": stream,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p
}
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["temperature"] = 0.3
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["top_p"] = 1
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["max_tokens"] = 200
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["stream"] = False
tasks = [fetch_event_stream(session, payload, prompt_id)]
tasks = [fetch_event_stream(session, payload, prompt_id, stream)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--question_id", type=int, default=0, required=False)
parser.add_argument("--question_id", type=int, default=0)
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--stream", type=bool, default=True)
parser.add_argument("--max_tokens", type=int, default=500)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=1)
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
args = parser.parse_args()
output_file = "ktransformer_test_results.txt"
asyncio.run(main(args.question_id))
SERVER_URL = args.api_url
asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))
......@@ -45,14 +45,14 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
async def fetch_event_stream(session, request_id, prompt, max_tokens):
async def fetch_event_stream(session, request_id, prompt, max_tokens, model):
try:
payload = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": prompt}
],
"model": "DeepSeek-V3",
"model": model,
"temperature": 0.3,
"top_p": 1.0,
"stream": True,
......@@ -134,17 +134,19 @@ async def fetch_event_stream(session, request_id, prompt, max_tokens):
except Exception as e:
print(f"[Request {request_id}] Exception: {e}")
async def main(concurrent_requests , prompt, max_tokens):
async def main(concurrent_requests , prompt, max_tokens, model):
async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)]
tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]
await asyncio.gather(*tasks)
if len(prefill_speeds) != 0:
import numpy as np
print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}")
print(f"concurrency: {len(prefill_speeds)}")
print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True)
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
......@@ -152,9 +154,10 @@ if __name__ == "__main__":
args = parser.parse_args()
SERVER_URL = args.api_url
max_tokens = args.max_tokens
model = args.model
if args.prompt_lens == 1024:
prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2
asyncio.run(main(args.concurrent, prompt, max_tokens))
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment