Unverified Commit 23c05372 authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

Add profile (#15)

* remove constraints on model name

* remove duplicate model converter

* add profile

* get eos and bos from server

* update stop_words

* update sequence_length when the last generated token is eos_id

* fix

* fix

* check-in models

* valicate model_name

* make stop_words as property

* debug profiling

* better stats

* fix assistant reponse

* update profile serving

* update

* update
parent 2700abb3
import multiprocessing as mp
import time
import fire
import numpy as np
from llmdeploy.serve.fastertransformer.chatbot import Chatbot
def infer(chatbot, session_id: int, prompt: str, output_seqlen: int,
test_round: int, que: mp.Queue):
stats = []
for i in range(test_round):
timestamps = []
tokens = []
start = time.perf_counter()
for status, res, token in chatbot.stream_infer(
session_id,
prompt,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
timestamps.append(time.perf_counter())
tokens.append(token)
first_token_latency = timestamps[0] - start
token_latency = timestamps[-1] - timestamps[0]
token = tokens[-1] - tokens[0]
stats.append([first_token_latency, token, token_latency])
chatbot.reset_session()
que.put((session_id, stats))
def warmup(tritonserver_addr: str,
model_name: str,
concurrency: int,
session_len: int,
output_seqlen: int,
warmup_round: int = 4):
print('start to warmup ...')
def _infer(_chatbot, session_id):
for _ in range(warmup_round):
for _, _, _ in chatbot.stream_infer(
session_id,
prompt='',
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
continue
chatbot.reset_session()
_start = time.perf_counter()
chatbots = [
Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
ignore_eos=True,
profile_generation=True) for _ in range(concurrency)
]
procs = []
for i, chatbot in enumerate(chatbots):
proc = mp.Process(target=_infer, args=(chatbot, i + 1))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)}s')
def main(tritonserver_addr: str,
model_name: str,
concurrency: int = 1,
session_len: int = 2048,
input_seqlen: int = 0,
output_seqlen: int = 512,
test_round: int = 10):
warmup(tritonserver_addr, model_name, concurrency, session_len,
output_seqlen)
# make up a prompt that can be tokenized into {input_seqlen} tokens
prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1)
que = mp.Queue()
procs = []
_start = time.perf_counter()
for i in range(concurrency):
chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
ignore_eos=True,
profile_generation=True)
proc = mp.Process(target=infer,
args=(chatbot, i + 1, prompt, output_seqlen,
test_round, que))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
elapsed_time = _end - _start
stats = []
while not que.empty():
session_id, _stats = que.get()
print(f'\n{"-" * 50}\n'
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(_stats)
stats = np.array(stats).reshape(-1, 3)
first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
token_latency_min = np.min(stats[:, 2], axis=0)
token_latency_max = np.max(stats[:, 2], axis=0)
token_latency_ave = np.mean(stats[:, 2], axis=0)
throughput = np.sum(stats[:, 1], axis=0) / np.sum(stats[:, 2], axis=0)
print(f'\n{"-" * 50}\ncocurrency: {concurrency}, input_tokens: '
f'{input_seqlen}, output_tokens: {output_seqlen}\n'
f'elapsed_time: {elapsed_time:.2f}s\n'
f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
f'{first_token_latency_ave:.2f}s\ntoken latency(min, max, ave): '
f'{token_latency_min:.2f}s, {token_latency_max:.2f}s, '
f'{token_latency_ave:.2f}s\n'
f'throughput: {throughput} token/s\n{"-" * 50}')
if __name__ == '__main__':
fire.Fire(main)
import json
import multiprocessing as mp
import os
import random
import time
from typing import List
import fire
import numpy as np
from sentencepiece import SentencePieceProcessor
from llmdeploy.serve.fastertransformer.chatbot import Chatbot
class Tokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
def encode(self, prompts: List):
prompts_token_ids = self.sp_model.Encode(prompts,
add_bos=False,
add_eos=False)
return [len(token_ids) for token_ids in prompts_token_ids]
def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
stats = []
while not req_que.empty():
prompt, input_seqlen, output_seqlen = req_que.get()
print(f'request info: session {session_id}, '
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}')
timestamps = []
tokens = []
start = time.perf_counter()
for status, res, token in chatbot.stream_infer(
session_id,
prompt,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
timestamps.append(time.perf_counter())
tokens.append(token)
chatbot.reset_session()
first_token_latency = timestamps[1] - start
token_latency = timestamps[-1] - timestamps[0]
token = tokens[-1] - tokens[0]
stats.append([first_token_latency, token, token_latency])
res_que.put((session_id, stats))
def warmup(tritonserver_addr: str,
model_name: str,
concurrency: int,
session_len: int,
output_seqlen: int,
warmup_round: int = 4):
print('start to warmup ...')
def _infer(_chatbot, session_id):
for _ in range(warmup_round):
for _, _, _ in chatbot.stream_infer(
session_id,
prompt='',
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
continue
chatbot.reset_session()
_start = time.perf_counter()
chatbots = [
Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
ignore_eos=True,
profile_generation=True) for _ in range(concurrency)
]
procs = []
for i, chatbot in enumerate(chatbots):
proc = mp.Process(target=_infer, args=(chatbot, i + 1))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)} s')
def read_dataset(tritonserver_addr, tokenizer_path: str, dataset_path: str,
samples: int, test_round: int, session_len: int):
start = time.perf_counter()
with open(dataset_path) as f:
dataset = json.load(f)
dataset = [data for data in dataset if len(data['conversations']) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data['conversations'][0]['value'],
data['conversations'][1]['value']) for data in dataset]
prompts = [prompt for prompt, _ in dataset]
completions = [completion for _, completion in dataset]
print(f'elapsed time for read data: '
f'{round(time.perf_counter() - start, 2)} s')
start = time.perf_counter()
tokenizer = Tokenizer(tokenizer_path)
prompts_token_lens = tokenizer.encode(prompts)
completions_token_lens = tokenizer.encode(completions)
print(f'elapsed time for tokenization: '
f'{round(time.perf_counter() - start, 2)} s')
start = time.perf_counter()
filtered_dataset = []
for (prompt, _), input_len, output_len in zip(dataset, prompts_token_lens,
completions_token_lens):
if input_len + output_len > session_len:
# ignore too long conversation
continue
filtered_dataset.append([prompt, input_len, output_len])
if samples > 0:
filtered_dataset = random.sample(filtered_dataset, samples)
filtered_dataset *= test_round
random.shuffle(filtered_dataset)
que = mp.Queue()
for data in filtered_dataset:
que.put(data)
print(f'elapsed time for filtering: '
f'{round(time.perf_counter() - start, 2)} s')
return que
def main(tritonserver_addr: str,
model_name: str,
tokenizer_path: str,
dataset_path: str,
concurrency: int = 1,
session_len: int = 2048,
samples: int = 2000,
test_round: int = 1):
warmup(tritonserver_addr, model_name, concurrency, session_len,
session_len)
req_que = read_dataset(tritonserver_addr, tokenizer_path, dataset_path,
samples, test_round, session_len)
res_que = mp.Queue()
procs = []
_start = time.perf_counter()
for i in range(concurrency):
chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
display=False,
profile_serving=True,
ignore_eos=True)
proc = mp.Process(target=infer,
args=(chatbot, i + 1, req_que, res_que))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
elapsed_time = _end - _start
stats = []
while not res_que.empty():
session_id, _stats = res_que.get()
print(f'\n{"-" * 50}\n'
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(_stats)
stats = np.array(stats).reshape(-1, 3)
first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
throughput = np.sum(stats[:, 1], axis=0) / elapsed_time
print(f'\n{"-" * 50}\ncocurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.2f}s\n'
f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
f'{first_token_latency_ave:.2f}s\n'
f'throughput: {throughput:.2f} token/s\n{"-" * 50}')
if __name__ == '__main__':
fire.Fire(main)
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import Registry
MODELS = Registry('model', locations=['llmdeploy.model'])
@MODELS.register_module(name='vicuna')
class Vicuna:
def __init__(self):
self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501
self.user = 'USER'
self.assistant = 'ASSISTANT'
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.system} {self.user}: {prompt} {self.assistant}:'
else:
return f'</s>{self.user}: {prompt} {self.assistant}:'
@property
def stop_words(self):
return None
@MODELS.register_module(name='puyu')
class Puyu:
def __init__(self):
self.system = """meta instruction
You are an AI assistant whose name is InternLM (书生·浦语).
- 书生·浦语 is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
- 书生·浦语 can understand and communicate fluently in the language chosen by the user such as English and 中文.
conversation""" # noqa: E501
self.user = '<|Human|>'
self.eou = 'െ'
self.assistant = '<|Assistant|>'
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.system}\n' \
f'{self.user}:{prompt}{self.eou}\n' \
f'{self.assistant}:'
else:
return f'\n{self.user}:{prompt}{self.eou}\n{self.assistant}:'
@property
def stop_words(self):
return [45623]
def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
model = MODELS.get('vicuna--1')()
prompt = model.get_prompt(prompt='hi')
print(prompt)
if __name__ == '__main__':
import fire
fire.Fire(main)
...@@ -12,9 +12,9 @@ def input_prompt(): ...@@ -12,9 +12,9 @@ def input_prompt():
return '\n'.join(iter(input, sentinel)) return '\n'.join(iter(input, sentinel))
def main(triton_server_addr: str, model_name: str, session_id: int): def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
chatbot = Chatbot(triton_server_addr, chatbot = Chatbot(tritonserver_addr,
model_name, model_name,
log_level=log_level, log_level=log_level,
display=True) display=True)
......
...@@ -15,6 +15,7 @@ import numpy as np ...@@ -15,6 +15,7 @@ import numpy as np
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from tritonclient.grpc.service_pb2 import ModelInferResponse from tritonclient.grpc.service_pb2 import ModelInferResponse
from llmdeploy.model import MODELS
from llmdeploy.serve.fastertransformer.utils import (Postprocessor, from llmdeploy.serve.fastertransformer.utils import (Postprocessor,
Preprocessor, Preprocessor,
prepare_tensor) prepare_tensor)
...@@ -24,9 +25,9 @@ from llmdeploy.serve.fastertransformer.utils import (Postprocessor, ...@@ -24,9 +25,9 @@ from llmdeploy.serve.fastertransformer.utils import (Postprocessor,
class Session: class Session:
session_id: Union[int, str] session_id: Union[int, str]
request_id: str = '' request_id: str = ''
prev: str = '' # history of the session in text format histories: str = '' # history conversations of the session
round_prev: str = '' # previous generated text in the current round
sequence_length: int = 0 # the total generated token number in the session sequence_length: int = 0 # the total generated token number in the session
prompt: str = ''
response: str = '' response: str = ''
status: int = None # status of the session status: int = None # status of the session
...@@ -71,11 +72,9 @@ class Chatbot: ...@@ -71,11 +72,9 @@ class Chatbot:
temperature (float): to modulate the next token probability temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
stop_words (list): List of token ids that stops the generation
bad_words (list): List of token ids that are not allowed to be
generated.
log_level (int): the level of the log log_level (int): the level of the log
display (bool): display the generated text on consolo or not display (bool): display the generated text on consolo or not
profile_generation (bool): profile token generation or not
""" """
def __init__(self, def __init__(self,
...@@ -86,18 +85,27 @@ class Chatbot: ...@@ -86,18 +85,27 @@ class Chatbot:
top_k: int = 40, top_k: int = 40,
temperature: float = 1.0, temperature: float = 1.0,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
stop_words: List = None, ignore_eos: bool = False,
bad_words: List = None,
log_level: int = logging.INFO, log_level: int = logging.INFO,
display: bool = False): display: bool = False,
profile_generation: bool = False,
profile_serving: bool = False):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
self.model_name = model_name
self.model = MODELS.get(self.model_name)()
self._session = None self._session = None
self.tritonserver_addr = tritonserver_addr self.tritonserver_addr = tritonserver_addr
self.model_name = model_name self.preprocess = Preprocessor(tritonserver_addr)
if stop_words is not None: self.postprocess = Postprocessor(tritonserver_addr)
stop_words = np.array(stop_words, dtype=np.int32) self.bos_id = self._get_bos()
if bad_words is not None: self.eos_id = self._get_eos()
bad_words = np.array(bad_words, dtype=np.int32) stop_words = self._stop_words(self.model.stop_words)
bad_words = None
if ignore_eos:
stop_words = None
bad_words = np.array([[[self.eos_id], [1]]], dtype=np.int32)
self.cfg = mmengine.Config( self.cfg = mmengine.Config(
dict(session_len=session_len, dict(session_len=session_len,
top_p=top_p, top_p=top_p,
...@@ -106,10 +114,10 @@ class Chatbot: ...@@ -106,10 +114,10 @@ class Chatbot:
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
stop_words=stop_words, stop_words=stop_words,
bad_words=bad_words)) bad_words=bad_words))
self.preprocess = Preprocessor(tritonserver_addr)
self.postprocess = Postprocessor(tritonserver_addr)
self.log_level = log_level self.log_level = log_level
self.display = display self.display = display
self.profile_generation = profile_generation
self.profile_serving = profile_serving
def stream_infer(self, def stream_infer(self,
session_id: int, session_id: int,
...@@ -152,13 +160,16 @@ class Chatbot: ...@@ -152,13 +160,16 @@ class Chatbot:
self._session.request_id = request_id self._session.request_id = request_id
self._session.response = '' self._session.response = ''
prompt = self._get_prompt(prompt, sequence_start) self._session.prompt = self._get_prompt(prompt, sequence_start)
for status, res, tokens in self._stream_infer(self._session, prompt, for status, res, tokens in self._stream_infer(self._session,
self._session.prompt,
request_output_len, request_output_len,
sequence_start, sequence_start,
sequence_end): sequence_end):
yield status, res, tokens yield status, res, tokens
self._session.prev = self._session.prev + self._session.round_prev self._session.histories = \
self._session.histories + self._session.prompt + \
self._session.response
def end(self, session_id: int, *args, **kwargs): def end(self, session_id: int, *args, **kwargs):
"""end a session. Triton inference server will release the session's """end a session. Triton inference server will release the session's
...@@ -237,20 +248,41 @@ class Chatbot: ...@@ -237,20 +248,41 @@ class Chatbot:
break break
if status == StatusCode.TRITON_STREAM_END: if status == StatusCode.TRITON_STREAM_END:
logger.info(f'cancel session {session_id} successfully') logger.info(f'cancel session {session_id} successfully')
if prev_session.prev: if prev_session.histories:
logger.warn(f'TODO: start to recover session {session_id}') logger.warn(f'TODO: start to recover session {session_id}')
else: else:
logger.info(f'cancel session {session_id} failed: {res}') logger.info(f'cancel session {session_id} failed: {res}')
return status return status
def reset_session(self):
self._session = None
def _get_bos(self):
token_ids, _ = self.preprocess('<BOS>')
return token_ids[0][0]
def _get_eos(self):
token_ids, _ = self.preprocess('<EOS>')
return token_ids[0][0]
def _stop_words(self, stop_words: List[int]):
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
# each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about fastertransformer's stop_words
stop_word_offsets = range(1, len(stop_words) + 1)
stop_words = np.array([[stop_words,
stop_word_offsets]]).astype(np.int32)
return stop_words
def _get_prompt(self, prompt: str, sequence_start: bool): def _get_prompt(self, prompt: str, sequence_start: bool):
if self.model_name == 'vicuna': if self.profile_generation or self.profile_serving:
if sequence_start:
return f'USER: {prompt} ASSISTANT:'
else:
return f'</s>USER: {prompt} ASSISTANT:'
else:
return prompt return prompt
return self.model.get_prompt(prompt, sequence_start)
def _stream_infer(self, def _stream_infer(self,
session: Session, session: Session,
...@@ -271,9 +303,17 @@ class Chatbot: ...@@ -271,9 +303,17 @@ class Chatbot:
f'request_output_len is supposed to be None or int, ' \ f'request_output_len is supposed to be None or int, ' \
f'but got {type(request_output_len)}' f'but got {type(request_output_len)}'
if sequence_start:
logger.info(f'session {session.session_id}, clear history since '
f'sequence_start is True')
session.histories = ''
session.sequence_length = 0
input_ids, input_lengths = self.preprocess(prompt) input_ids, input_lengths = self.preprocess(prompt)
input_tokens = input_lengths.squeeze() input_tokens = input_lengths.squeeze()
if self.profile_generation:
yield StatusCode.TRITON_STREAM_ING, \
'ignore preprocessing during profiling generation', 0
if request_output_len is None: if request_output_len is None:
request_output_len = max( request_output_len = max(
128, 128,
...@@ -293,7 +333,7 @@ class Chatbot: ...@@ -293,7 +333,7 @@ class Chatbot:
f'history tokens: {session.sequence_length}') f'history tokens: {session.sequence_length}')
preseq_length = session.sequence_length preseq_length = session.sequence_length
session.round_prev = '' session.response = ''
que = queue.Queue() que = queue.Queue()
producer = threading.Thread(target=self._stream_producer, producer = threading.Thread(target=self._stream_producer,
...@@ -302,10 +342,9 @@ class Chatbot: ...@@ -302,10 +342,9 @@ class Chatbot:
request_output_len, sequence_start, request_output_len, sequence_start,
sequence_end, preseq_length, cancel)) sequence_end, preseq_length, cancel))
producer.start() producer.start()
for state, res, tokens in self.stream_consumer(self.postprocess, que, for state, res, tokens in self.stream_consumer(
session, preseq_length, self.postprocess, que, session, preseq_length, cancel, logger,
cancel, logger, self.display, self.profile_generation, self.eos_id):
self.display):
if state.value < 0: if state.value < 0:
yield state, res, 0 yield state, res, 0
else: else:
...@@ -382,21 +421,13 @@ class Chatbot: ...@@ -382,21 +421,13 @@ class Chatbot:
@staticmethod @staticmethod
def stream_consumer(postprocess, res_queue, session, preseq_length, cancel, def stream_consumer(postprocess, res_queue, session, preseq_length, cancel,
logger, display): logger, display, profile_generation, eos_id):
def process_response(res):
if session.ai_says is None:
return res, True
index = res.find(session.ai_says)
if index == -1:
return res, False
res = res[index + len(session.ai_says):].replace(session.eoa, '')
return res, True
while True: while True:
result = res_queue.get() result = res_queue.get()
if result is None: if result is None:
yield StatusCode.TRITON_STREAM_END, session.response, \ yield StatusCode.TRITON_STREAM_END, \
session.response[len(session.prompt):], \
session.sequence_length - preseq_length session.sequence_length - preseq_length
break break
if 'errcode' in result: if 'errcode' in result:
...@@ -417,22 +448,35 @@ class Chatbot: ...@@ -417,22 +448,35 @@ class Chatbot:
session.sequence_length = sequence_length.squeeze() session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - preseq_length sequence_length = sequence_length - preseq_length
last_token_id = output_ids[-1][-1][session.sequence_length - 1]
if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1
sequence_length = sequence_length - 1
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
sequence_length = sequence_length.reshape( sequence_length = sequence_length.reshape(
(1, sequence_length.shape[-1])) (1, sequence_length.shape[-1]))
if profile_generation:
yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze())
continue
output_str = postprocess(output_ids[:, :, preseq_length:], output_str = postprocess(output_ids[:, :, preseq_length:],
sequence_length) sequence_length)
text = output_str[0].decode() text = output_str[0].decode()
if display: if display:
new_text = text[len(session.round_prev):] new_text = text[len(session.response):]
print(new_text, end='', flush=True) print(new_text, end='', flush=True)
session.round_prev = text session.response = text
yield (StatusCode.TRITON_STREAM_ING, session.response, if len(session.response) > len(session.prompt):
sequence_length.squeeze()) yield (StatusCode.TRITON_STREAM_ING,
session.response[len(session.prompt):],
sequence_length.squeeze())
except Exception as e: except Exception as e:
logger.error(f'catch exception: {e}') logger.error(f'catch exception: {e}')
session.response = session.response[len(session.prompt):]
# put session back to queue so that `_stream_infer` can update it in # put session back to queue so that `_stream_infer` can update it in
# `self.sessions` # `self.sessions`
while not res_queue.empty(): while not res_queue.empty():
......
...@@ -21,10 +21,14 @@ class Tokenizer: ...@@ -21,10 +21,14 @@ class Tokenizer:
def encode(self, s: str): def encode(self, s: str):
add_bos = False add_bos = False
add_eos = False
if s.find('<BOS>') != -1: if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '') s = s.replace('<BOS>', '')
add_bos = True add_bos = True
return self.model.Encode(s, add_bos=add_bos) if s == '<EOS>':
s = ''
add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
def decode(self, t: List[int]): def decode(self, t: List[int]):
return self.model.Decode(t) return self.model.Decode(t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment