Unverified Commit 0ed1e4d4 authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

Improve postprocessing in TIS serving by applying Incremental de-tokenizing (#197)

* change to incremental decoding

* update
parent 18c386d9
...@@ -55,7 +55,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue): ...@@ -55,7 +55,7 @@ def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
def warmup(tritonserver_addr: str, def warmup(tritonserver_addr: str,
concurrency: int, concurrency: int,
output_seqlen: int, output_seqlen: int,
warmup_round: int = 4): warmup_round: int = 1):
print('start to warmup ...') print('start to warmup ...')
def _infer(_chatbot, session_id): def _infer(_chatbot, session_id):
...@@ -87,7 +87,7 @@ def warmup(tritonserver_addr: str, ...@@ -87,7 +87,7 @@ def warmup(tritonserver_addr: str,
def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
test_round: int, session_len: int): session_len: int):
start = time.perf_counter() start = time.perf_counter()
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
...@@ -119,14 +119,12 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int, ...@@ -119,14 +119,12 @@ def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
if samples > 0: if samples > 0:
filtered_dataset = random.sample(filtered_dataset, samples) filtered_dataset = random.sample(filtered_dataset, samples)
filtered_dataset *= test_round
random.shuffle(filtered_dataset)
que = mp.Queue() que = mp.Queue()
for data in filtered_dataset: for data in filtered_dataset:
que.put(data) que.put(data)
print(f'elapsed time for filtering: ' print(f'elapsed time for filtering: '
f'{round(time.perf_counter() - start, 2)} s') f'{round(time.perf_counter() - start, 2)} s')
return que return que, len(filtered_dataset)
def main(tritonserver_addr: str, def main(tritonserver_addr: str,
...@@ -134,11 +132,10 @@ def main(tritonserver_addr: str, ...@@ -134,11 +132,10 @@ def main(tritonserver_addr: str,
dataset_path: str, dataset_path: str,
concurrency: int = 1, concurrency: int = 1,
session_len: int = 2048, session_len: int = 2048,
samples: int = 1000, samples: int = 1000):
test_round: int = 1):
warmup(tritonserver_addr, concurrency, session_len - 1) warmup(tritonserver_addr, concurrency, session_len - 1)
req_que = read_dataset(tokenizer_path, dataset_path, samples, test_round, req_que, n_req = read_dataset(tokenizer_path, dataset_path, samples,
session_len) session_len)
res_que = mp.Queue() res_que = mp.Queue()
procs = [] procs = []
_start = time.perf_counter() _start = time.perf_counter()
...@@ -168,13 +165,17 @@ def main(tritonserver_addr: str, ...@@ -168,13 +165,17 @@ def main(tritonserver_addr: str,
first_token_latency_min = np.min(stats[:, 0], axis=0) first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0) first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0) first_token_latency_ave = np.mean(stats[:, 0], axis=0)
throughput = np.sum(stats[:, 1], axis=0) / elapsed_time token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time
req_throughput = n_req / elapsed_time
print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n' print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.2f}s\n' f'elapsed_time: {elapsed_time:.2f}s\n'
f'first_token latency(min, max, ave): ' f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, ' f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
f'{first_token_latency_ave:.2f}s\n' f'{first_token_latency_ave:.2f}s\n'
f'throughput: {throughput:.2f} token/s\n{"-" * 50}') f'token throughput: {token_throughput:.2f} token/s\n'
f'req throughput: {req_throughput} req/s\n'
f'{"-" * 50}\n')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,6 @@ def main(tritonserver_addr: str, session_id: int = 1): ...@@ -20,7 +20,6 @@ def main(tritonserver_addr: str, session_id: int = 1):
Args: Args:
tritonserver_addr (str): the address in format "ip:port" of tritonserver_addr (str): the address in format "ip:port" of
triton inference server triton inference server
model_name (str): the name of the deployed model
session_id (int): the identical id of a session session_id (int): the identical id of a session
""" """
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
......
...@@ -26,6 +26,7 @@ class Session: ...@@ -26,6 +26,7 @@ class Session:
request_id: str = '' request_id: str = ''
histories: str = '' # history conversations of the session histories: str = '' # history conversations of the session
sequence_length: int = 0 # the total generated token number in the session sequence_length: int = 0 # the total generated token number in the session
sequence_offset: int = 0 # the new generated token offset in the session
prompt: str = '' prompt: str = ''
response: str = '' response: str = ''
status: int = None # status of the session status: int = None # status of the session
...@@ -539,14 +540,15 @@ class Chatbot: ...@@ -539,14 +540,15 @@ class Chatbot:
Yields: Yields:
tuple: status, text, generated token number tuple: status, text, generated token number
""" """
offset = n_input_token + preseq_length session.sequence_offset = n_input_token + preseq_length
sentinel = n_input_token + preseq_length
status, res, n_token = None, '', 0 status, res, n_token = None, '', 0
while True: while True:
result = res_queue.get() result = res_queue.get()
if result is None: if result is None:
status = StatusCode.TRITON_STREAM_END status = StatusCode.TRITON_STREAM_END
res = session.response res = session.response
n_token = session.sequence_length - offset n_token = session.sequence_length - sentinel
session.status = StatusCode.TRITON_STREAM_END session.status = StatusCode.TRITON_STREAM_END
break break
if 'errcode' in result: if 'errcode' in result:
...@@ -569,30 +571,31 @@ class Chatbot: ...@@ -569,30 +571,31 @@ class Chatbot:
output_ids = result.as_numpy('output_ids') output_ids = result.as_numpy('output_ids')
session.sequence_length = sequence_length.squeeze() session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - offset new_token_length = sequence_length - session.sequence_offset
last_token_id = output_ids[-1][-1][session.sequence_length - 1] last_token_id = output_ids[-1][-1][session.sequence_length - 1]
if last_token_id == eos_id: if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1 session.sequence_length = session.sequence_length - 1
sequence_length = sequence_length - 1 new_token_length = new_token_length - 1
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
sequence_length = sequence_length.reshape( new_token_length = new_token_length.reshape(
(1, sequence_length.shape[-1])) (1, new_token_length.shape[-1]))
if profile_generation: if profile_generation:
yield (StatusCode.TRITON_STREAM_ING, yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling ' 'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze()) 'token generation', new_token_length.squeeze())
continue continue
output_str = postprocess(output_ids[:, :, offset:], output_str = postprocess(
sequence_length) output_ids[:, :, session.sequence_offset:],
new_token_length)
session.sequence_offset = session.sequence_length
text = output_str[0].decode() text = output_str[0].decode()
if display: if display:
new_text = text[len(session.response):] print(text, end='', flush=True)
print(new_text, end='', flush=True) session.response += text
session.response = text
yield (StatusCode.TRITON_STREAM_ING, session.response, yield (StatusCode.TRITON_STREAM_ING, session.response,
sequence_length.squeeze()) session.sequence_offset - sentinel)
except Exception as e: except Exception as e:
logger.error(f'catch exception: {e}') logger.error(f'catch exception: {e}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment