Unverified Commit 529e56bd authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

fix benchmark serving computation mistake (#630)

* fix benchmark serving computation mistake

* fix timestamps computations

* remove speed up

* no mp

* mp seems faster?

* remove

* update

* remove

* fix

* update

* update print log

* typo

* print fist token latency only stream==True

* remove renew_session

* update AsyncEngine
parent 11d10930
import json import json
import multiprocessing as mp
import random import random
import time import time
from queue import Queue
from threading import Thread
import fire import fire
import numpy as np import numpy as np
from lmdeploy.serve.openai.api_client import get_streaming_response from lmdeploy.serve.openai.api_client import get_streaming_response
from lmdeploy.tokenizer import Tokenizer from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger
def infer(server_addr: str, session_id: int, req_queue: mp.Queue, def infer(server_addr: str, session_id: int, req_queue: Queue, res_que: Queue,
res_que: mp.Queue): stream_output: bool):
stats = [] stats = []
while not req_queue.empty(): for prompt, input_seqlen, output_seqlen in iter(req_queue.get,
prompt, input_seqlen, output_seqlen = req_queue.get() [None, None, None]):
get_logger('profile_restful_api').info( if prompt is None:
f'request info: session {session_id}, ' break
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}')
timestamps = [] timestamps = []
tokens = [] tokens = []
start = time.perf_counter() timestamps.append(time.perf_counter())
for res, token, status in get_streaming_response( for res, token, status in get_streaming_response(
prompt, prompt,
server_addr, server_addr,
session_id, session_id,
request_output_len=output_seqlen, request_output_len=output_seqlen,
interactive_mode=False): interactive_mode=False,
ignore_eos=True,
stream=stream_output):
timestamps.append(time.perf_counter()) timestamps.append(time.perf_counter())
tokens.append(token) tokens.append(token)
first_token_latency = timestamps[1] - start first_token_latency = np.round(timestamps[1] - timestamps[0], 3)
token_latency = timestamps[-1] - timestamps[0] token_latency = np.round(timestamps[-1] - timestamps[0], 3)
token = tokens[-1] - tokens[0] completion_tokens = tokens[-1]
stats.append([first_token_latency, token, token_latency]) total_tokens = tokens[-1] + input_seqlen
stats.append([
first_token_latency, completion_tokens, output_seqlen,
total_tokens, token_latency
])
print(f'session {session_id}: '
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}, '
f'completion_tokens {completion_tokens}')
res_que.put((session_id, stats)) res_que.put((session_id, stats))
def warmup(server_addr: str, def warmup(server_addr: str,
concurrency: int, concurrency: int,
output_seqlen: int, output_seqlen: int,
warmup_round: int = 1): warmup_round: int = 1,
stream_output: bool = False):
print('start to warmup ...') print('start to warmup ...')
def _infer(server_addr, session_id): def _infer(server_addr, session_id):
...@@ -50,13 +59,15 @@ def warmup(server_addr: str, ...@@ -50,13 +59,15 @@ def warmup(server_addr: str,
server_addr, server_addr,
session_id, session_id,
request_output_len=output_seqlen, request_output_len=output_seqlen,
interactive_mode=False): interactive_mode=False,
stream=stream_output,
ignore_eos=True):
continue continue
_start = time.perf_counter() _start = time.perf_counter()
procs = [] procs = []
for i in range(concurrency): for i in range(concurrency):
proc = mp.Process(target=_infer, args=(server_addr, i + 1)) proc = Thread(target=_infer, args=(server_addr, i + 1))
procs.append(proc) procs.append(proc)
proc.start() proc.start()
for proc in procs: for proc in procs:
...@@ -79,6 +90,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, ...@@ -79,6 +90,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
print(f'elapsed time for read data: ' print(f'elapsed time for read data: '
f'{round(time.perf_counter() - start, 2)} s') f'{round(time.perf_counter() - start, 2)} s')
print('start tokenization. This takes a while, please wait...')
start = time.perf_counter() start = time.perf_counter()
tokenizer = Tokenizer(tokenizer_path) tokenizer = Tokenizer(tokenizer_path)
prompts_token_lens = [len(tokenizer.encode(prompt)) for prompt in prompts] prompts_token_lens = [len(tokenizer.encode(prompt)) for prompt in prompts]
...@@ -100,9 +112,10 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, ...@@ -100,9 +112,10 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
if samples > 0: if samples > 0:
filtered_dataset = random.sample(filtered_dataset, samples) filtered_dataset = random.sample(filtered_dataset, samples)
que = mp.Queue() que = Queue()
for data in filtered_dataset: for data in filtered_dataset:
que.put(data) que.put(data)
que.put((None, None, None))
print(f'elapsed time for filtering: ' print(f'elapsed time for filtering: '
f'{round(time.perf_counter() - start, 2)} s') f'{round(time.perf_counter() - start, 2)} s')
return que, len(filtered_dataset) return que, len(filtered_dataset)
...@@ -113,17 +126,20 @@ def main(server_addr: str, ...@@ -113,17 +126,20 @@ def main(server_addr: str,
dataset_path: str, dataset_path: str,
concurrency: int = 1, concurrency: int = 1,
session_len: int = 2048, session_len: int = 2048,
samples: int = 1000): samples: int = 1000,
stream_output: bool = False):
api_url = server_addr + '/v1/chat/interactive' api_url = server_addr + '/v1/chat/interactive'
warmup(api_url, concurrency, session_len - 1) warmup(api_url, concurrency, session_len - 1, 4, stream_output)
req_queue, n_req = read_dataset(tokenizer_path, dataset_path, samples, req_queue, n_req = read_dataset(tokenizer_path, dataset_path, samples,
session_len) session_len)
res_que = mp.Queue() for i in range(concurrency):
req_queue.put([None, None, None])
res_que = Queue()
procs = [] procs = []
_start = time.perf_counter() _start = time.perf_counter()
for i in range(concurrency): for i in range(concurrency):
proc = mp.Process(target=infer, proc = Thread(target=infer,
args=(api_url, i + 1, req_queue, res_que)) args=(api_url, i + 1, req_queue, res_que, stream_output))
procs.append(proc) procs.append(proc)
proc.start() proc.start()
for proc in procs: for proc in procs:
...@@ -138,22 +154,40 @@ def main(server_addr: str, ...@@ -138,22 +154,40 @@ def main(server_addr: str,
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n') f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(np.array(_stats)) stats.append(np.array(_stats))
stats = np.concatenate(stats).reshape(-1, 3) stats = np.concatenate(stats).reshape(-1, 5)
first_token_latency_min = np.min(stats[:, 0], axis=0) first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0) first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0) first_token_latency_ave = np.mean(stats[:, 0], axis=0)
token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time completion_tokens = np.sum(stats[:, 1], axis=0)
req_throughput = n_req / elapsed_time request_output_tokens = np.sum(stats[:, 2], axis=0)
total_tokens = np.sum(stats[:, 3], axis=0)
prompt_tokens = total_tokens - completion_tokens
completion_token_throughput = completion_tokens / elapsed_time
total_token_throughput = total_tokens / elapsed_time
rqs = n_req / elapsed_time
rqm = rqs * 60
if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False:
print(f'Did not generate requested number of tokens. '
f'Request {request_output_tokens:.0f}, '
f'but got {completion_tokens:.0f}')
print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.2f}s\n' f'elapsed_time: {elapsed_time:.3f}s\n')
f'first_token latency(min, max, ave): ' if stream_output:
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, ' print(f'first_token latency(min, max, ave): '
f'{first_token_latency_ave:.2f}s\n' f'{first_token_latency_min:.3f}s, '
f'token throughput: {token_throughput:.2f} token/s\n' f'{first_token_latency_max:.3f}s, '
f'req throughput: {req_throughput:.2f} req/s\n' f'{first_token_latency_ave:.3f}s\n')
f'{"-" * 50}\n') print(
f'number of prompt tokens: {prompt_tokens:.0f}\n'
f'number of completion tokens: {completion_tokens:.0f}\n'
f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa
f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa
f'RPS (request per second): {rqs:.3f} req/s\n'
f'RPM (request per minute): {rqm:.3f} req/min\n'
f'{"-" * 50}\n')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,7 +17,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): ...@@ -17,7 +17,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
[None, None, None]): [None, None, None]):
timestamps = [] timestamps = []
tokens = [] tokens = []
start = time.perf_counter() timestamps.append(time.perf_counter())
for status, res, token in chatbot.stream_infer( for status, res, token in chatbot.stream_infer(
session_id, session_id,
prompt, prompt,
...@@ -26,13 +26,17 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): ...@@ -26,13 +26,17 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
sequence_end=True): sequence_end=True):
timestamps.append(time.perf_counter()) timestamps.append(time.perf_counter())
tokens.append(token) tokens.append(token)
first_token_latency = np.round(timestamps[1] - timestamps[0], 3)
first_token_latency = np.round(timestamps[1] - start, 3)
token_latency = np.round(timestamps[-1] - timestamps[0], 3) token_latency = np.round(timestamps[-1] - timestamps[0], 3)
token = tokens[-1] - tokens[0] completion_tokens = tokens[-1]
stats.append([first_token_latency, token, token_latency]) total_tokens = tokens[-1] + input_seqlen
stats.append([
first_token_latency, completion_tokens, output_seqlen,
total_tokens, token_latency
])
print(f'session {session_id}: ' print(f'session {session_id}: '
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}') f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}, '
f'completion_tokens {completion_tokens}')
res_que.put((session_id, stats)) res_que.put((session_id, stats))
...@@ -84,6 +88,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, ...@@ -84,6 +88,7 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
completions = [completion for _, completion in dataset] completions = [completion for _, completion in dataset]
print(f'elapsed time for read data: ' print(f'elapsed time for read data: '
f'{round(time.perf_counter() - start, 2)} s') f'{round(time.perf_counter() - start, 2)} s')
print('start tokenization. This takes a while, please wait...')
start = time.perf_counter() start = time.perf_counter()
tokenizer = Tokenizer(tokenizer_path) tokenizer = Tokenizer(tokenizer_path)
...@@ -124,7 +129,6 @@ def main(tritonserver_addr: str, ...@@ -124,7 +129,6 @@ def main(tritonserver_addr: str,
res_que = mp.Queue() res_que = mp.Queue()
procs = [] procs = []
_start = time.perf_counter()
for i in range(concurrency): for i in range(concurrency):
chatbot = Chatbot(tritonserver_addr=tritonserver_addr, chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
display=False, display=False,
...@@ -134,13 +138,15 @@ def main(tritonserver_addr: str, ...@@ -134,13 +138,15 @@ def main(tritonserver_addr: str,
proc = mp.Process(target=infer, proc = mp.Process(target=infer,
args=(chatbot, i + 1, req_que, res_que)) args=(chatbot, i + 1, req_que, res_que))
procs.append(proc) procs.append(proc)
proc.start()
# read data and put it to queue # read data and put it to queue
n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len, n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len,
req_que) req_que)
for i in range(concurrency): for i in range(concurrency):
req_que.put([None, None, None]) req_que.put([None, None, None])
_start = time.perf_counter()
for proc in procs:
proc.start()
stats = [] stats = []
for i in range(concurrency): for i in range(concurrency):
...@@ -149,27 +155,42 @@ def main(tritonserver_addr: str, ...@@ -149,27 +155,42 @@ def main(tritonserver_addr: str,
f'session {session_id}: processed reqs {len(_stats)}, ' f'session {session_id}: processed reqs {len(_stats)}, '
f'stats: \n{_stats}\n{"-" * 50}\n') f'stats: \n{_stats}\n{"-" * 50}\n')
stats.append(np.array(_stats)) stats.append(np.array(_stats))
_end = time.perf_counter() _end = time.perf_counter()
elapsed_time = _end - _start elapsed_time = _end - _start
stats = np.concatenate(stats).reshape(-1, 3) stats = np.concatenate(stats).reshape(-1, 5)
first_token_latency_min = np.min(stats[:, 0], axis=0) first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0) first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0) first_token_latency_ave = np.mean(stats[:, 0], axis=0)
token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time completion_tokens = np.sum(stats[:, 1], axis=0)
req_throughput = n_req / elapsed_time request_output_tokens = np.sum(stats[:, 2], axis=0)
total_tokens = np.sum(stats[:, 3], axis=0)
print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' prompt_tokens = total_tokens - completion_tokens
f'elapsed_time: {elapsed_time:.3f}s\n' completion_token_throughput = completion_tokens / elapsed_time
f'first_token latency(min, max, ave): ' total_token_throughput = total_tokens / elapsed_time
f'{first_token_latency_min:.3f}s, {first_token_latency_max:.3f}s, ' rqs = n_req / elapsed_time
f'{first_token_latency_ave:.3f}s\n' rqm = rqs * 60
f'token throughput: {token_throughput:.3f} token/s\n'
f'req throughput: {req_throughput:.3f} req/s\n' if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False:
f'{"-" * 50}\n') print(f'Did not generate requested number of tokens. '
f'Request {request_output_tokens:.0f}, '
f'but got {completion_tokens:.0f}')
print(
f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.3f}s\n'
f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.3f}s, {first_token_latency_max:.3f}s, '
f'{first_token_latency_ave:.3f}s\n'
f'number of prompt tokens: {prompt_tokens:.0f}\n'
f'number of completion tokens: {completion_tokens:.0f}\n'
f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa
f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa
f'RPS (request per second): {rqs:.3f} req/s\n'
f'RPM (request per minute): {rqm:.3f} req/min\n'
f'{"-" * 50}\n')
for proc in procs: for proc in procs:
proc.join() proc.join()
......
...@@ -7,6 +7,7 @@ from threading import Thread ...@@ -7,6 +7,7 @@ from threading import Thread
from typing import List, Tuple from typing import List, Tuple
import fire import fire
import numpy as np
from lmdeploy.tokenizer import Tokenizer from lmdeploy.tokenizer import Tokenizer
from lmdeploy.turbomind import TurboMind from lmdeploy.turbomind import TurboMind
...@@ -24,8 +25,7 @@ def sample_requests( ...@@ -24,8 +25,7 @@ def sample_requests(
dataset = [data for data in dataset if len(data['conversations']) >= 2] dataset = [data for data in dataset if len(data['conversations']) >= 2]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [(data['conversations'][0]['value'], dataset = [(data['conversations'][0]['value'],
data['conversations'][1]['value']) data['conversations'][1]['value']) for data in dataset]
for data in dataset][:num_requests * 2] # speed up encoding
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompts = [prompt for prompt, _ in dataset]
...@@ -64,80 +64,131 @@ class Engine: ...@@ -64,80 +64,131 @@ class Engine:
self.tm_model = tm_model self.tm_model = tm_model
self.tokenizer = tokenizer self.tokenizer = tokenizer
def _inference(self, queue, session_id: int): def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
stream_output: bool):
model_inst = self.tm_model.create_instance() model_inst = self.tm_model.create_instance()
while True: stats = []
request = queue.get() timestamps = []
if request is None: tokens = []
# stop signal timestamps.append(time.perf_counter())
queue.put(None) for prompt, input_seqlen, output_seqlen in iter(
return req_queue.get, [None, None, None]):
else: input_ids = self.tokenizer.encode(prompt)
prompt, _, output_seqlen = request offset = 0
input_ids = self.tokenizer.encode(prompt) for outputs in model_inst.stream_infer(
session_id,
for outputs in model_inst.stream_infer( input_ids=input_ids,
session_id, request_output_len=output_seqlen,
input_ids=input_ids, temperature=1.0,
request_output_len=output_seqlen, top_p=1.0,
temperature=1.0, sequence_start=True,
top_p=1.0, sequence_end=True,
sequence_start=True, ignore_eos=True,
sequence_end=True, stream_output=stream_output):
ignore_eos=True): res, token = outputs[0]
res, tokens = outputs[0] self.tokenizer.decode(res, offset)
self.tokenizer.decode(res) offset = token
timestamps.append(time.perf_counter())
def process_request(self, requests, concurrency: int = 1): tokens.append(token)
q = Queue() first_token_latency = np.round(timestamps[1] - timestamps[0], 3)
token_latency = np.round(timestamps[-1] - timestamps[0], 3)
completion_tokens = tokens[-1]
total_tokens = tokens[-1] + len(input_ids)
stats.append([
first_token_latency, completion_tokens, output_seqlen,
total_tokens, token_latency
])
print(
f'session {session_id}: '
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}, '
f'completion_tokens {completion_tokens}')
res_queue.put((session_id, stats))
def process_request(self,
requests,
concurrency: int = 1,
stream_output: bool = True):
res_queue = Queue()
req_queue = Queue()
threads = [] threads = []
# feed request to q
for req in requests:
req_queue.put(req)
for i in range(concurrency):
req_queue.put([None, None, None])
start = time.time() start = time.time()
# start threads # start threads
for i in range(concurrency): for i in range(concurrency):
t = Thread(target=self._inference, args=(q, i)) t = Thread(target=self._inference,
args=(req_queue, res_queue, i, stream_output))
t.start() t.start()
threads.append(t) threads.append(t)
# feed request to q
for req in requests:
q.put(req)
q.put(None)
# wait for finish # wait for finish
for t in threads: for t in threads:
t.join() t.join()
end = time.time() elapsed_time = time.time() - start
return end - start stats = []
while not res_queue.empty():
session_id, _stats = res_queue.get()
print(f'\n{"-" * 50}\n'
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(np.array(_stats))
stats = np.concatenate(stats).reshape(-1, 5)
first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
completion_tokens = np.sum(stats[:, 1], axis=0)
request_output_tokens = np.sum(stats[:, 2], axis=0)
total_tokens = np.sum(stats[:, 3], axis=0)
prompt_tokens = total_tokens - completion_tokens
completion_token_throughput = completion_tokens / elapsed_time
total_token_throughput = total_tokens / elapsed_time
rqs = len(requests) / elapsed_time
rqm = rqs * 60
if (np.abs(stats[:, 1] - stats[:, 2]) <= 1).min() is False:
print(f'Did not generate requested number of tokens. '
f'Request {request_output_tokens:.0f}, '
f'but got {completion_tokens:.0f}')
print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.3f}s\n')
if stream_output:
print(f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.3f}s, '
f'{first_token_latency_max:.3f}s, '
f'{first_token_latency_ave:.3f}s\n')
print(
f'number of prompt tokens: {prompt_tokens:.0f}\n'
f'number of completion tokens: {completion_tokens:.0f}\n'
f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa
f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa
f'RPS (request per second): {rqs:.3f} req/s\n'
f'RPM (request per minute): {rqm:.3f} req/min\n'
f'{"-" * 50}\n')
def main(dataset: str, def main(dataset: str,
model_path: str, model_path: str,
concurrency: int = 1, concurrency: int = 1,
num_prompts: int = 1000, num_prompts: int = 1000,
tp: int = 1): tp: int = 1,
stream_output: bool = True):
engine = Engine(model_path, tp=tp) engine = Engine(model_path, tp=tp)
tokenizer = engine.tokenizer tokenizer = engine.tokenizer
requests = sample_requests(dataset, num_prompts, tokenizer) requests = sample_requests(dataset, num_prompts, tokenizer)
elapsed_time = engine.process_request(requests, concurrency) engine.process_request(requests, concurrency, stream_output)
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_num_out_tokens = sum(output_len for _, _, output_len in requests)
print(f'Throughput requests: {len(requests) / elapsed_time:.2f} req/s')
print(
f'Throughput requests: {len(requests) * 60 / elapsed_time:.2f} req/min'
)
print(f'Throughput tokens: {total_num_tokens / elapsed_time:.2f} tokens/s')
print('Throughput tokens(output only):'
f'{total_num_out_tokens / elapsed_time:.2f} tokens/s')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -107,6 +107,7 @@ class AsyncEngine: ...@@ -107,6 +107,7 @@ class AsyncEngine:
temperature=0.8, temperature=0.8,
repetition_penalty=1.0, repetition_penalty=1.0,
ignore_eos=False, ignore_eos=False,
do_preprocess=True,
**kwargs): **kwargs):
"""Inference a batch of prompts. """Inference a batch of prompts.
...@@ -122,6 +123,7 @@ class AsyncEngine: ...@@ -122,6 +123,7 @@ class AsyncEngine:
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
""" """
assert isinstance(prompts, List), 'prompts should be a list' assert isinstance(prompts, List), 'prompts should be a list'
batch_size = len(prompts) batch_size = len(prompts)
...@@ -139,7 +141,9 @@ class AsyncEngine: ...@@ -139,7 +141,9 @@ class AsyncEngine:
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
repetition_penalty=repetition_penalty)) repetition_penalty=repetition_penalty,
do_preprocess=do_preprocess,
**kwargs))
async def _inner_call(i, generator): async def _inner_call(i, generator):
async for out in generator: async for out in generator:
...@@ -153,22 +157,22 @@ class AsyncEngine: ...@@ -153,22 +157,22 @@ class AsyncEngine:
return outputs return outputs
async def generate( async def generate(
self, self,
messages, messages,
session_id, session_id,
stream_response=True, stream_response=True,
sequence_start=True, sequence_start=True,
sequence_end=True, # no interactive mode by default sequence_end=True, # no interactive mode by default
step=0, step=0,
request_output_len=512, request_output_len=512,
stop=False, stop=False,
top_k=40, top_k=40,
top_p=0.8, top_p=0.8,
temperature=0.8, temperature=0.8,
repetition_penalty=1.0, repetition_penalty=1.0,
ignore_eos=False, ignore_eos=False,
do_preprocess=True, do_preprocess=True,
): **kwargs):
"""Generate responses. """Generate responses.
Args: Args:
......
...@@ -71,8 +71,6 @@ class ChatCompletionRequest(BaseModel): ...@@ -71,8 +71,6 @@ class ChatCompletionRequest(BaseModel):
# additional argument of lmdeploy # additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1 session_id: Optional[int] = -1
renew_session: Optional[
bool] = False # lagecy and useless, will be removed
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment