# Copyright (c) OpenMMLab. All rights reserved. import json import logging import queue import random import threading from dataclasses import dataclass from enum import Enum from functools import partial from typing import List, Union import google.protobuf.json_format import mmengine import numpy as np import tritonclient.grpc as grpcclient from tritonclient.grpc.service_pb2 import ModelInferResponse from llmdeploy.model import MODELS from llmdeploy.serve.fastertransformer.utils import (Postprocessor, Preprocessor, prepare_tensor) @dataclass class Session: session_id: Union[int, str] request_id: str = '' histories: str = '' # history conversations of the session sequence_length: int = 0 # the total generated token number in the session prompt: str = '' response: str = '' status: int = None # status of the session class StatusCode(Enum): TRITON_STREAM_END = 0 # end of streaming TRITON_STREAM_ING = 1 # response is in streaming TRITON_SERVER_ERR = -1 # triton server's error TRITON_SESSION_CLOSED = -2 # session has been closed TRITON_SESSION_OUT_OF_LIMIT = -3 # request length out of limit TRITON_SESSION_INVALID_ARG = -4 # invalid argument def stream_callback(que, result, error): if error: print(error) que.put(dict(errcode=StatusCode.TRITON_SERVER_ERR, errmsg=f'{error}')) else: que.put(result.get_response(as_json=True)) def get_logger(log_file=None, log_level=logging.INFO): from .utils import get_logger logger = get_logger('service.ft', log_file=log_file, log_level=log_level) return logger class Chatbot: """Chatbot for LLaMA series models with fastertransformer as inference engine. Args: tritonserver_addr (str): communicating address ':' of triton inference server model_name (str): name of the to-be-deployed mode session_len (int): the maximum context length of the model top_p (float): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. top_k (int): The number of the highest probability vocabulary tokens to keep for top-k-filtering temperature (float): to modulate the next token probability repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty log_level (int): the level of the log display (bool): display the generated text on consolo or not profile_generation (bool): profile token generation or not """ def __init__(self, tritonserver_addr: str, model_name: str, session_len: int = 2048, top_p: float = 1.0, top_k: int = 40, temperature: float = 1.0, repetition_penalty: float = 1.0, ignore_eos: bool = False, log_level: int = logging.INFO, display: bool = False, profile_generation: bool = False, profile_serving: bool = False): assert model_name in MODELS.module_dict.keys(), \ f"'{model_name}' is not supported. " \ f'The supported models are: {MODELS.module_dict.keys()}' self.model_name = model_name self.model = MODELS.get(self.model_name)() self._session = None self.tritonserver_addr = tritonserver_addr self.preprocess = Preprocessor(tritonserver_addr) self.postprocess = Postprocessor(tritonserver_addr) self.bos_id = self._get_bos() self.eos_id = self._get_eos() stop_words = self._stop_words(self.model.stop_words) bad_words = None if ignore_eos: stop_words = None bad_words = np.array([[[self.eos_id], [1]]], dtype=np.int32) self.cfg = mmengine.Config( dict( session_len=session_len, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, stop_words=stop_words, bad_words=bad_words)) self.log_level = log_level self.display = display self.profile_generation = profile_generation self.profile_serving = profile_serving def stream_infer(self, session_id: int, prompt: str, request_id: str = '', request_output_len: int = None, sequence_start: bool = False, sequence_end: bool = False, *args, **kwargs): """Start a new round conversion of a session. Args: session_id (int): the identical id of a session prompt (str): user's prompt in this round conversation request_id (str): the identical id of this round conversation request_output_len (int): the expected generated token numbers sequence_start (bool): start flag of a session sequence_end (bool): end flag of a session Returns: iterator: The generated content by chatbot """ assert isinstance(session_id, int), \ f'INT session id is required, but got {type(session_id)}' logger = get_logger(log_level=self.log_level) logger.info(f'session {session_id}, request_id {request_id}, ' f'request_output_len {request_output_len}') if self._session is None: sequence_start = True self._session = Session(session_id=session_id) elif self._session.status == 0: logger.error(f'session {session_id} has been ended. Please set ' f'`sequence_start` be True if you want to restart it') yield StatusCode.TRITON_SESSION_CLOSED, '', 0 return self._session.status = 1 self._session.request_id = request_id self._session.response = '' self._session.prompt = self._get_prompt(prompt, sequence_start) for status, res, tokens in self._stream_infer(self._session, self._session.prompt, request_output_len, sequence_start, sequence_end): yield status, res, tokens if status.value < 0: return self._session.histories = \ self._session.histories + self._session.prompt + \ self._session.response def end(self, session_id: int, *args, **kwargs): """end a session. Triton inference server will release the session's occupied resource when it is ended. Args: session_id (int): the identical id of a session Returns: int: 0: success, -1: session not found """ assert isinstance(session_id, int), \ f'INT session id is required, but got {type(session_id)}' logger = get_logger(log_level=self.log_level) logger.info(f'end session: {session_id}') if self._session is None: logger.error( f"session {session_id} doesn't exist. It cannot be ended") return StatusCode.TRITON_SESSION_INVALID_ARG if self._session.session_id != session_id: logger.error(f'you cannot end session {session_id}, because this ' f'session is {self._session.session_id}') return StatusCode.TRITON_SESSION_INVALID_ARG if self._session.status == 0: logger.warning(f'session {session_id} has already been ended') return StatusCode.TRITON_SESSION_CLOSED self._session.status = 0 for status, _, _ in self._stream_infer( self._session, prompt='', request_output_len=0, sequence_start=False, sequence_end=True): if status != StatusCode.TRITON_STREAM_END: return status self.reset_session() return StatusCode.TRITON_STREAM_END def cancel(self, session_id: int, *args, **kwargs): """Cancel the session during generating tokens. Args: session_id (int): the identical id of a session Returns: int: 0: success, -1: session not found """ assert isinstance(session_id, int), \ f'INT session id is required, but got {type(session_id)}' logger = get_logger(log_level=self.log_level) logger.info(f'cancel session: {session_id}') if self._session is None: logger.error( f"session {session_id} doesn't exist. It cannot be cancelled") return StatusCode.TRITON_SESSION_INVALID_ARG if self._session.session_id != session_id: logger.error( f'you cannot cancel session {session_id}, because this ' f'session is {self._session.session_id}') return StatusCode.TRITON_SESSION_INVALID_ARG if self._session.status == 0: logger.error(f'session {session_id} has already been ended. ' f'It cannot be cancelled') return StatusCode.TRITON_SESSION_CLOSED prev_session = self._session for status, res, _ in self._stream_infer( self._session, prompt='', request_output_len=0, sequence_start=False, sequence_end=False, cancel=True): if status.value < 0: break if status == StatusCode.TRITON_STREAM_END: logger.info(f'cancel session {session_id} successfully') if prev_session.histories: logger.warn(f'TODO: start to recover session {session_id}') else: logger.info(f'cancel session {session_id} failed: {res}') return status def reset_session(self): self._session = None def _get_bos(self): token_ids, _ = self.preprocess('') return token_ids[0][0] def _get_eos(self): token_ids, _ = self.preprocess('') return token_ids[0][0] def _stop_words(self, stop_words: List[int]): if stop_words is None: return None assert isinstance(stop_words, List) and \ all(isinstance(elem, int) for elem in stop_words), \ f'stop_words must be a list but got {type(stop_words)}' # each id in stop_words represents a stop word # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for # detailed explanation about fastertransformer's stop_words stop_word_offsets = range(1, len(stop_words) + 1) stop_words = np.array([[stop_words, stop_word_offsets]]).astype(np.int32) return stop_words def _get_prompt(self, prompt: str, sequence_start: bool): if self.profile_generation or self.profile_serving: return prompt return self.model.get_prompt(prompt, sequence_start) def _stream_infer(self, session: Session, prompt: str, request_output_len: int = 512, sequence_start: bool = True, sequence_end: bool = False, cancel: bool = False): logger = get_logger(log_level=self.log_level) logger.info(f'session {session.session_id}, ' f'request id {session.request_id}, ' f'request_output_len {request_output_len}, ' f'start {sequence_start}, ' f'end {sequence_end}, cancel {cancel}') assert request_output_len is None or \ isinstance(request_output_len, int), \ f'request_output_len is supposed to be None or int, ' \ f'but got {type(request_output_len)}' if sequence_start: logger.info(f'session {session.session_id}, clear history since ' f'sequence_start is True') session.histories = '' session.sequence_length = 0 input_ids, input_lengths = self.preprocess(prompt) input_tokens = input_lengths.squeeze() if self.profile_generation: yield StatusCode.TRITON_STREAM_ING, \ 'ignore preprocessing during profiling generation', 0 if request_output_len is None: request_output_len = max( 128, self.cfg.session_len - session.sequence_length - input_tokens) if input_tokens + request_output_len + \ session.sequence_length > self.cfg.session_len: errmsg = f'session {session.session_id}, ' \ f'out of max sequence length {self.cfg.session_len}, ' \ f'#input tokens {input_tokens}, ' \ f'history tokens {session.sequence_length}, ' \ f'request length {request_output_len}' yield StatusCode.TRITON_SESSION_OUT_OF_LIMIT, errmsg, 0 return logger.info(f'session {session.session_id}, ' f'max length: {self.cfg.session_len}, ' f'input tokens: {input_tokens}, ' f'request tokens: {request_output_len}, ' f'history tokens: {session.sequence_length}') preseq_length = session.sequence_length session.response = '' que = queue.Queue() producer = threading.Thread( target=self._stream_producer, args=(self.tritonserver_addr, session, que, self.cfg, input_ids, input_lengths, request_output_len, sequence_start, sequence_end, preseq_length, cancel)) producer.start() for state, res, tokens in self.stream_consumer( self.postprocess, que, session, preseq_length, cancel, logger, self.display, self.profile_generation, self.eos_id): if state.value < 0: yield state, res, 0 else: yield state, res, tokens - input_tokens producer.join() self._session = que.get() curseq_length = self._session.sequence_length logger.info(f'session {session.session_id}, pre seq_len ' f'{preseq_length}, cur seq_len {curseq_length}, ' f'diff {curseq_length - preseq_length}') @staticmethod def _stream_producer(tritonserver_addr, session, que, cfg, input_ids, input_lengths, request_output_len, sequence_start, sequence_end, preseq_length, cancel): request_output_len = np.full(input_lengths.shape, request_output_len).astype(np.uint32) callback = partial(stream_callback, que) with grpcclient.InferenceServerClient(tritonserver_addr) as client: inputs = [ prepare_tensor('input_ids', input_ids), prepare_tensor('input_lengths', input_lengths), prepare_tensor('request_output_len', request_output_len), prepare_tensor('runtime_top_k', cfg.top_k * np.ones((1, 1), dtype=np.uint32)), prepare_tensor('runtime_top_p', cfg.top_p * np.ones((1, 1), dtype=np.float32)), prepare_tensor( 'temperature', cfg.temperature * np.ones((1, 1), dtype=np.float32)), prepare_tensor( 'repetition_penalty', cfg.repetition_penalty * np.ones( (1, 1), dtype=np.float32)), prepare_tensor('step', preseq_length * np.ones((1, 1), dtype=np.int32)) ] if cfg.stop_words is not None: inputs += [prepare_tensor('stop_words_list', cfg.stop_words)] if cfg.bad_words is not None: inputs += [prepare_tensor('bad_words_list', cfg.bad_words)] inputs += [ prepare_tensor( 'session_len', cfg.session_len * np.ones([input_ids.shape[0], 1], dtype=np.uint32)), prepare_tensor('START', (1 if sequence_start else 0) * np.ones( (1, 1), dtype=np.int32)), prepare_tensor('END', (1 if sequence_end else 0) * np.ones( (1, 1), dtype=np.int32)), prepare_tensor( 'CORRID', session.session_id * np.ones((1, 1), dtype=np.uint64)), prepare_tensor('STOP', (1 if cancel else 0) * np.ones( (1, 1), dtype=np.int32)) ] if sequence_start: random_seed = random.getrandbits(64) inputs += [ prepare_tensor( 'random_seed', random_seed * np.ones((1, 1), dtype=np.uint64)) ] client.start_stream(callback) client.async_stream_infer( 'fastertransformer', inputs, sequence_id=session.session_id, request_id=session.request_id, sequence_start=sequence_start, sequence_end=sequence_end) que.put(None) @staticmethod def stream_consumer(postprocess, res_queue, session, preseq_length, cancel, logger, display, profile_generation, eos_id): while True: result = res_queue.get() if result is None: yield StatusCode.TRITON_STREAM_END, \ session.response[len(session.prompt):], \ session.sequence_length - preseq_length break if 'errcode' in result: logger.error(f'got error from fastertransformer, code ' f"{result['errcode']}, {result['errmsg']}, " f'token {session.sequence_length}') session.sequence_length = preseq_length yield result['errcode'], result['errmsg'], 0 break if cancel: continue try: message = ModelInferResponse() google.protobuf.json_format.Parse(json.dumps(result), message) result = grpcclient.InferResult(message) sequence_length = result.as_numpy('sequence_length') output_ids = result.as_numpy('output_ids') session.sequence_length = sequence_length.squeeze() sequence_length = sequence_length - preseq_length last_token_id = output_ids[-1][-1][session.sequence_length - 1] if last_token_id == eos_id: session.sequence_length = session.sequence_length - 1 sequence_length = sequence_length - 1 output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) sequence_length = sequence_length.reshape( (1, sequence_length.shape[-1])) if profile_generation: yield (StatusCode.TRITON_STREAM_ING, 'postprocessing is ignored during profiling ' 'token generation', sequence_length.squeeze()) continue output_str = postprocess(output_ids[:, :, preseq_length:], sequence_length) text = output_str[0].decode() if display: new_text = text[len(session.response):] print(new_text, end='', flush=True) session.response = text if len(session.response) > len(session.prompt): yield (StatusCode.TRITON_STREAM_ING, session.response[len(session.prompt):], sequence_length.squeeze()) except Exception as e: logger.error(f'catch exception: {e}') session.response = session.response[len(session.prompt):] # put session back to queue so that `_stream_infer` can update it in # `self.sessions` while not res_queue.empty(): res_queue.get() res_queue.put(session) if display: print('\n')