Unverified Commit b288f4f4 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Improve `send_sone` script (#11817)

parent 6d6ea5af
...@@ -12,6 +12,7 @@ import dataclasses ...@@ -12,6 +12,7 @@ import dataclasses
import json import json
import requests import requests
import tabulate
from sglang.profiler import run_profile from sglang.profiler import run_profile
...@@ -141,12 +142,16 @@ def send_one_prompt(args): ...@@ -141,12 +142,16 @@ def send_one_prompt(args):
) )
if args.stream: if args.stream:
last_len = 0
for chunk in response.iter_lines(decode_unicode=False): for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8") chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"): if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]": if chunk == "data: [DONE]":
break break
ret = json.loads(chunk[5:].strip("\n")) ret = json.loads(chunk[5:].strip("\n"))
chunk_str = ret["text"][last_len:]
last_len = len(ret["text"])
print(chunk_str, end="", flush=True)
else: else:
ret = response.json() ret = response.json()
...@@ -157,21 +162,25 @@ def send_one_prompt(args): ...@@ -157,21 +162,25 @@ def send_one_prompt(args):
print(ret) print(ret)
return 0, 0 return 0, 0
latency = ret["meta_info"]["e2e_latency"] if "spec_verify_ct" in ret["meta_info"] and ret["meta_info"]["spec_verify_ct"] > 0:
if "spec_verify_ct" in ret["meta_info"]:
acc_length = ( acc_length = (
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"] ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
) )
else: else:
acc_length = 1.0 acc_length = 1.0
latency = ret["meta_info"]["e2e_latency"]
speed = ret["meta_info"]["completion_tokens"] / latency speed = ret["meta_info"]["completion_tokens"] / latency
tokens = ret["meta_info"]["completion_tokens"]
if not args.stream:
print(ret["text"]) print(ret["text"])
print() print()
print(f"{acc_length=:.2f}") headers = ["Latency (s)", "Tokens", "Acc Length", "Speed (token/s)"]
print(f"{speed=:.2f} token/s") rows = [[f"{latency:.3f}", f"{tokens}", f"{acc_length:.3f}", f"{speed:.2f}"]]
msg = tabulate.tabulate(rows, headers=headers, tablefmt="pretty")
print(msg)
return acc_length, speed return acc_length, speed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment