Commit f7d93931 authored by Alisehen's avatar Alisehen
Browse files

Merge remote-tracking branch 'origin/main' into check-para

parents 99540ad0 7e4813e8
...@@ -14,15 +14,10 @@ from ktransformers.server.backend.base import BackendInterfaceBase ...@@ -14,15 +14,10 @@ from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
# Define own data structure instead of importing from OpenAI # Define own data structure instead of importing from OpenAI
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]] = None
completion_tokens_details: Optional[Dict[str, Any]] = None
class Choice(BaseModel): class Choice(BaseModel):
index: int index: int
...@@ -267,6 +262,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): ...@@ -267,6 +262,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
completion_tokens=raw_usage.decode_count, completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count total_tokens=raw_usage.prefill_count + raw_usage.decode_count
) )
if create.return_speed:
chunk.usage.prefill_time = res.prefill_time
chunk.usage.decode_time = res.decode_time
else:
chunk.usage.__dict__.pop('prefill_time', None)
chunk.usage.__dict__.pop('decode_time', None)
yield chunk yield chunk
elif isinstance(res, tuple) and len(res) == 2: elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res token, finish_reason = res
...@@ -427,8 +428,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): ...@@ -427,8 +428,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
usage = CompletionUsage( usage = CompletionUsage(
prompt_tokens=raw_usage.prefill_count, prompt_tokens=raw_usage.prefill_count,
completion_tokens=raw_usage.decode_count, completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count total_tokens=raw_usage.prefill_count + raw_usage.decode_count,
) )
if create.return_speed:
usage.prefill_time = res.prefill_time
usage.decode_time = res.decode_time
else:
usage.__dict__.pop('prefill_time', None)
usage.__dict__.pop('decode_time', None)
elif isinstance(res, tuple) and len(res) == 2: elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res token, finish_reason = res
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
......
...@@ -46,6 +46,8 @@ import pickle ...@@ -46,6 +46,8 @@ import pickle
import subprocess import subprocess
import tempfile import tempfile
import atexit import atexit
import signal
ktransformer_rules_dir = ( ktransformer_rules_dir = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
...@@ -55,6 +57,7 @@ default_optimize_rules = { ...@@ -55,6 +57,7 @@ default_optimize_rules = {
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
} }
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer): async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
streamer = TextStreamer(tokenizer) streamer = TextStreamer(tokenizer)
while True: while True:
...@@ -293,10 +296,6 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -293,10 +296,6 @@ class BalanceServeInterface(BackendInterfaceBase):
kvcache_event.wait() kvcache_event.wait()
def cleanup():
if sched_process.poll() is None:
sched_process.terminate()
with tempfile.NamedTemporaryFile(delete=False) as temp_file: with tempfile.NamedTemporaryFile(delete=False) as temp_file:
pickle.dump(args, temp_file) pickle.dump(args, temp_file)
temp_file_path = temp_file.name temp_file_path = temp_file.name
...@@ -311,7 +310,27 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -311,7 +310,27 @@ class BalanceServeInterface(BackendInterfaceBase):
stderr=log stderr=log
) )
print("sched_rpc started with PID:", sched_process.pid) print("sched_rpc started with PID:", sched_process.pid)
atexit.register(cleanup)
def signal_handler(signum, frame):
print(f"Received signal {signum}, shutting down...")
cleanup()
os._exit(0)
def cleanup():
print("Cleaning up...")
for p in processes:
if p.is_alive():
print(f"Terminating subprocess {p.pid}")
p.terminate()
p.join()
if sched_process and sched_process.poll() is None:
print(f"Terminating sched_process {sched_process.pid}")
sched_process.terminate()
sched_process.wait()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
start_event.wait() start_event.wait()
......
...@@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any ...@@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice from openai.types.chat.chat_completion_chunk import Choice
from uuid import uuid4 from uuid import uuid4
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]] = None
completion_tokens_details: Optional[Dict[str, Any]] = None
prefill_time: Optional[float] = None
decode_time: Optional[float] = None
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
...@@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel): ...@@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message] messages: List[Message]
model: str model: str
stream: bool = False stream: bool = False
temperature: Optional[float] = Field(default=0.6) temperature: Optional[float] = Field(default=Config().temperature)
top_p: Optional[float] = Field(default=1.0) top_p: Optional[float] = Field(default=Config().top_p)
tools: Optional[List[Tool]] = None tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None tool_choice: Optional[Union[str, Dict[str, Any]]] = None
stream_options: Optional[Dict[str, Any]] = None stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0 frequency_penalty: float = 0
presence_penalty: float = 0 presence_penalty: float = 0
max_tokens: Optional[int] = Field(default=50) max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
max_completion_tokens: Optional[int] = Field(default=50) max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
return_speed: Optional[bool] = Field(default=False)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
......
from typing import List, Optional from typing import List, Optional
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ktransformers.server.config.config import Config
from ..base import Object from ..base import Object
class CompletionCreate(BaseModel): class CompletionCreate(BaseModel):
model: str model: str
prompt: str | List[str] prompt: str | List[str]
stream: bool = False stream: bool = False
temperature: Optional[float] = Field(default=0.6) temperature: Optional[float] = Field(default=Config().temperature)
top_p: Optional[float] = Field(default=1) top_p: Optional[float] = Field(default=Config().top_p)
max_tokens: Optional[int] = Field(default=50) max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
max_completion_tokens: Optional[int] = Field(default=50) max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
if isinstance(self.prompt,List): if isinstance(self.prompt,List):
......
...@@ -25,19 +25,10 @@ class DataEvaluator: ...@@ -25,19 +25,10 @@ class DataEvaluator:
""" """
# 读取 Parquet 文件 # 读取 Parquet 文件
# dataset = load_dataset('parquet', data_files=file_path) # dataset = load_dataset('parquet', data_files=file_path)
ds = load_dataset(file_path,"all") splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
df = pd.DataFrame(ds['test']) 'dev': 'all/dev-00000-of-00001.parquet',
# print(ds) 'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
# # ds_1 = ds['train'] df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
# ds_2 = ds['validation']
# ds_3 = ds['test']
# # 将数据集转换为 Pandas DataFrame
# df_test = pd.DataFrame(ds['test'])
# df_val = pd.DataFrame(ds['validation'])
# for _, row in df.iterrows():
# self.data.append(row.to_dict())
# df = pd.read_parquet(file_path)
for _, row in df.iterrows(): for _, row in df.iterrows():
self.data.append(row.to_dict()) self.data.append(row.to_dict())
......
...@@ -8,12 +8,57 @@ from datasets import load_dataset ...@@ -8,12 +8,57 @@ from datasets import load_dataset
import os import os
import concurrent.futures import concurrent.futures
import threading import threading
import re
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = '' os.environ['https_proxy'] = ''
os.environ['http_proxy'] = '' os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.' hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
def extract_final_answer(text):
"""
提取模型预测的最终选项(如 A/B/C/D)
支持自然语言、多行、markdown、高亮、非末尾结论等格式
"""
text = text.strip()
# 1. 显式语句匹配(优先)
explicit_patterns = [
r'Answer:\s*([A-D])\b',
r'Correct answer:\s*([A-D])\b',
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
r'Answer is\s*([A-D])\b',
r'Therefore,\s*answer is\s*([A-D])\b',
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
r'Option\s+([A-D])\s+is correct',
]
for pat in explicit_patterns:
match = re.search(pat, text, re.IGNORECASE)
if match:
return match.group(1).upper()
# 2. markdown 强调 **C**, **C. something**
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
if markdown_match:
return markdown_match[-1].upper()
# 3. 查找单引号中的 'C' 或 "C"
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
if quote_match:
return quote_match[-1].upper()
# 4. 倒数几行是否以 "C." 或 "C" 开头
lines = text.splitlines()
for line in reversed(lines[-5:]):
line = line.strip()
match = re.match(r'^([A-D])([.\s]|$)', line)
if match:
return match.group(1).upper()
# 再不行就返回 None
return None
class DataEvaluator: class DataEvaluator:
def __init__(self): def __init__(self):
self.data = [] self.data = []
...@@ -22,8 +67,10 @@ class DataEvaluator: ...@@ -22,8 +67,10 @@ class DataEvaluator:
""" """
从数据文件中加载数据,每条记录对应一个实例 从数据文件中加载数据,每条记录对应一个实例
""" """
ds = load_dataset(file_path, "all") splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
df = pd.DataFrame(ds['test']) 'dev': 'all/dev-00000-of-00001.parquet',
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
for _, row in df.iterrows(): for _, row in df.iterrows():
self.data.append(row.to_dict()) self.data.append(row.to_dict())
...@@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False): ...@@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name): def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
start_total_time = time.time() start_total_time = time.time()
total_score = 0 total_score = 0
total_exact_score = 0
results = [] results = []
file_lock = threading.Lock() file_lock = threading.Lock()
...@@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi ...@@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
def worker(index, data_item): def worker(index, data_item):
nonlocal total_score nonlocal total_score
nonlocal total_exact_score
question = data_evaluator.get_prompt(data_item) question = data_evaluator.get_prompt(data_item)
start_time = time.time() start_time = time.time()
try: try:
...@@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi ...@@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
answer = chr(data_item['answer'] + 65) answer = chr(data_item['answer'] + 65)
processed_prediction = data_evaluator.post_processing(prediction) processed_prediction = data_evaluator.post_processing(prediction)
score = data_evaluator.score(processed_prediction, answer) score = data_evaluator.score(processed_prediction, answer)
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
result_data = { result_data = {
"question_id": index, "question_id": index,
"answer": answer, "answer": answer,
"prediction": processed_prediction, "prediction": processed_prediction,
"real_prediction": prediction, "full_prediction": prediction,
"score": score, "score": score,
"exact_score": exact_score,
"time": elapsed_time "time": elapsed_time
} }
# 写入结果时加锁保证线程安全 # 写入结果时加锁保证线程安全
...@@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi ...@@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
if res is not None: if res is not None:
results.append(res) results.append(res)
total_score += res['score'] total_score += res['score']
total_exact_score += res['exact_score']
total_time = time.time() - start_total_time total_time = time.time() - start_total_time
throughput = len(data_subset) / total_time if total_time > 0 else 0 throughput = len(data_subset) / total_time if total_time > 0 else 0
...@@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi ...@@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
log_f.write(f"Throughput: {throughput:.2f} requests per second\n") log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
average_score = total_score / len(data_subset) if data_subset else 0 average_score = total_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Score: {average_score}\n") log_f.write(f"Average Score: {average_score}\n")
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Exact Score: {average_exact_score}\n")
log_f.write('-' * 40 + '\n') log_f.write('-' * 40 + '\n')
print(f"Results saved to {result_file}") print(f"Results saved to {result_file}")
...@@ -152,4 +206,4 @@ if __name__ == "__main__": ...@@ -152,4 +206,4 @@ if __name__ == "__main__":
data_evaluator = DataEvaluator() data_evaluator = DataEvaluator()
data_evaluator.load_data(args.file) data_evaluator.load_data(args.file)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model) main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
\ No newline at end of file
...@@ -2,23 +2,18 @@ import asyncio ...@@ -2,23 +2,18 @@ import asyncio
import json import json
import sys import sys
import aiohttp import aiohttp
import random
import argparse import argparse
import yaml
import os
import time
from time import sleep
decodesz = 128
# Server URL (replace with your server URL)
SERVER_URL = "http://localhost:10002/v1/chat/completions"
bf_list = [1]
decodesz_list = [128]
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
async def fetch_event_stream(session, payload, request_id):
try:
prompt_list = [
'Please elaborate on modern world history.',
'Please introduce Harry Potter.',
'I want to learn Python. Please give me some advice.',
'Please tell me a joke '
]
async def fetch_event_stream(session, payload, request_id, stream):
try:
headers = { headers = {
'accept': 'application/json', 'accept': 'application/json',
'Content-Type': 'application/json' 'Content-Type': 'application/json'
...@@ -31,104 +26,80 @@ async def fetch_event_stream(session, payload, request_id): ...@@ -31,104 +26,80 @@ async def fetch_event_stream(session, payload, request_id):
print(f"Request {request_id}: Error, status {response.status}") print(f"Request {request_id}: Error, status {response.status}")
return return
output_text = "" # 存储当前 response 的所有 token output_text = ""
total_tokens = 0 # 统计总 tokens 数
decode_start_time = None # 记录 decode 阶段开始时间 if stream:
decode_end_time = None # 记录 decode 结束时间 async for line in response.content:
try:
async for line in response.content: decoded_line = line.decode("utf-8").strip()
try: if not decoded_line or not decoded_line.startswith("data: "):
decoded_line = line.decode("utf-8").strip() continue
# 过滤空行 decoded_line = decoded_line[6:].strip()
if not decoded_line or not decoded_line.startswith("data: "): if not decoded_line:
continue continue
decoded_line = decoded_line[6:].strip() # 去掉 `data: ` response_data = json.loads(decoded_line)
choices = response_data.get("choices", [])
# 确保 JSON 数据是合法的 if not choices:
if not decoded_line: continue
continue
delta = choices[0].get("delta", {})
response_data = json.loads(decoded_line) # 解析 JSON token = delta.get("content", "")
# 确保 choices 存在 if token:
choices = response_data.get("choices", []) output_text += token
if not choices: sys.stdout.write(token)
continue sys.stdout.flush()
delta = choices[0].get("delta", {}) finish_reason = choices[0].get("finish_reason", None)
token = delta.get("content", "") if finish_reason:
break
if token:
if decode_start_time is None: except json.JSONDecodeError as e:
decode_start_time = time.time() # 记录 decode 开始时间 print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
output_text += token # 追加 token print(f"\nRequest {request_id}: List Index Error - choices is empty")
sys.stdout.write(token) # 直接输出 token except Exception as e:
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端 print(f"\nRequest {request_id}: Error parsing stream - {e}")
total_tokens += 1 # 增加 token 计数 else:
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间 # 非 stream 模式下,一次性接收完整 json
response_data = await response.json()
# 检查是否完成 choices = response_data.get("choices", [])
finish_reason = choices[0].get("finish_reason", None) if choices:
if finish_reason: content = choices[0].get("message", {}).get("content", "")
# print(f"\nRequest {request_id}: Done") print(f"Request {request_id} Output:\n{content}")
break # 结束流式处理 output_text += content
except json.JSONDecodeError as e:
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
except IndexError:
print(f"\nRequest {request_id}: List Index Error - choices is empty")
except Exception as e:
print(f"\nRequest {request_id}: Error parsing stream - {e}")
# 计算 decode 速度
if decode_start_time and decode_end_time and total_tokens > 0:
decode_time = decode_end_time - decode_start_time
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
except Exception as e: except Exception as e:
print(f"\nRequest {request_id}: Exception - {e}") print(f"\nRequest {request_id}: Exception - {e}")
async def main(prompt_id): async def main(prompt_id, model, stream, max_tokens, temperature, top_p):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
payload = { payload = {
"messages": [ "messages": [
{"role": "system", "content": ""}, {"role": "system", "content": ""},
{"role": "user", "content": prompt_list[prompt_id]} {"role": "user", "content": prompt_list[prompt_id]}
], ],
"model": "DeepSeek-V3", "model": model,
"stream": True, "stream": stream,
"max_completion_tokens": 2, "max_tokens": max_tokens,
# "temperature": 0.3, "temperature": temperature,
# "top_p": 1.0, "top_p": top_p
# "max_tokens" : 20,
} }
tasks = [fetch_event_stream(session, payload, prompt_id)] tasks = [fetch_event_stream(session, payload, prompt_id, stream)]
await asyncio.gather(*tasks)
payload["temperature"] = 0.3
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["top_p"] = 1
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["max_tokens"] = 200
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks)
payload["stream"] = False
tasks = [fetch_event_stream(session, payload, prompt_id)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--question_id", type=int, default=0, required=False) parser.add_argument("--question_id", type=int, default=0)
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--stream", type=bool, default=True)
parser.add_argument("--max_tokens", type=int, default=500)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=1)
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
args = parser.parse_args() args = parser.parse_args()
output_file = "ktransformer_test_results.txt" SERVER_URL = args.api_url
asyncio.run(main(args.question_id)) asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))
...@@ -12,6 +12,8 @@ from time import sleep ...@@ -12,6 +12,8 @@ from time import sleep
decodesz = 128 decodesz = 128
# Server URL (replace with your server URL) # Server URL (replace with your server URL)
decodesz_list = [128] decodesz_list = [128]
prefill_speeds = []
decode_speeds = []
ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills.
He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs.
...@@ -43,17 +45,19 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c ...@@ -43,17 +45,19 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
Mr. Dursley always sat with his back to the window in his office on the ninth floor.""" Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
async def fetch_event_stream(session, request_id, prompt): async def fetch_event_stream(session, request_id, prompt, max_tokens, model):
try: try:
payload = { payload = {
"messages": [ "messages": [
{"role": "system", "content": ""}, {"role": "system", "content": ""},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}
], ],
"model": "DeepSeek-V3", "model": model,
"temperature": 0.3, "temperature": 0.3,
"top_p": 1.0, "top_p": 1.0,
"stream": True "stream": True,
"return_speed": True,
"max_tokens": max_tokens,
} }
headers = { headers = {
...@@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt): ...@@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt):
total_tokens = 0 total_tokens = 0
decode_start_time = None decode_start_time = None
decode_end_time = None decode_end_time = None
usage_info = None
async for line in response.content: async for line in response.content:
try: try:
...@@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt): ...@@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt):
continue continue
response_data = json.loads(decoded_line) response_data = json.loads(decoded_line)
if "usage" in response_data:
usage_info = response_data["usage"]
choices = response_data.get("choices", []) choices = response_data.get("choices", [])
if not choices: if not choices:
continue continue
...@@ -107,34 +116,48 @@ async def fetch_event_stream(session, request_id, prompt): ...@@ -107,34 +116,48 @@ async def fetch_event_stream(session, request_id, prompt):
except Exception as e: except Exception as e:
print(f"[Request {request_id}] Stream Error: {e}") print(f"[Request {request_id}] Stream Error: {e}")
if buffer.strip(): if buffer.strip():
print(f"[Request {request_id}] {buffer.strip()}") print(f"[Request {request_id}] {buffer.strip()}")
if decode_start_time and decode_end_time and total_tokens > 0: if usage_info:
decode_time = decode_end_time - decode_start_time if "prefill_time" in usage_info:
decode_speed = total_tokens / decode_time if decode_time > 0 else 0 # print(f"[Request {request_id}] Usage:")
print(f"[Request {request_id}] Speed: {decode_speed:.2f} tokens/s") # for key, value in usage_info.items():
# print(f" {key}: {value}")
prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"]
decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"]
prefill_speeds.append(prefill_speed)
decode_speeds.append(decode_speed)
print(f'[Request {request_id}] prefill speed: {prefill_speed}')
print(f'[Request {request_id}] decode speed: {decode_speed}')
except Exception as e: except Exception as e:
print(f"[Request {request_id}] Exception: {e}") print(f"[Request {request_id}] Exception: {e}")
async def main(concurrent_requests , prompt ): async def main(concurrent_requests , prompt, max_tokens, model):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)] tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if len(prefill_speeds) != 0:
import numpy as np
print(f"concurrency: {len(prefill_speeds)}")
print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True)
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
args = parser.parse_args() args = parser.parse_args()
SERVER_URL = args.api_url SERVER_URL = args.api_url
max_tokens = args.max_tokens
model = args.model
if args.prompt_lens == 1024: if args.prompt_lens == 1024:
prompt = ktansformer_prompt1024 prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048: elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2 prompt = ktansformer_prompt1024 * 2
asyncio.run(main(args.concurrent, prompt)) asyncio.run(main(args.concurrent, prompt, max_tokens, model))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment