"vscode:/vscode.git/clone" did not exist on "345b3491a7b6f9a279f9d08c312a22153be8f969"
Commit 4c5a8a74 authored by chenych's avatar chenych
Browse files

stream client

parent e80e947e
import json
import argparse
import requests
import configparser
from typing import Iterable, List
parse = argparse.ArgumentParser()
parse.add_argument('--query', default='请写一首诗')
parse.add_argument('--use_hf', action='store_true')
args = parse.parse_args()
print(args.query)
headers = {"Content-Type": "application/json"}
data = {
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=1024, decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content.decode("utf-8"))
output = data["text"]
return output
def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--query', default='请写一首诗')
parser.add_argument('--use_hf', action='store_true')
parser.add_argument(
'--config_path', default='../config.ini', help='config目录')
args = parser.parse_args()
print(args.query)
headers = {"Content-Type": "application/json"}
data = {
"query": args.query,
"history": []
}
json_str = json.dumps(data)
if args.use_hf:
response = requests.post("http://localhost:8888/hf_inference", headers=headers, data=json_str.encode("utf-8"), verify=False)
else:
response = requests.post("http://localhost:8888/vllm_inference", headers=headers, data=json_str.encode("utf-8"), verify=False)
json_str = json.dumps(data)
config = configparser.ConfigParser()
config.read(args.config_path)
stream_chat = config.getboolean('llm', 'stream_chat')
func = 'vllm_inference'
if args.use_hf:
func = 'hf_inference'
api_url = f"http://localhost:8888/{func}"
response = requests.post(api_url, headers=headers, data=json_str.encode(
"utf-8"), verify=False, stream=stream_chat)
str_response = response.content.decode("utf-8")
print(json.loads(str_response))
if stream_chat:
num_printed_lines = 0
for h in get_streaming_response(response):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True)
"""Example Python client for vllm.entrypoints.api_server"""
import argparse
import json
from typing import Iterable, List
import requests
def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(query: str, api_url: str, n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"query": query,
"n": n,
"use_beam_search": True,
"temperature": 0.0,
"max_tokens": 16,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8888)
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--query", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
query = args.query
api_url = f"http://{args.host}:{args.port}/generate"
n = args.n
stream = args.stream
print(f"Prompt: {query!r}\n", flush=True)
response = post_http_request(query, api_url, n, stream)
if stream:
num_printed_lines = 0
for h in get_streaming_response(response):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment