Commit d7117b95 authored by zhouxiang's avatar zhouxiang
Browse files

同步0.2.6代码

parent 5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
"""Helpers for parallel and distributed inference."""
import functools
import os
import torch
from torch.distributed import broadcast, broadcast_object_list, is_initialized
def get_local_rank():
"""Get local rank of current process.
Assume environment variable ``LOCAL_RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
""" # noqa: E501
return int(os.getenv('LOCAL_RANK', '0'))
def get_rank():
"""Get rank of current process.
Assume environment variable ``RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
""" # noqa: E501
return int(os.getenv('RANK', '0'))
def get_world_size():
"""Get rank of current process.
Assume environment variable ``WORLD_SIZE`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
""" # noqa: E501
return int(os.getenv('WORLD_SIZE', '1'))
def master_only(func):
"""Decorator to run a function only on the master process."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_initialized():
if get_rank() != 0:
return None
return func(*args, **kwargs)
return wrapper
def master_only_and_broadcast_general(func):
"""Decorator to run a function only on the master process and broadcast the
result to all processes."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_initialized():
if get_rank() == 0:
result = [func(*args, **kwargs)]
else:
result = [None]
broadcast_object_list(result, src=0)
result = result[0]
else:
result = func(*args, **kwargs)
return result
return wrapper
def master_only_and_broadcast_tensor(func):
"""Decorator to run a function only on the master process and broadcast the
result to all processes.
Note: Require CUDA tensor.
Note: Not really work because we don't know the shape aforehand,
for cpu tensors, use master_only_and_broadcast_general
"""
@functools.wraps(func)
def wrapper(*args, size, dtype, **kwargs):
if is_initialized():
if get_rank() == 0:
result = func(*args, **kwargs)
else:
result = torch.empty(size=size,
dtype=dtype,
device=get_local_rank())
broadcast(result, src=0)
# print(f'rank {get_rank()} received {result}')
else:
result = func(*args, **kwargs)
return result
return wrapper
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import time
import warnings
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .dist import get_local_rank
logger = logging.getLogger(__name__)
class LoadWoInit:
"""Context manager that disable parameter initialization."""
def __init__(self):
self.constant_ = torch.nn.init.constant_
self.zeros_ = torch.nn.init.zeros_
self.ones_ = torch.nn.init.ones_
self.uniform_ = torch.nn.init.uniform_
self.normal_ = torch.nn.init.normal_
self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_
self.kaiming_normal_ = torch.nn.init.kaiming_normal_
def __enter__(self, *args, **kwargs):
torch.nn.init.constant_ = lambda *args, **kwargs: None
torch.nn.init.zeros_ = lambda *args, **kwargs: None
torch.nn.init.ones_ = lambda *args, **kwargs: None
torch.nn.init.uniform_ = lambda *args, **kwargs: None
torch.nn.init.normal_ = lambda *args, **kwargs: None
torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None
torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
def __exit__(self, *args, **kwargs):
torch.nn.init.constant_ = self.constant_
torch.nn.init.zeros_ = self.zeros_
torch.nn.init.ones_ = self.ones_
torch.nn.init.uniform_ = self.uniform_
torch.nn.init.normal_ = self.normal_
torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_
torch.nn.init.kaiming_normal_ = self.kaiming_normal_
def init_model(model_path: str,
tokenizer_path: Optional[str] = None,
use_fast_tokenizer=True):
"""Initialize model and tokenizer from given model path.
Args:
model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
"""
start = time.monotonic()
if not tokenizer_path:
tokenizer_path = model_path
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
use_fast=use_fast_tokenizer,
trust_remote_code=True)
with LoadWoInit():
model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.float16,
trust_remote_code=True)
logger.info(f'Model loaded in {time.monotonic() - start:.1f} seconds')
logger.info(f'Model loaded from {model_path}')
logger.debug(model)
return model, tokenizer
def accel_model(model,
accel: Optional[str] = None,
gpu_id=None,
max_alloc=2048,
tp_size=1):
"""Accelerate model with given accelerator.
Note:
Currently we support only deepspeed or just no acceleration.
"""
logger.info(f'Accelerate model with {accel}')
if accel is None:
# No acceleration, just to cuda
# assume single gpu single process
# user is responsible to assign the gpu id via CUDA_VISIBLE_DEVICES # noqa: E501
gpu_id = gpu_id if gpu_id is not None else get_local_rank()
model = model.cuda(gpu_id)
elif accel.lower() == 'deepspeed':
# Use deepspeed inference inject fast kernel and/or tensor parallel
try:
import deepspeed
except ImportError as e:
raise ImportError('--accel=deepspeed is specified but '
'deepspeed is not installed.\n'
'Install with `pip install deepspeed`.') from e
config = dict(
tensor_parallel=dict(tp_size=tp_size), # Use world size in general
dtype=torch.float16,
replace_with_kernel_inject=True,
max_out_tokens=max_alloc,
)
if 'InternLM' in model.__class__.__name__:
try:
# Use customized deepspeed supporting InternLM
# https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0 (commit cdef2ce) # noqa: E501
from deepspeed.module_inject.containers.internlm import \
InternLMLayerPolicy # noqa: E501
except ImportError:
# InternLM is not officially supported by DeepSpeed
# Set replace_with_kernel_inject=False to use AutoTP
config.update({'replace_with_kernel_inject': False})
warnings.warn(
'\033[0;93m'
'Current installation of deepspeed does not '
'support InternLM. Disable kernel injection. '
'To support InternLM, install customized deepspeed with '
'`pip install git+https://github.com/wangruohui/DeepSpeed@support_internlm_0.10.0`' # noqa: E501
'\033[0m')
else:
for module in model.modules():
# Since remote code is dynamically located,
# we need to do this dynamically
if module.__class__.__name__ == 'InternLMDecoderLayer':
InternLMLayerPolicy._orig_layer_class = module.__class__ # noqa: E501
break
logger.debug(f'Using deepspeed config\n{config}')
model = deepspeed.init_inference(
model=model, # Transformers models
config=config,
)
# for k, v in model.named_parameters():
# logger.debug(f"{k}: v.device")
else:
raise ValueError(f'Unsupported accelerator {accel}.')
logger.debug(model)
return model
# Copyright (c) OpenMMLab. All rights reserved.
from .linear import WeightOnlyQLinear
__all__ = ['WeightOnlyQLinear']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Type, TypeVar
import torch
from torch import nn
try:
import awq_inference_engine
except ModuleNotFoundError:
awq_inference_engine = None
class WeightOnlyQLinear(nn.Module):
"""This class implements weight only quantization linear.
Args:
w_bit (int): number of bits for quantization.
symmetry (bool): If true, use symmetric quantization,
otherwise use asymmetric quantization.
group_size (int): size of the quantization group.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (Tensor, optional): Defaults to None.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: Optional[torch.Tensor] = True,
w_bit: int = 4,
symmetry: bool = False,
group_size: int = 128,
) -> None:
super().__init__()
if w_bit not in [2, 4, 8]:
raise NotImplementedError('Only 2,4,8 bit are supported for now.')
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
w_pack_oc = out_features // (32 // self.w_bit)
w_inc = in_features
weight = torch.zeros((w_inc, w_pack_oc), dtype=torch.int32)
self.register_buffer('qweight', weight)
if bias:
self.register_buffer('bias', torch.zeros(out_features))
else:
self.bias = None
s_inc = in_features // self.group_size
s_oc = out_features
scales = torch.zeros((s_inc, s_oc), dtype=torch.float16)
self.register_buffer('scales', scales)
if not symmetry:
z_inc = in_features // self.group_size
z_oc = out_features // (32 // self.w_bit)
zeros = torch.zeros((z_inc, z_oc), dtype=torch.int32)
self.register_buffer('qzeros', zeros)
else:
self.qzeros = None
@classmethod
def from_linear(cls: Type['WeightOnlyQLinear'],
linear: nn.Linear,
quantizer: TypeVar('Quantizer'),
awq_layout: bool = True) -> 'WeightOnlyQLinear':
"""Create a WeightOnlyQLinear object from a PyTorch Linear object.
Args:
linear (nn.Linear): PyTorch Linear object.
quantizer (Quantizer): Object that handles quantization.
awq_layout (bool): AWQ layout. Defaults to True.
Returns:
WeightOnlyQLinear: A WeightOnlyQLinear object.
"""
device = linear.weight.device
w_bit = quantizer.bits
pack_num = 32 // w_bit
if awq_layout:
assert w_bit == 4
pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
else:
pack_order = torch.arange(pack_num)
group_size = quantizer.group_size
symmetry = quantizer.symmetry
in_features = linear.in_features
out_features = linear.out_features
bias = False if linear.bias is None else True
qlinear = cls(in_features, out_features, bias, w_bit, symmetry,
group_size)
qlinear.bias = linear.bias
qparams = quantizer.calculate_qparams(linear.weight)
i32_w = quantizer.quant(linear.weight, qparams, real=True)
i32_w = i32_w.t().contiguous()
pack_int_w = torch.zeros_like(qlinear.qweight).to(device)
for col in range(pack_int_w.shape[1]):
for i in range(pack_num):
pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
pack_int_w[:, col] |= pack_int_w_col << (i * w_bit)
qlinear.qweight = pack_int_w
qlinear.scales = qparams.scales.squeeze(-1).t().contiguous()
if qparams.zero_points is not None:
zeros = qparams.zero_points.to(torch.int32).to(device)
zeros = zeros.squeeze(-1).t().contiguous()
pack_int_zeros = torch.zeros_like(qlinear.qzeros).to(device)
for col in range(pack_int_zeros.shape[1]):
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + pack_order[i]]
pack_int_zeros[:, col] |= qzero_col << (i * w_bit)
qlinear.qzeros = pack_int_zeros
qlinear.to('cpu')
return qlinear
@torch.no_grad()
def forward(self, x):
if awq_inference_engine is None:
raise RuntimeError(
'Run the following command to install '
'the kernel for 4bit inference\n\n'
'git clone https://github.com/mit-han-lab/llm-awq.git\n'
'cd awq/kernels\n'
'python setup.py install\n')
out_shape = x.shape[:-1] + (self.out_features, )
inputs = x.reshape(-1, x.shape[-1])
out = awq_inference_engine.gemm_forward_cuda(inputs.half(),
self.qweight,
self.scales.half(),
self.qzeros,
self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch
from transformers.generation.utils import ModelOutput
logger = logging.getLogger(__name__)
class BasicSessionManager:
"""Basic session manager without history."""
def prepend_history(self, input_ids):
return input_ids
def add_to_history(self, output):
pass
class BasicSessionManagerWithHistory:
"""Basic session manager with chat history.
Args:
max_session_len (int): Maximum number of tokens allowed for all chat sessions.
reduce_size (int): Number of tokens to be trimmed when reaching maximum
session length. Default: 256.
start_ids (list[int]): Sequences of ids at the start of the chat session.
sep_ids (list[int]): Sequences of ids separating chat sessions.
""" # noqa: E501
bs = 1
def __init__(self,
max_session_len=2048,
reduce_size=256,
start_ids=[1],
sep_ids=[13]) -> None:
self.start_ids = torch.tensor(start_ids, dtype=torch.long)
self.sep_ids = torch.tensor(sep_ids, dtype=torch.long)
assert self.start_ids.ndim == 1
assert self.sep_ids.ndim == 1
self.max_session_len = max(len(start_ids), max_session_len)
self.reduce_size = min(reduce_size, max_session_len - len(start_ids))
assert self.max_session_len > self.reduce_size
self.new_session()
def new_session(self):
self.history_ids = self.start_ids.repeat(self.bs, 1)
def prepend_history(self, input_ids: torch.Tensor):
"""Prepend history ids to input ids and trim if over-length."""
input_ids = input_ids.to(self.history_ids.device).long()
sep_ids = self.sep_ids.to(self.history_ids.device).long().repeat(1, 1)
input_ids = torch.cat([self.history_ids, sep_ids, input_ids], dim=1)
if input_ids.shape[1] > self.max_session_len:
input_ids = input_ids[:,
(self.reduce_size - self.max_session_len):]
input_ids[:, :len(self.start_ids)] = self.start_ids.repeat(
self.bs, 1)
return input_ids
def add_to_history(self, output):
"""Save history output ids.
Note:
Output returned by HuggingFace generator contains both input
and output ids.
"""
if isinstance(output, ModelOutput):
self.history_ids = output.sequences
elif isinstance(output, torch.Tensor):
self.history_ids = output
else:
raise ValueError(f'Unknown output type {type(output)}')
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import inspect
from inspect import Parameter, Signature
from typing import Dict, Sequence
import logging import psutil
from transformers.generation.streamers import BaseStreamer
from .dist import get_rank, master_only, master_only_and_broadcast_general def get_gpu_memory(id: int = 0) -> int:
"""Returns the free and total physical memory of the GPU in bytes."""
import torch
return torch.cuda.mem_get_info(id)
try:
import readline # To support command line history # noqa: F401
except ImportError: # readline not available
pass
logger = logging.getLogger(__name__) def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total
class TerminalIO: def bind_sigature(input_names: str, args: Sequence, kwargs: Dict):
"""Terminal input and output.""" """Bind args and kwargs to given input names."""
kind = inspect._ParameterKind.POSITIONAL_OR_KEYWORD
end_of_output = '\n' sig = Signature([Parameter(name, kind) for name in input_names])
bind = sig.bind(*args, **kwargs)
@master_only_and_broadcast_general return bind.arguments
def input(self):
"""Read input from terminal."""
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
try:
return '\n'.join(iter(input, sentinel))
except EOFError:
print('Detect EOF, exit')
exit()
@master_only
def output(self, string):
"""Output to terminal with flush."""
print(string, end='', flush=True)
class BasicStreamer(BaseStreamer):
"""Basic streamer for HuggingFace models."""
def __init__(self,
decode_func,
output_func,
end_of_output='\n',
skip_prompt=True):
self.decode = decode_func
self.output = output_func
self.end_of_output = end_of_output
self.skip_prompt = skip_prompt
self.gen_len = 0
def put(self, value):
"""Callback before forwarding current token id to model."""
if self.gen_len == 0 and self.skip_prompt:
pass
else:
token = self.decode(value)
self.output(token)
self.gen_len += 1
def end(self):
"""Callback at the end of generation."""
self.output(self.end_of_output)
self.gen_len = 0
def control(prompt, gen_config, sm):
"""Allow user to control generation config and session manager.
Return:
True if control command applied, False otherwise.
"""
if prompt == 'exit':
exit(0)
if prompt == 'clear':
sm.new_session()
logger.info('Session cleared')
return True
# Re-config during runtime
if prompt.startswith('config set'):
try:
keqv = prompt.split()[-1]
k, v = keqv.split('=')
v = eval(v)
gen_config.__setattr__(k, v)
logger.info(f'Worker {get_rank()} set {k} to {repr(v)}')
logger.info(f'Generator config changed to: {gen_config}')
return True
except: # noqa
logger.info(
'illegal instruction, treated as normal conversation. ')
return False
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import asyncio import asyncio
import dataclasses import dataclasses
import os
import random import random
from contextlib import contextmanager from argparse import ArgumentError
from typing import List, Literal, Optional, Union from contextlib import asynccontextmanager
from itertools import count
from queue import Empty, Queue
from threading import Thread
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from lmdeploy.messages import (EngineGenerationConfig, GenerationConfig,
PytorchEngineConfig, Response,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig, best_match_model
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _stop_words, get_logger
logger = get_logger('lmdeploy')
def get_model_name_from_workspace_model(model_dir: str):
"""Get model name from workspace model."""
from configparser import ConfigParser
triton_model_path = os.path.join(model_dir, 'triton_models', 'weights')
if not os.path.exists(triton_model_path):
return None
ini_path = os.path.join(triton_model_path, 'config.ini')
# load cfg
with open(ini_path, 'r') as f:
parser = ConfigParser()
parser.read_file(f)
return parser['llama']['model_name']
def deduce_a_name(
model_path: str,
model_name: Optional[str] = None,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None) -> str:
"""Deduce a model name from all the possible arguments."""
def _config_model_name(config):
if config and config.model_name:
return config.model_name
return None
backend_config_model_name = _config_model_name(backend_config)
chat_template_config_model_name = _config_model_name(chat_template_config)
model_name = model_name or chat_template_config_model_name or backend_config_model_name # noqa
if model_name is None:
# model maybe from workspace for turbomind
model_name = get_model_name_from_workspace_model(model_path)
# may get a model name from model_path
if model_name is None:
model_name = best_match_model(model_path)
if model_name is None:
raise ArgumentError(None,
f'Please set model_name for {model_path}')
else:
logger.info(f'matched chat template name: {model_name}')
return model_name
@dataclasses.dataclass @dataclasses.dataclass
...@@ -16,6 +74,55 @@ class GenOut: ...@@ -16,6 +74,55 @@ class GenOut:
finish_reason: Optional[Literal['stop', 'length']] = None finish_reason: Optional[Literal['stop', 'length']] = None
class Session:
"""Session for AsyncEngine.chat.
Args:
_id (int): session_id for internal use.
_step (int): the offset of the k/v cache for internal use.
_prompt (Any): input prompt for internal use.
_response (Reaponse): model output for prompt.
_engine (Any): engine for internal use.
history (List[Any, str]): chat history.
"""
_ids = count(0)
def __init__(self):
self._id: int = next(self._ids)
self._step: int = 0
self._prompt: Any = None
self._response: Response = None
self._engine: Any = None
self.history: List[Tuple[Any, str]] = []
def _merge_response(self, resp: Response, step: Union[Response, GenOut]):
"""merge response."""
resp.text += step.text if isinstance(step, Response) else step.response
resp.input_token_len = step.input_token_len
resp.generate_token_len = step.generate_token_len
resp.finish_reason = step.finish_reason
return resp
@property
def response(self) -> Response:
"""return response."""
return self._response
def close(self):
"""release engine storage for this session."""
if self._engine:
inst = self._engine.create_instance()
inst.end(self._id)
def __repr__(self) -> str:
res = ''
for user, assistant in self.history:
if isinstance(user, list):
user = str(user)
res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n'
return res
class AsyncEngine: class AsyncEngine:
"""Async inference engine. Maintaining a bunch of tm_model instances. """Async inference engine. Maintaining a bunch of tm_model instances.
...@@ -30,51 +137,150 @@ class AsyncEngine: ...@@ -30,51 +137,150 @@ class AsyncEngine:
"InternLM/internlm-chat-20b-4bit", "InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc. "lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo - iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b", on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on. and so on.
model_name (str): needed when model_path is a pytorch model on model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b", huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
instance_num (int): instance numbers to be created backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel tp (int): tensor parallel
""" """
def __init__(self, def __init__(self,
model_path: str, model_path: str,
model_name: Optional[str] = None, model_name: Optional[str] = None,
instance_num: int = 32, backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
tp: int = 1, tp: int = 1,
**kwargs) -> None: **kwargs) -> None:
from lmdeploy import turbomind as tm logger.info(
self.tm_model = tm.TurboMind.from_pretrained(model_path, f'input backend={backend}, backend_config={backend_config}')
model_name=model_name, logger.info(f'input chat_template_config={chat_template_config}')
tp=tp,
**kwargs) self.model_name = deduce_a_name(model_path, model_name, backend_config,
self.tokenizer = self.tm_model.tokenizer chat_template_config)
self.instance_num = instance_num # build chat template config
self.model = self.tm_model.model if chat_template_config is None:
chat_template_config = ChatTemplateConfig(self.model_name)
elif chat_template_config.model_name is None:
chat_template_config.model_name = self.model_name
self.chat_template = chat_template_config.chat_template
# prevent bc
for k in list(kwargs.keys()):
if hasattr(chat_template_config, k):
logger.warning(f'{k} was deprecated. Please use '
'chat_template_config instead')
v = kwargs.pop(k)
setattr(chat_template_config, k, v)
logger.info(f'updated chat_template_onfig={chat_template_config}')
# build backend engine
if backend == 'turbomind':
self._build_turbomind(model_path=model_path,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)
elif backend == 'pytorch':
self._build_pytorch(model_path=model_path,
backend_config=backend_config,
**kwargs)
else:
raise ValueError(f'unsupported backend {backend}')
logger.info(f'updated backend_config={self.backend_config}')
# parameters for member functions
self.session_len = self.backend_config.session_len
self.stop_words = _stop_words(self.chat_template.stop_words,
self.engine.tokenizer)
if self.stop_words is not None:
self.stop_words = self.stop_words[0][0].tolist()
self.backend = backend
self.instance_num = self.backend_config.max_batch_size
self.tokenizer = self.engine.tokenizer
self.id2step = {} self.id2step = {}
self.id2generator = {} self.id2generator = {}
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.running_session_ids = set()
self.gens_set = set() self.gens_set = set()
for i in range(instance_num): for i in range(self.instance_num):
self.gens_set.add(self.tm_model.create_instance()) self.gens_set.add(self.engine.create_instance())
def _build_turbomind(
self,
model_path: str,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
tp: int = 1,
**kwargs):
"""Innter build method for turbomind backend."""
if backend_config is None:
backend_config = TurbomindEngineConfig(model_name=self.model_name,
tp=tp)
assert isinstance(backend_config, TurbomindEngineConfig), 'Please'\
' use TurbomindEngineConfig imported from lmdeploy.messages for ' \
'turbomind backend'
if backend_config.session_len is None:
backend_config.session_len = self.chat_template.session_len
from lmdeploy import turbomind as tm
self.engine = tm.TurboMind.from_pretrained(
model_path,
engine_config=backend_config,
chat_template_config=chat_template_config,
**kwargs)
self.backend_config = backend_config
def _build_pytorch(
self,
model_path: str,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
**kwargs):
"""Innter build method for pytorch backend."""
from lmdeploy.pytorch.engine import Engine
if backend_config is None:
backend_config = PytorchEngineConfig(self.model_name)
assert isinstance(backend_config, PytorchEngineConfig), 'Please '\
'use PytorchEngineConfig imported from lmdeploy.messages for ' \
'pytorch backend'
if backend_config.session_len is None:
backend_config.session_len = self.chat_template.session_len
self.engine = Engine(model_path=model_path,
engine_config=backend_config)
self.backend_config = backend_config
def __call__(self, def __call__(self,
prompts: List[str], prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
gen_config: Optional[GenerationConfig] = None,
request_output_len=512, request_output_len=512,
top_k=40, top_k: int = 40,
top_p=0.8, top_p: float = 0.8,
temperature=0.8, temperature: float = 0.8,
repetition_penalty=1.0, repetition_penalty: float = 1.0,
ignore_eos=False, ignore_eos: bool = False,
do_preprocess=True, do_preprocess: bool = True,
**kwargs): **kwargs):
"""Inference a batch of prompts. """Inference a batch of prompts.
Args: Args:
prompts (List[str]): a batch of prompts prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
chat_template_config (ChatTemplateConfig | None):a instance of
ChatTemplateConfig. Default to None.
request_output_len (int): output token nums request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering tokens to keep for top-k-filtering
...@@ -85,245 +291,363 @@ class AsyncEngine: ...@@ -85,245 +291,363 @@ class AsyncEngine:
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages. do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
""" """
if gen_config is None:
gen_config = GenerationConfig(
max_new_tokens=request_output_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos)
return self.batch_infer(prompts, return self.batch_infer(prompts,
request_output_len=request_output_len, gen_config=gen_config,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
do_preprocess=do_preprocess, do_preprocess=do_preprocess,
**kwargs) **kwargs)
def stop_session(self, session_id: int): async def stop_session(self, session_id: int):
"""Stop a session by a session_id.""" """Stop a session by a session_id."""
input_ids = [self.tm_model.eos_id] if str(session_id) in self.id2generator:
stop_generator = self.tm_model.create_instance() await self.id2generator[str(session_id)].async_cancel(session_id)
for outputs in stop_generator.stream_infer(session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)]) self.gens_set.add(self.id2generator[str(session_id)])
def end_session(self, session_id: int): self.running_session_ids.discard(session_id)
async def end_session(self, session_id: int):
"""Clear a session by a session_id.""" """Clear a session by a session_id."""
input_ids = [self.tm_model.eos_id] if str(session_id) in self.id2generator:
end_generator = self.tm_model.create_instance() await self.id2generator[str(session_id)].async_end(session_id)
for outputs in end_generator.stream_infer(session_id, self.id2step[str(session_id)] = 0
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=True):
pass
self.id2step[str(session_id)] = 0
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)]) self.gens_set.add(self.id2generator[str(session_id)])
@contextmanager self.running_session_ids.discard(session_id)
def safe_run(self, session_id: Optional[int] = None):
@asynccontextmanager
async def safe_run(self, session_id: Optional[int] = None):
"""A context manager to make sure server's safe running.""" """A context manager to make sure server's safe running."""
try: try:
yield yield
except (Exception, asyncio.CancelledError) as e: # noqa except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id) await self.stop_session(session_id)
raise e raise e
if str(session_id) in self.id2generator and self.id2generator[str( if str(session_id) in self.id2generator:
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)]) self.gens_set.add(self.id2generator[str(session_id)])
self.running_session_ids.discard(session_id)
async def get_embeddings(self, prompt, do_prerpocess=False):
if do_prerpocess:
prompt = self.model.get_prompt(prompt)
input_ids = self.tokenizer.encode(prompt)
return input_ids
async def get_generator(self, stop: bool, session_id: int): async def get_generator(self, stop: bool, session_id: int):
"""Only return the model instance if it is available.""" """Only return the model instance if it is available."""
if stop: if stop:
return self.tm_model.create_instance() return self.engine.create_instance()
while self.gens_set == set(): # waiting no generator is available or the same session_id is running
await asyncio.sleep(0) while self.gens_set == set() or session_id in self.running_session_ids:
await asyncio.sleep(0.1)
generator = self.gens_set.pop() generator = self.gens_set.pop()
self.id2generator[str(session_id)] = generator self.id2generator[str(session_id)] = generator
self.running_session_ids.add(session_id)
return generator return generator
def batch_infer(self, def batch_infer(self,
prompts: Union[List[str], str], prompts: Union[List[str], str, List[Dict],
request_output_len=512, List[List[Dict]]],
top_k=40, gen_config: Optional[Union[GenerationConfig,
top_p=0.8, EngineGenerationConfig]] = None,
temperature=0.8, do_preprocess: bool = True,
repetition_penalty=1.0,
ignore_eos=False,
do_preprocess=True,
**kwargs): **kwargs):
"""Inference a batch of prompts. """Inference a batch of prompts.
Args: Args:
prompts (List[str] | str): a batch of prompts prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
request_output_len (int): output token nums prompts. It accepts: string prompt, a list of string prompts,
top_k (int): The number of the highest probability vocabulary a chat history in OpenAI format or a list of chat history.
tokens to keep for top-k-filtering gen_config (GenerationConfig | None): a instance of
top_p (float): If set to float < 1, only the smallest set of most GenerationConfig. Default to None.
probable tokens with probabilities that add up to top_p or higher do_preprocess (bool): whether pre-process the messages. Default to
are kept for generation. True, which means chat_template will be applied.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
""" """
input_str = isinstance(prompts, str) need_list_wrap = isinstance(prompts, str) or isinstance(
prompts = [prompts] if input_str else prompts prompts[0], Dict)
prompts = [prompts] if need_list_wrap else prompts
assert isinstance(prompts, List), 'prompts should be a list' assert isinstance(prompts, List), 'prompts should be a list'
batch_size = len(prompts) if gen_config is None:
outputs = [''] * batch_size gen_config = GenerationConfig()
generators = [] if type(gen_config) is GenerationConfig:
for i, prompt in enumerate(prompts): gen_config = EngineGenerationConfig.From(gen_config,
generators.append( self.tokenizer)
self.generate(prompt, # set random if it is not set
i, if gen_config.random_seed is None:
stream_response=True, gen_config.random_seed = random.getrandbits(64)
sequence_start=True, prompt_num = len(prompts)
sequence_end=True, outputs = [Response('', 0, 0, i) for i in range(prompt_num)]
request_output_len=request_output_len, for j in range(0, prompt_num, self.instance_num):
top_k=top_k, batch_prompts = prompts[j:j + self.instance_num]
top_p=top_p, generators = []
temperature=temperature, for i, prompt in enumerate(batch_prompts):
ignore_eos=ignore_eos, generators.append(
repetition_penalty=repetition_penalty, self.generate(prompt,
do_preprocess=do_preprocess, i,
**kwargs)) gen_config=gen_config,
stream_response=True,
async def _inner_call(i, generator): sequence_start=True,
async for out in generator: sequence_end=True,
outputs[i] += out.response do_preprocess=do_preprocess,
**kwargs))
async def gather():
await asyncio.gather( async def _inner_call(i, generator):
*[_inner_call(i, generators[i]) for i in range(batch_size)]) async for out in generator:
outputs[i + j].text += out.response
self.loop.run_until_complete(gather()) outputs[i + j].generate_token_len = out.generate_token_len
outputs = outputs[0] if input_str else outputs outputs[i + j].input_token_len = out.input_token_len
outputs[i + j].finish_reason = out.finish_reason
async def gather():
await asyncio.gather(*[
_inner_call(i, generators[i])
for i in range(len(batch_prompts))
])
self.loop.run_until_complete(gather())
outputs = outputs[0] if need_list_wrap else outputs
return outputs return outputs
def stream_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
gen_config: Optional[Union[GenerationConfig,
EngineGenerationConfig]] = None,
do_preprocess: bool = True,
**kwargs):
"""Inference a batch of prompts with stream mode.
Args:
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
need_list_wrap = isinstance(prompts, str) or isinstance(
prompts[0], Dict)
prompts = [prompts] if need_list_wrap else prompts
assert isinstance(prompts, List), 'prompts should be a list'
if gen_config is None:
gen_config = GenerationConfig()
if type(gen_config) is GenerationConfig:
gen_config = EngineGenerationConfig.From(gen_config,
self.tokenizer)
# set random if it is not set
if gen_config.random_seed is None:
gen_config.random_seed = random.getrandbits(64)
prompt_num = len(prompts)
outputs = Queue()
generators = []
for j in range(0, prompt_num, self.instance_num):
batch_prompts = prompts[j:j + self.instance_num]
generators = []
for i, prompt in enumerate(batch_prompts):
generators.append(
self.generate(prompt,
i,
gen_config=gen_config,
stream_response=True,
sequence_start=True,
sequence_end=True,
do_preprocess=do_preprocess,
**kwargs))
async def _inner_call(i, generator):
async for out in generator:
outputs.put(
Response(out.response, out.generate_token_len,
out.input_token_len, i + j,
out.finish_reason))
async def gather():
await asyncio.gather(*[
_inner_call(i, generators[i])
for i in range(len(batch_prompts))
])
outputs.put(None)
proc = Thread(
target=lambda: self.loop.run_until_complete(gather()))
proc.start()
while True:
try:
out = outputs.get(timeout=0.001)
if out is None:
break
yield out
except Empty:
pass
proc.join()
async def _get_prompt_input(self, prompt: str, do_preprocess: bool,
sequence_start: bool):
if do_preprocess:
prompt = self.chat_template.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
return {'prompt': prompt, 'input_ids': input_ids}
async def generate( async def generate(
self, self,
messages, messages,
session_id, session_id: int,
stream_response=True, gen_config: Optional[Union[GenerationConfig,
sequence_start=True, EngineGenerationConfig]] = None,
sequence_end=True, # no interactive mode by default stream_response: bool = True,
step=0, sequence_start: bool = True,
request_output_len=512, sequence_end: bool = True, # no interactive mode by default
stop=False, step: int = 0,
top_k=40, do_preprocess: bool = True,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
do_preprocess=True,
**kwargs): **kwargs):
"""Generate responses. """Generate responses.
Args: Args:
messages (str | List): chat history or prompt messages (str | List): chat history or prompt
session_id (int): the session id session_id (int): the session id
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
stream_response (bool): whether return responses streamingly stream_response (bool): whether return responses streamingly
request_output_len (int): output token nums
sequence_start (bool): indicator for starting a sequence sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache step (int): the offset of the k/v cache
stop (bool): whether stop inference do_preprocess (bool): whether pre-process the messages. Default to
top_k (int): The number of the highest probability vocabulary True, which means chat_template will be applied.
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
""" """
if str(session_id) not in self.id2step: if str(session_id) not in self.id2step:
self.id2step[str(session_id)] = 0 self.id2step[str(session_id)] = 0
if step != 0: if step != 0:
self.id2step[str(session_id)] = step self.id2step[str(session_id)] = step
seed = random.getrandbits(64) if gen_config is None:
gen_config = GenerationConfig()
if type(gen_config) is GenerationConfig:
gen_config = EngineGenerationConfig.From(gen_config,
self.tokenizer)
if gen_config.stop_words is None:
gen_config.stop_words = self.stop_words
# set random if it is not set and sequence_start is True
if gen_config.random_seed is None and sequence_start:
gen_config.random_seed = random.getrandbits(64)
prompt = messages prompt = messages
if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = None
request_output_len = min( prompt_input = await self._get_prompt_input(prompt, do_preprocess,
request_output_len, self.tm_model.session_len - self.id2step[str(session_id)] - sequence_start)
prompt = prompt_input['prompt']
logger.info(f'Prompt with applied chat template:\n{prompt}')
input_ids = prompt_input['input_ids']
if gen_config.max_new_tokens is None:
# for interactive endpoint, will try maximum possible token num
gen_config.max_new_tokens = max(
128, self.session_len - self.id2step[str(session_id)] -
len(input_ids)) len(input_ids))
request_output_len = max(0, request_output_len) finish_reason = None
logger.info(f'session_id={session_id}, '
if stop is True: f'history_tokens={self.id2step[str(session_id)]}, '
self.stop_session(session_id) f'input_tokens={len(input_ids)}, '
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0, f'max_new_tokens={gen_config.max_new_tokens}, '
finish_reason) f'seq_start={sequence_start}, seq_end={sequence_end}, '
elif self.id2step[str(session_id)] + len( f'step={step}, prep={do_preprocess}')
input_ids) + request_output_len > self.tm_model.session_len: if self.id2step[str(session_id)] + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
logger.warning(f'run out of tokens. session_id={session_id}')
finish_reason = 'length' finish_reason = 'length'
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0, yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason) finish_reason)
if sequence_end is True and sequence_start is False: if sequence_end is True and sequence_start is False:
self.end_session(session_id) await self.end_session(session_id)
else: else:
generator = await self.get_generator(stop, session_id) generator = await self.get_generator(False, session_id)
with self.safe_run(session_id): async with self.safe_run(session_id):
response_size = 0 state = DetokenizeState()
async for outputs in generator.async_stream_infer( async for outputs in generator.async_stream_infer(
session_id=session_id, session_id=session_id,
input_ids=[input_ids], **prompt_input,
gen_config=gen_config,
stream_output=stream_response, stream_output=stream_response,
request_output_len=request_output_len, sequence_start=sequence_start,
sequence_start=(sequence_start),
sequence_end=sequence_end, sequence_end=sequence_end,
step=self.id2step[str(session_id)], step=self.id2step[str(session_id)]):
stop=stop, _, res, tokens = outputs
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res # decode res
response = self.tokenizer.decode(res.tolist(), response, state = self.tokenizer.detokenize_incrementally(
offset=response_size) res,
# utf-8 char at the end means it's a potential unfinished state,
# byte sequence, continue to concate it with the next skip_special_tokens=gen_config.skip_special_tokens)
# sequence and decode them together
if response.endswith('�'):
continue
# response, history token len, # response, history token len,
# input token len, gen token len # input token len, gen token len
yield GenOut(response, self.id2step[str(session_id)], yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
response_size = tokens
finish_reason = 'length' \ finish_reason = 'length' \
if tokens >= request_output_len else 'stop' if tokens >= gen_config.max_new_tokens else 'stop'
# `response_size` might be note updated since # utf-8 char at the end means it's a potential unfinished
# ` if response.endswith('�')` # byte sequence
if response_size == tokens: if not response.endswith('�'):
response = '' # avaid returning the last response twice response = '' # avaid returning the last response twice
yield GenOut(response, self.id2step[str(session_id)], yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
# update step # update step
self.id2step[str(session_id)] += len(input_ids) + tokens self.id2step[str(session_id)] += len(input_ids) + tokens
if sequence_end or stop: if sequence_end:
self.id2step[str(session_id)] = 0 self.id2step[str(session_id)] = 0
# manually end pytorch session
# TODO modify pytorch or turbomind api
if self.backend == 'pytorch' and sequence_end:
await self.end_session(session_id)
def chat(self,
prompt: str,
session=None,
gen_config: Optional[Union[GenerationConfig,
EngineGenerationConfig]] = None,
do_preprocess: bool = True,
**kwargs) -> Session:
"""Chat.
Args:
prompt (str): prompt
session (Session): the chat session
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
**kwargs (dict): ad hoc parametrization of `gen_config
"""
if session is None:
session = Session()
session._engine = self.engine
# sync & init
session._prompt = prompt
session._response = None
sequence_start = session._step == 0
async def _work():
resp = Response('', -1, -1, session._id)
async for output in self.generate(prompt,
session_id=session._id,
gen_config=gen_config,
stream_response=False,
sequence_start=sequence_start,
sequence_end=False,
step=session._step,
do_preprocess=do_preprocess,
**kwargs):
resp = session._merge_response(resp, output)
return resp
from lmdeploy.pytorch.engine.request import _run_until_complete
resp = _run_until_complete(_work())
session._response = resp
session._step += resp.generate_token_len + resp.input_token_len
session.history.append((session._prompt, resp.text))
return session
...@@ -17,7 +17,8 @@ class InterFace: ...@@ -17,7 +17,8 @@ class InterFace:
def chat_stream_restful(instruction: str, state_chatbot: Sequence, def chat_stream_restful(instruction: str, state_chatbot: Sequence,
cancel_btn: gr.Button, reset_btn: gr.Button, cancel_btn: gr.Button, reset_btn: gr.Button,
session_id: int): session_id: int, top_p: float, temperature: float,
request_output_len: int):
"""Chat with AI assistant. """Chat with AI assistant.
Args: Args:
...@@ -33,9 +34,11 @@ def chat_stream_restful(instruction: str, state_chatbot: Sequence, ...@@ -33,9 +34,11 @@ def chat_stream_restful(instruction: str, state_chatbot: Sequence,
instruction, instruction,
f'{InterFace.api_server_url}/v1/chat/interactive', f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id, session_id=session_id,
request_output_len=512, request_output_len=request_output_len,
interactive_mode=True): interactive_mode=True,
if finish_reason == 'length': top_p=top_p,
temperature=temperature):
if finish_reason == 'length' and tokens == 0:
gr.Warning('WARNING: exceed session max length.' gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.') ' Please restart the session by reset button.')
if tokens < 0: if tokens < 0:
...@@ -94,7 +97,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, ...@@ -94,7 +97,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
f'{InterFace.api_server_url}/v1/chat/interactive', f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id, session_id=session_id,
request_output_len=0, request_output_len=0,
stop=True, cancel=True,
interactive_mode=True): interactive_mode=True):
pass pass
# end the session # end the session
...@@ -106,6 +109,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, ...@@ -106,6 +109,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
interactive_mode=False): interactive_mode=False):
pass pass
# resume the session # resume the session
# TODO this is not proper if api server is running pytorch backend
messages = [] messages = []
for qa in state_chatbot: for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0])) messages.append(dict(role='user', content=qa[0]))
...@@ -155,10 +159,22 @@ def run_api_server(api_server_url: str, ...@@ -155,10 +159,22 @@ def run_api_server(api_server_url: str,
with gr.Row(): with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False) cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset') reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')
send_event = instruction_txtbox.submit(chat_stream_restful, [ send_event = instruction_txtbox.submit(chat_stream_restful, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn, instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn]) ], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit( instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''), lambda: gr.Textbox.update(value=''),
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Literal, Optional, Union
from lmdeploy.archs import get_task
from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.model import ChatTemplateConfig
def run(model_path_or_server: str, def run(model_path_or_server: str,
server_name: str = '0.0.0.0', server_name: str = '0.0.0.0',
server_port: int = 6006, server_port: int = 6006,
batch_size: int = 32, batch_size: int = 32,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
tp: int = 1, tp: int = 1,
model_name: str = None, model_name: str = None,
**kwargs): **kwargs):
...@@ -19,6 +28,12 @@ def run(model_path_or_server: str, ...@@ -19,6 +28,12 @@ def run(model_path_or_server: str,
server_name (str): the ip address of gradio server server_name (str): the ip address of gradio server
server_port (int): the port of gradio server server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly batch_size (int): batch size for running Turbomind directly
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel for Turbomind tp (int): tensor parallel for Turbomind
""" """
if ':' in model_path_or_server: if ':' in model_path_or_server:
...@@ -31,11 +46,22 @@ def run(model_path_or_server: str, ...@@ -31,11 +46,22 @@ def run(model_path_or_server: str,
run_triton_server run_triton_server
run_triton_server(model_path_or_server, server_name, server_port) run_triton_server(model_path_or_server, server_name, server_port)
else: else:
from lmdeploy.serve.gradio.turbomind_coupled import run_local pipeline_type, _ = get_task(model_path_or_server)
if pipeline_type == 'vlm':
from lmdeploy.serve.gradio.vl import run_local
assert backend == 'turbomind', 'vlm only support turbomind backend'
if backend_config is not None and \
backend_config.session_len is None:
backend_config.session_len = 8192
else:
from lmdeploy.serve.gradio.turbomind_coupled import run_local
run_local(model_path_or_server, run_local(model_path_or_server,
model_name=model_name,
server_name=server_name, server_name=server_name,
server_port=server_port, server_port=server_port,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
model_name=model_name,
batch_size=batch_size, batch_size=batch_size,
tp=tp, tp=tp,
**kwargs) **kwargs)
......
...@@ -24,5 +24,5 @@ THEME = gr.themes.Soft( ...@@ -24,5 +24,5 @@ THEME = gr.themes.Soft(
secondary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.sky,
font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif']) font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
enable_btn = gr.Button.update(interactive=True) enable_btn = gr.update(interactive=True)
disable_btn = gr.Button.update(interactive=False) disable_btn = gr.update(interactive=False)
...@@ -16,7 +16,8 @@ class InterFace: ...@@ -16,7 +16,8 @@ class InterFace:
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int): cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int,
top_p: float, temperature: float, request_output_len: int):
"""Chat with AI assistant. """Chat with AI assistant.
Args: Args:
...@@ -30,7 +31,12 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, ...@@ -30,7 +31,12 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction = state_chatbot[-1][0] instruction = state_chatbot[-1][0]
bot_response = llama_chatbot.stream_infer( bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}') session_id,
instruction,
f'{session_id}-{len(state_chatbot)}',
request_output_len=request_output_len,
top_p=top_p,
temperature=temperature)
for status, tokens, _ in bot_response: for status, tokens, _ in bot_response:
state_chatbot[-1] = (state_chatbot[-1][0], tokens) state_chatbot[-1] = (state_chatbot[-1][0], tokens)
...@@ -108,12 +114,24 @@ def run_triton_server(triton_server_addr: str, ...@@ -108,12 +114,24 @@ def run_triton_server(triton_server_addr: str,
with gr.Row(): with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False) cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset') reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')
send_event = instruction_txtbox.submit( send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot], add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(chat_stream, [ [instruction_txtbox, state_chatbot]).then(chat_stream, [
state_chatbot, llama_chatbot, cancel_btn, reset_btn, state_chatbot, llama_chatbot, cancel_btn, reset_btn,
state_session_id state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn]) ], [state_chatbot, chatbot, cancel_btn, reset_btn])
cancel_btn.click(cancel_func, cancel_btn.click(cancel_func,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random
from threading import Lock from threading import Lock
from typing import Optional, Sequence from typing import Literal, Optional, Sequence, Union
import gradio as gr import gradio as gr
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
...@@ -14,13 +18,10 @@ class InterFace: ...@@ -14,13 +18,10 @@ class InterFace:
lock = Lock() lock = Lock()
async def chat_stream_local( async def chat_stream_local(instruction: str, state_chatbot: Sequence,
instruction: str, cancel_btn: gr.Button, reset_btn: gr.Button,
state_chatbot: Sequence, session_id: int, top_p: float, temperature: float,
cancel_btn: gr.Button, request_output_len: int):
reset_btn: gr.Button,
session_id: int,
):
"""Chat with AI assistant. """Chat with AI assistant.
Args: Args:
...@@ -33,15 +34,23 @@ async def chat_stream_local( ...@@ -33,15 +34,23 @@ async def chat_stream_local(
state_chatbot = state_chatbot + [(instruction, None)] state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn) yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
gen_config = GenerationConfig(max_new_tokens=request_output_len,
top_p=top_p,
top_k=40,
temperature=temperature,
random_seed=random.getrandbits(64)
if len(state_chatbot) == 1 else None)
async for outputs in InterFace.async_engine.generate( async for outputs in InterFace.async_engine.generate(
instruction, instruction,
session_id, session_id,
gen_config=gen_config,
stream_response=True, stream_response=True,
sequence_start=(len(state_chatbot) == 1), sequence_start=(len(state_chatbot) == 1),
sequence_end=False): sequence_end=False):
response = outputs.response response = outputs.response
if outputs.finish_reason == 'length': if outputs.finish_reason == 'length' and \
outputs.generate_token_len == 0:
gr.Warning('WARNING: exceed session max length.' gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.') ' Please restart the session by reset button.')
if outputs.generate_token_len < 0: if outputs.generate_token_len < 0:
...@@ -69,7 +78,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox, ...@@ -69,7 +78,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
""" """
state_chatbot = [] state_chatbot = []
# end the session # end the session
InterFace.async_engine.end_session(session_id) await InterFace.async_engine.end_session(session_id)
return (state_chatbot, state_chatbot, gr.Textbox.update(value='')) return (state_chatbot, state_chatbot, gr.Textbox.update(value=''))
...@@ -85,28 +94,36 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button, ...@@ -85,28 +94,36 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
session_id (int): the session id session_id (int): the session id
""" """
yield (state_chatbot, disable_btn, disable_btn) yield (state_chatbot, disable_btn, disable_btn)
InterFace.async_engine.stop_session(session_id) await InterFace.async_engine.stop_session(session_id)
InterFace.async_engine.end_session(session_id) # pytorch backend does not support resume chat history now
messages = [] if InterFace.async_engine.backend == 'pytorch':
for qa in state_chatbot: yield (state_chatbot, disable_btn, enable_btn)
messages.append(dict(role='user', content=qa[0])) else:
if qa[1] is not None: await InterFace.async_engine.end_session(session_id)
messages.append(dict(role='assistant', content=qa[1])) messages = []
async for out in InterFace.async_engine.generate(messages, for qa in state_chatbot:
session_id, messages.append(dict(role='user', content=qa[0]))
request_output_len=0, if qa[1] is not None:
stream_response=True, messages.append(dict(role='assistant', content=qa[1]))
sequence_start=True, gen_config = GenerationConfig(max_new_tokens=0)
sequence_end=False): async for out in InterFace.async_engine.generate(messages,
pass session_id,
yield (state_chatbot, disable_btn, enable_btn) gen_config=gen_config,
stream_response=True,
sequence_start=True,
sequence_end=False):
pass
yield (state_chatbot, disable_btn, enable_btn)
def run_local(model_path: str, def run_local(model_path: str,
model_name: Optional[str] = None, model_name: Optional[str] = None,
server_name: str = 'localhost', backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
server_name: str = '0.0.0.0',
server_port: int = 6006, server_port: int = 6006,
batch_size: int = 4,
tp: int = 1, tp: int = 1,
**kwargs): **kwargs):
"""chat with AI assistant through web ui. """chat with AI assistant through web ui.
...@@ -122,22 +139,32 @@ def run_local(model_path: str, ...@@ -122,22 +139,32 @@ def run_local(model_path: str,
"InternLM/internlm-chat-20b-4bit", "InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc. "lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo - iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b", on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on. and so on.
model_name (str): needed when model_path is a pytorch model on model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b", huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
server_name (str): the ip address of gradio server backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): the ip address of gradio server. Default to
"0.0.0.0". For huggingface space demo, it should be
"huggingface-space".
server_port (int): the port of gradio server server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind tp (int): tensor parallel for Turbomind
""" """
InterFace.async_engine = AsyncEngine(model_path=model_path, InterFace.async_engine = AsyncEngine(
model_name=model_name, model_path=model_path,
instance_num=batch_size, backend=backend,
tp=tp, backend_config=backend_config,
**kwargs) chat_template_config=chat_template_config,
model_name=model_name,
tp=tp,
**kwargs)
with gr.Blocks(css=CSS, theme=THEME) as demo: with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([]) state_chatbot = gr.State([])
...@@ -148,17 +175,29 @@ def run_local(model_path: str, ...@@ -148,17 +175,29 @@ def run_local(model_path: str,
chatbot = gr.Chatbot( chatbot = gr.Chatbot(
elem_id='chatbot', elem_id='chatbot',
label=InterFace.async_engine.tm_model.model_name) label=InterFace.async_engine.engine.model_name)
instruction_txtbox = gr.Textbox( instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction', placeholder='Please input the instruction',
label='Instruction') label='Instruction')
with gr.Row(): with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False) cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset') reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')
send_event = instruction_txtbox.submit(chat_stream_local, [ send_event = instruction_txtbox.submit(chat_stream_local, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn, instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn]) ], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit( instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''), lambda: gr.Textbox.update(value=''),
...@@ -184,14 +223,19 @@ def run_local(model_path: str, ...@@ -184,14 +223,19 @@ def run_local(model_path: str,
demo.load(init, inputs=None, outputs=[state_session_id]) demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}') if server_name == 'huggingface-space':
demo.queue(concurrency_count=batch_size, max_size=100, demo.queue(concurrency_count=InterFace.async_engine.instance_num,
api_open=True).launch( max_size=100).launch()
max_threads=10, else:
share=True, print(f'server is gonna mount on: http://{server_name}:{server_port}')
server_port=server_port, demo.queue(concurrency_count=InterFace.async_engine.instance_num,
server_name=server_name, max_size=100,
) api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,8 +4,11 @@ from typing import Any, Dict, Iterable, List, Optional, Union ...@@ -4,8 +4,11 @@ from typing import Any, Dict, Iterable, List, Optional, Union
import requests import requests
from lmdeploy.utils import get_logger
def get_model_list(api_url: str): def get_model_list(api_url: str):
"""Get model list from api server."""
response = requests.get(api_url) response = requests.get(api_url)
if hasattr(response, 'text'): if hasattr(response, 'text'):
model_list = json.loads(response.text) model_list = json.loads(response.text)
...@@ -14,15 +17,31 @@ def get_model_list(api_url: str): ...@@ -14,15 +17,31 @@ def get_model_list(api_url: str):
return None return None
def json_loads(content):
"""Loads content to json format."""
try:
content = json.loads(content)
return content
except: # noqa
logger = get_logger('lmdeploy')
logger.warning(f'weird json content {content}')
return ''
class APIClient: class APIClient:
"""Chatbot for LLaMA series models with turbomind as inference engine. """Chatbot for LLaMA series models with turbomind as inference engine.
Args: Args:
api_server_url (str): communicating address 'http://<ip>:<port>' of api_server_url (str): communicating address 'http://<ip>:<port>' of
api_server api_server
api_key (str | None): api key. Default to None, which means no
api key will be used.
""" """
def __init__(self, api_server_url: str, **kwargs): def __init__(self,
api_server_url: str,
api_key: Optional[str] = None,
**kwargs):
self.api_server_url = api_server_url self.api_server_url = api_server_url
self.chat_intractive_v1_url = f'{api_server_url}/v1/chat/interactive' self.chat_intractive_v1_url = f'{api_server_url}/v1/chat/interactive'
self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions' self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions'
...@@ -30,6 +49,10 @@ class APIClient: ...@@ -30,6 +49,10 @@ class APIClient:
self.models_v1_url = f'{api_server_url}/v1/models' self.models_v1_url = f'{api_server_url}/v1/models'
self.encode_v1_url = f'{api_server_url}/v1/encode' self.encode_v1_url = f'{api_server_url}/v1/encode'
self._available_models = None self._available_models = None
self.api_key = api_key
self.headers = {'content-type': 'application/json'}
if api_key is not None:
self.headers['Authorization'] = f'Bearer {api_key}'
@property @property
def available_models(self): def available_models(self):
...@@ -38,7 +61,7 @@ class APIClient: ...@@ -38,7 +61,7 @@ class APIClient:
return self._available_models return self._available_models
response = requests.get(self.models_v1_url) response = requests.get(self.models_v1_url)
if hasattr(response, 'text'): if hasattr(response, 'text'):
model_list = json.loads(response.text) model_list = json_loads(response.text)
model_list = model_list.pop('data', []) model_list = model_list.pop('data', [])
self._available_models = [item['id'] for item in model_list] self._available_models = [item['id'] for item in model_list]
return self._available_models return self._available_models
...@@ -57,15 +80,14 @@ class APIClient: ...@@ -57,15 +80,14 @@ class APIClient:
when it is not. Default to True. when it is not. Default to True.
Return: (input_ids, length) Return: (input_ids, length)
""" """
headers = {'content-type': 'application/json'}
response = requests.post(self.encode_v1_url, response = requests.post(self.encode_v1_url,
headers=headers, headers=self.headers,
json=dict(input=input, json=dict(input=input,
do_preprocess=do_preprocess, do_preprocess=do_preprocess,
add_bos=add_bos), add_bos=add_bos),
stream=False) stream=False)
if hasattr(response, 'text'): if hasattr(response, 'text'):
output = json.loads(response.text) output = json_loads(response.text)
return output['input_ids'], output['length'] return output['input_ids'], output['length']
return None, None return None, None
...@@ -75,8 +97,8 @@ class APIClient: ...@@ -75,8 +97,8 @@ class APIClient:
temperature: Optional[float] = 0.7, temperature: Optional[float] = 0.7,
top_p: Optional[float] = 1.0, top_p: Optional[float] = 1.0,
n: Optional[int] = 1, n: Optional[int] = 1,
max_tokens: Optional[int] = 512, max_tokens: Optional[int] = None,
stop: Optional[bool] = False, stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
presence_penalty: Optional[float] = 0.0, presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0, frequency_penalty: Optional[float] = 0.0,
...@@ -84,12 +106,14 @@ class APIClient: ...@@ -84,12 +106,14 @@ class APIClient:
repetition_penalty: Optional[float] = 1.0, repetition_penalty: Optional[float] = 1.0,
session_id: Optional[int] = -1, session_id: Optional[int] = -1,
ignore_eos: Optional[bool] = False, ignore_eos: Optional[bool] = False,
skip_special_tokens: Optional[bool] = True,
**kwargs): **kwargs):
"""Chat completion v1. """Chat completion v1.
Args: Args:
model: model name. Available from self.available_models. model: model name. Available from self.available_models.
messages: string prompt or chat history in OpenAI format. messages: string prompt or chat history in OpenAI format. Chat
history example: `[{"role": "user", "content": "hi"}]`.
temperature (float): to modulate the next token probability temperature (float): to modulate the next token probability
top_p (float): If set to float < 1, only the smallest set of most top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or probable tokens with probabilities that add up to top_p or
...@@ -97,11 +121,15 @@ class APIClient: ...@@ -97,11 +121,15 @@ class APIClient:
n (int): How many chat completion choices to generate for each n (int): How many chat completion choices to generate for each
input message. Only support one here. input message. Only support one here.
stream: whether to stream the results or not. Default to false. stream: whether to stream the results or not. Default to false.
max_tokens (int): output token nums max_tokens (int | None): output token nums. Default to None.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields: Yields:
json objects in openai formats json objects in openai formats
...@@ -111,9 +139,8 @@ class APIClient: ...@@ -111,9 +139,8 @@ class APIClient:
for k, v in locals().copy().items() for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self'] if k[:2] != '__' and k not in ['self']
} }
headers = {'content-type': 'application/json'}
response = requests.post(self.chat_completions_v1_url, response = requests.post(self.chat_completions_v1_url,
headers=headers, headers=self.headers,
json=pload, json=pload,
stream=stream) stream=stream)
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(chunk_size=8192,
...@@ -126,11 +153,11 @@ class APIClient: ...@@ -126,11 +153,11 @@ class APIClient:
continue continue
if decoded[:6] == 'data: ': if decoded[:6] == 'data: ':
decoded = decoded[6:] decoded = decoded[6:]
output = json.loads(decoded) output = json_loads(decoded)
yield output yield output
else: else:
decoded = chunk.decode('utf-8') decoded = chunk.decode('utf-8')
output = json.loads(decoded) output = json_loads(decoded)
yield output yield output
def chat_interactive_v1(self, def chat_interactive_v1(self,
...@@ -138,13 +165,14 @@ class APIClient: ...@@ -138,13 +165,14 @@ class APIClient:
session_id: int = -1, session_id: int = -1,
interactive_mode: bool = False, interactive_mode: bool = False,
stream: bool = False, stream: bool = False,
stop: bool = False, stop: Optional[Union[str, List[str]]] = None,
request_output_len: int = 512, request_output_len: Optional[int] = None,
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 40, top_k: int = 40,
temperature: float = 0.8, temperature: float = 0.8,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = True,
**kwargs): **kwargs):
"""Interactive completions. """Interactive completions.
...@@ -162,8 +190,10 @@ class APIClient: ...@@ -162,8 +190,10 @@ class APIClient:
interactive mode, session history is kept on the server (and interactive mode, session history is kept on the server (and
vice versa). vice versa).
stream: whether to stream the results or not. stream: whether to stream the results or not.
stop: whether to stop the session response or not. stop (str | List[str] | None): To stop generating further tokens.
request_output_len (int): output token nums Only accept stop words that's encoded to one token idex.
request_output_len (int): output token nums. If not specified,
will use maximum possible number for a session.
top_p (float): If set to float < 1, only the smallest set of most top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or probable tokens with probabilities that add up to top_p or
higher are kept for generation. higher are kept for generation.
...@@ -173,18 +203,20 @@ class APIClient: ...@@ -173,18 +203,20 @@ class APIClient:
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields: Yields:
json objects consist of text, tokens, finish_reason json objects consist of text, tokens, input_tokens,
history_tokens, finish_reason
""" """
pload = { pload = {
k: v k: v
for k, v in locals().copy().items() for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self'] if k[:2] != '__' and k not in ['self']
} }
headers = {'content-type': 'application/json'}
response = requests.post(self.chat_intractive_v1_url, response = requests.post(self.chat_intractive_v1_url,
headers=headers, headers=self.headers,
json=pload, json=pload,
stream=stream) stream=stream)
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(chunk_size=8192,
...@@ -192,7 +224,7 @@ class APIClient: ...@@ -192,7 +224,7 @@ class APIClient:
delimiter=b'\n'): delimiter=b'\n'):
if chunk: if chunk:
decoded = chunk.decode('utf-8') decoded = chunk.decode('utf-8')
output = json.loads(decoded) output = json_loads(decoded)
yield output yield output
def completions_v1( def completions_v1(
...@@ -204,12 +236,15 @@ class APIClient: ...@@ -204,12 +236,15 @@ class APIClient:
n: Optional[int] = 1, n: Optional[int] = 1,
max_tokens: Optional[int] = 16, max_tokens: Optional[int] = 16,
stream: Optional[bool] = False, stream: Optional[bool] = False,
stop: Optional[Union[str, List[str]]] = None,
top_p: Optional[float] = 1.0, top_p: Optional[float] = 1.0,
top_k: Optional[int] = 40,
user: Optional[str] = None, user: Optional[str] = None,
# additional argument of lmdeploy # additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0, repetition_penalty: Optional[float] = 1.0,
session_id: Optional[int] = -1, session_id: Optional[int] = -1,
ignore_eos: Optional[bool] = False, ignore_eos: Optional[bool] = False,
skip_special_tokens: Optional[bool] = True,
**kwargs): **kwargs):
"""Chat completion v1. """Chat completion v1.
...@@ -223,14 +258,20 @@ class APIClient: ...@@ -223,14 +258,20 @@ class APIClient:
top_p (float): If set to float < 1, only the smallest set of most top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or probable tokens with probabilities that add up to top_p or
higher are kept for generation. higher are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
n (int): How many chat completion choices to generate for each n (int): How many chat completion choices to generate for each
input message. Only support one here. input message. Only support one here.
stream: whether to stream the results or not. Default to false. stream: whether to stream the results or not. Default to false.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
user (str): A unique identifier representing your end-user. user (str): A unique identifier representing your end-user.
ignore_eos (bool): indicator for ignoring eos ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields: Yields:
json objects in openai formats json objects in openai formats
...@@ -240,9 +281,8 @@ class APIClient: ...@@ -240,9 +281,8 @@ class APIClient:
for k, v in locals().copy().items() for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self'] if k[:2] != '__' and k not in ['self']
} }
headers = {'content-type': 'application/json'}
response = requests.post(self.completions_v1_url, response = requests.post(self.completions_v1_url,
headers=headers, headers=self.headers,
json=pload, json=pload,
stream=stream) stream=stream)
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(chunk_size=8192,
...@@ -250,16 +290,16 @@ class APIClient: ...@@ -250,16 +290,16 @@ class APIClient:
delimiter=b'\n'): delimiter=b'\n'):
if chunk: if chunk:
if stream: if stream:
decoded = chunk.decode('utf-8')[6:] decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]': if decoded == 'data: [DONE]':
continue continue
if decoded[:6] == 'data: ': if decoded[:6] == 'data: ':
decoded = decoded[6:] decoded = decoded[6:]
output = json.loads(decoded) output = json_loads(decoded)
yield output yield output
else: else:
decoded = chunk.decode('utf-8') decoded = chunk.decode('utf-8')
output = json.loads(decoded) output = json_loads(decoded)
yield output yield output
def chat(self, def chat(self,
...@@ -307,7 +347,7 @@ class APIClient: ...@@ -307,7 +347,7 @@ class APIClient:
temperature=temperature, temperature=temperature,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos): ignore_eos=ignore_eos):
if outputs['finish_reason'] == 'length': if outputs['finish_reason'] == 'length' and outputs['tokens'] == 0:
print('WARNING: exceed session max length.' print('WARNING: exceed session max length.'
' Please end the session.') ' Please end the session.')
yield outputs['text'], outputs['tokens'], outputs['finish_reason'] yield outputs['text'], outputs['tokens'], outputs['finish_reason']
...@@ -334,15 +374,21 @@ def input_prompt(): ...@@ -334,15 +374,21 @@ def input_prompt():
return '\n'.join(iter(input, sentinel)) return '\n'.join(iter(input, sentinel))
def get_streaming_response(prompt: str, def get_streaming_response(
api_url: str, prompt: str,
session_id: int, api_url: str,
request_output_len: int = 512, session_id: int,
stream: bool = True, request_output_len: int = 512,
interactive_mode: bool = False, stream: bool = True,
ignore_eos: bool = False, interactive_mode: bool = False,
stop: bool = False) -> Iterable[List[str]]: ignore_eos: bool = False,
cancel: bool = False,
top_p: float = 0.8,
temperature: float = 0.7,
api_key: Optional[str] = None) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'} headers = {'User-Agent': 'Test Client'}
if api_key is not None:
headers['Authorization'] = f'Bearer {api_key}'
pload = { pload = {
'prompt': prompt, 'prompt': prompt,
'stream': stream, 'stream': stream,
...@@ -350,7 +396,9 @@ def get_streaming_response(prompt: str, ...@@ -350,7 +396,9 @@ def get_streaming_response(prompt: str,
'request_output_len': request_output_len, 'request_output_len': request_output_len,
'interactive_mode': interactive_mode, 'interactive_mode': interactive_mode,
'ignore_eos': ignore_eos, 'ignore_eos': ignore_eos,
'stop': stop 'cancel': cancel,
'top_p': top_p,
'temperature': temperature
} }
response = requests.post(api_url, response = requests.post(api_url,
headers=headers, headers=headers,
...@@ -360,15 +408,18 @@ def get_streaming_response(prompt: str, ...@@ -360,15 +408,18 @@ def get_streaming_response(prompt: str,
decode_unicode=False, decode_unicode=False,
delimiter=b'\n'): delimiter=b'\n'):
if chunk: if chunk:
data = json.loads(chunk.decode('utf-8')) data = json_loads(chunk.decode('utf-8'))
output = data.pop('text', '') output = data.pop('text', '')
tokens = data.pop('tokens', 0) tokens = data.pop('tokens', 0)
finish_reason = data.pop('finish_reason', None) finish_reason = data.pop('finish_reason', None)
yield output, tokens, finish_reason yield output, tokens, finish_reason
def main(api_server_url: str, session_id: int = 0): def main(api_server_url: str,
api_client = APIClient(api_server_url) session_id: int = 0,
api_key: Optional[str] = None):
"""Main function to chat in terminal."""
api_client = APIClient(api_server_url, api_key=api_key)
while True: while True:
prompt = input_prompt() prompt = input_prompt()
if prompt in ['exit', 'end']: if prompt in ['exit', 'end']:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import asyncio import asyncio
import os import os
import random
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Literal, Optional, Union
import uvicorn import uvicorn
from fastapi import FastAPI, Request from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from lmdeploy.archs import get_task
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.openai.protocol import ( # noqa: E501 from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionRequestQos, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, CompletionRequest, ChatCompletionStreamResponse, ChatMessage, CompletionRequest,
CompletionResponse, CompletionResponseChoice, CompletionRequestQos, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse, EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse,
GenerateRequest, GenerateResponse, ModelCard, ModelList, ModelPermission, GenerateRequest, GenerateRequestQos, GenerateResponse, ModelCard,
UsageInfo) ModelList, ModelPermission, UsageInfo)
from lmdeploy.serve.qos_engine.qos_engine import QosEngine
from lmdeploy.utils import get_logger
class VariableInterface: class VariableInterface:
"""A IO interface maintaining variables.""" """A IO interface maintaining variables."""
async_engine: AsyncEngine = None async_engine: AsyncEngine = None
session_id: int = 0
api_keys: Optional[List[str]] = None
qos_engine: QosEngine = None
request_hosts = [] request_hosts = []
app = FastAPI(docs_url='/') app = FastAPI(docs_url='/')
get_bearer_token = HTTPBearer(auto_error=False)
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
) -> str:
"""Check if client provide valid api key.
Adopted from https://github.com/lm-sys/FastChat/blob/v0.2.35/fastchat/serve/openai_api_server.py#L108-L127
""" # noqa
if VariableInterface.api_keys:
if auth is None or (
token := auth.credentials) not in VariableInterface.api_keys:
raise HTTPException(
status_code=401,
detail={
'error': {
'message': 'Please request with valid api key!',
'type': 'invalid_request_error',
'param': None,
'code': 'invalid_api_key',
}
},
)
return token
else:
# api_keys not set; allow all
return None
def get_model_list(): def get_model_list():
...@@ -37,10 +74,10 @@ def get_model_list(): ...@@ -37,10 +74,10 @@ def get_model_list():
Only provided one now. Only provided one now.
""" """
return [VariableInterface.async_engine.tm_model.model_name] return [VariableInterface.async_engine.model_name]
@app.get('/v1/models') @app.get('/v1/models', dependencies=[Depends(check_api_key)])
def available_models(): def available_models():
"""Show available models.""" """Show available models."""
model_cards = [] model_cards = []
...@@ -74,17 +111,149 @@ async def check_request(request) -> Optional[JSONResponse]: ...@@ -74,17 +111,149 @@ async def check_request(request) -> Optional[JSONResponse]:
return ret return ret
def ip2id(host_ip: str): @app.post('/v1/chat/completions_qos')
"""Convert host ip address to session id.""" async def chat_completions_v1_qos(request: ChatCompletionRequestQos,
if '.' in host_ip: # IPv4 raw_request: Request = None):
return int(host_ip.replace('.', '')[-8:]) """Completion API similar to OpenAI's API.
if ':' in host_ip: # IPv6
return int(host_ip.replace(':', '')[-8:], 16) Refer to `https://platform.openai.com/docs/api-reference/chat/create`
print('Warning, could not get session id from ip, set it 0') for the API specification.
return 0
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = str(request.session_id)
created_time = int(time.time())
if VariableInterface.qos_engine is None:
return create_error_response(
HTTPStatus.NOT_FOUND,
'cannot parse qos engine config, this api is not work')
result_generator = await VariableInterface.qos_engine.generate_with_qos(
request)
if result_generator is None:
return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR,
'Failed to generate completions')
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role='assistant', content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.model_dump_json()
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role='assistant'),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f'data: {data}\n\n'
async for res in result_generator:
response_json = create_stream_response_json(
index=0,
text=res.response,
)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type='text/event-stream')
# Non-streaming response
final_res = None
text = ''
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
text += res.response
assert final_res is not None
choices = []
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=text),
finish_reason=final_res.finish_reason,
)
choices.append(choice_data)
total_tokens = sum([
final_res.history_token_len, final_res.input_token_len,
final_res.generate_token_len
])
usage = UsageInfo(
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
@app.post('/v1/chat/completions') @app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)])
async def chat_completions_v1(request: ChatCompletionRequest, async def chat_completions_v1(request: ChatCompletionRequest,
raw_request: Request = None): raw_request: Request = None):
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
...@@ -94,7 +263,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -94,7 +263,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
The request should be a JSON object with the following fields: The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models. - model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format. - messages: string prompt or chat history in OpenAI format. Chat history
example: `[{"role": "user", "content": "hi"}]`.
- temperature (float): to modulate the next token probability - temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most - top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher probable tokens with probabilities that add up to top_p or higher
...@@ -102,13 +272,18 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -102,13 +272,18 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- n (int): How many chat completion choices to generate for each input - n (int): How many chat completion choices to generate for each input
message. Only support one here. message. Only support one here.
- stream: whether to stream the results or not. Default to false. - stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums - max_tokens (int | None): output token nums. Default to None.
- repetition_penalty (float): The parameter for repetition penalty. - repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy: Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos - ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value - skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Currently we do not support the following features: Currently we do not support the following features:
- function_call (Users should implement this by themselves) - function_call (Users should implement this by themselves)
...@@ -116,8 +291,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -116,8 +291,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- presence_penalty (replaced with repetition_penalty) - presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty)
""" """
if request.session_id == -1: VariableInterface.session_id += 1
request.session_id = random.randint(1, 10086) request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request) error_check_ret = await check_request(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
...@@ -126,18 +301,26 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -126,18 +301,26 @@ async def chat_completions_v1(request: ChatCompletionRequest,
request_id = str(request.session_id) request_id = str(request.session_id)
created_time = int(time.time()) created_time = int(time.time())
if isinstance(request.stop, str):
request.stop = [request.stop]
gen_config = GenerationConfig(
max_new_tokens=request.max_tokens,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens)
result_generator = VariableInterface.async_engine.generate( result_generator = VariableInterface.async_engine.generate(
request.messages, request.messages,
request.session_id, request.session_id,
True, # always use stream to enable batching gen_config=gen_config,
stream_response=True, # always use stream to enable batching
sequence_start=True, sequence_start=True,
sequence_end=True, sequence_end=True,
request_output_len=request.max_tokens if request.max_tokens else 512,
stop=request.stop,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
do_preprocess=not isinstance(request.messages, do_preprocess=not isinstance(request.messages,
str), # text completion for string input str), # text completion for string input
) )
...@@ -196,7 +379,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -196,7 +379,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(request.session_id) await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected') 'Client disconnected')
final_res = res final_res = res
...@@ -230,7 +414,155 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -230,7 +414,155 @@ async def chat_completions_v1(request: ChatCompletionRequest,
return response return response
@app.post('/v1/completions') @app.post('/v1/completions_qos')
async def completions_v1_qos(request: CompletionRequestQos,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Go to `https://platform.openai.com/docs/api-reference/completions/create`
for the API specification.
The request should be a JSON object with the following fields:
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = str(request.session_id)
created_time = int(time.time())
if isinstance(request.prompt, str):
request.prompt = [request.prompt]
if VariableInterface.qos_engine is None:
return create_error_response(
HTTPStatus.NOT_FOUND,
'cannot parse qos engine config, this api is not work')
generators = await VariableInterface.qos_engine.generate_with_qos(request)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.model_dump_json()
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for generator in generators:
for i in range(request.n):
choice_data = CompletionResponseStreamChoice(
index=i,
text='',
finish_reason=None,
)
chunk = CompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f'data: {data}\n\n'
async for res in generator:
response_json = create_stream_response_json(
index=0,
text=res.response,
)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type='text/event-stream')
# Non-streaming response
usage = UsageInfo()
choices = []
async def _inner_call(i, generator):
final_res = None
text = ''
async for res in generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
text += res.response
assert final_res is not None
choice_data = CompletionResponseChoice(
index=0,
text=text,
finish_reason=final_res.finish_reason,
)
choices.append(choice_data)
total_tokens = sum([
final_res.history_token_len, final_res.input_token_len,
final_res.generate_token_len
])
usage.prompt_tokens += final_res.input_token_len
usage.completion_tokens += final_res.generate_token_len
usage.total_tokens += total_tokens
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(len(generators))])
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
@app.post('/v1/completions', dependencies=[Depends(check_api_key)])
async def completions_v1(request: CompletionRequest, async def completions_v1(request: CompletionRequest,
raw_request: Request = None): raw_request: Request = None):
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
...@@ -242,7 +574,7 @@ async def completions_v1(request: CompletionRequest, ...@@ -242,7 +574,7 @@ async def completions_v1(request: CompletionRequest,
- model (str): model name. Available from /v1/models. - model (str): model name. Available from /v1/models.
- prompt (str): the input prompt. - prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text. - suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums - max_tokens (int): output token nums. Default to 16.
- temperature (float): to modulate the next token probability - temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most - top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher probable tokens with probabilities that add up to top_p or higher
...@@ -253,18 +585,23 @@ async def completions_v1(request: CompletionRequest, ...@@ -253,18 +585,23 @@ async def completions_v1(request: CompletionRequest,
- repetition_penalty (float): The parameter for repetition penalty. - repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
- user (str): A unique identifier representing your end-user. - user (str): A unique identifier representing your end-user.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy: Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos - ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value - skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
Currently we do not support the following features: Currently we do not support the following features:
- logprobs (not supported yet) - logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty) - presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty)
""" """
if request.session_id == -1: VariableInterface.session_id += 1
request.session_id = random.randint(1, 10086) request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request) error_check_ret = await check_request(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
...@@ -274,21 +611,26 @@ async def completions_v1(request: CompletionRequest, ...@@ -274,21 +611,26 @@ async def completions_v1(request: CompletionRequest,
created_time = int(time.time()) created_time = int(time.time())
if isinstance(request.prompt, str): if isinstance(request.prompt, str):
request.prompt = [request.prompt] request.prompt = [request.prompt]
if isinstance(request.stop, str):
request.stop = [request.stop]
gen_config = GenerationConfig(
max_new_tokens=request.max_tokens if request.max_tokens else 512,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens)
generators = [] generators = []
for i in range(len(request.prompt)): for i in range(len(request.prompt)):
result_generator = VariableInterface.async_engine.generate( result_generator = VariableInterface.async_engine.generate(
request.prompt[i], request.prompt[i],
request.session_id + i, request.session_id + i,
True, # always use stream to enable batching gen_config=gen_config,
stream_response=True, # always use stream to enable batching
sequence_start=True, sequence_start=True,
sequence_end=True, sequence_end=True,
request_output_len=request.max_tokens
if request.max_tokens else 512,
stop=False,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
do_preprocess=False) do_preprocess=False)
generators.append(result_generator) generators.append(result_generator)
...@@ -351,7 +693,8 @@ async def completions_v1(request: CompletionRequest, ...@@ -351,7 +693,8 @@ async def completions_v1(request: CompletionRequest,
async for res in generator: async for res in generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(request.session_id) await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected') 'Client disconnected')
final_res = res final_res = res
...@@ -394,7 +737,7 @@ async def create_embeddings(request: EmbeddingsRequest, ...@@ -394,7 +737,7 @@ async def create_embeddings(request: EmbeddingsRequest,
'Unsupported by turbomind.') 'Unsupported by turbomind.')
@app.post('/v1/encode') @app.post('/v1/encode', dependencies=[Depends(check_api_key)])
async def encode(request: EncodeRequest, raw_request: Request = None): async def encode(request: EncodeRequest, raw_request: Request = None):
"""Encode prompts. """Encode prompts.
...@@ -407,7 +750,7 @@ async def encode(request: EncodeRequest, raw_request: Request = None): ...@@ -407,7 +750,7 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
def encode(prompt: str, do_preprocess: bool, add_bos: bool): def encode(prompt: str, do_preprocess: bool, add_bos: bool):
if do_preprocess: if do_preprocess:
prompt = VariableInterface.async_engine.model.get_prompt( prompt = VariableInterface.async_engine.chat_template.get_prompt(
prompt, sequence_start=add_bos) prompt, sequence_start=add_bos)
input_ids = VariableInterface.async_engine.tokenizer.encode( input_ids = VariableInterface.async_engine.tokenizer.encode(
prompt, add_bos=add_bos) prompt, add_bos=add_bos)
...@@ -425,12 +768,9 @@ async def encode(request: EncodeRequest, raw_request: Request = None): ...@@ -425,12 +768,9 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
return EncodeResponse(input_ids=encoded, length=length) return EncodeResponse(input_ids=encoded, length=length)
@app.post('/generate', @app.post('/v1/chat/interactive_qos')
tags=['deprecated'], async def chat_interactive_v1_qos(request: GenerateRequestQos,
description='please use /v1/chat/interactive') raw_request: Request = None):
@app.post('/v1/chat/interactive')
async def chat_interactive_v1(request: GenerateRequest,
raw_request: Request = None):
"""Generate completion for the request. """Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set - On interactive mode, the chat history is kept on the server. Please set
...@@ -456,33 +796,134 @@ async def chat_interactive_v1(request: GenerateRequest, ...@@ -456,33 +796,134 @@ async def chat_interactive_v1(request: GenerateRequest,
- repetition_penalty (float): The parameter for repetition penalty. - repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos - ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
""" """
if request.session_id == -1: if request.session_id == -1:
request.session_id = random.randint(10087, 23333) VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
if VariableInterface.qos_engine is None:
return create_error_response(
HTTPStatus.NOT_FOUND,
'cannot parse qos engine config, this api is not work')
generation = await VariableInterface.qos_engine.generate_with_qos(request)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for out in generation:
chunk = GenerateResponse(text=out.response,
tokens=out.generate_token_len,
input_tokens=out.input_token_len,
history_tokens=out.history_token_len,
finish_reason=out.finish_reason)
data = chunk.model_dump_json()
yield f'{data}\n'
if request.stream:
return StreamingResponse(stream_results(),
media_type='text/event-stream')
else:
ret = {}
text = ''
tokens = 0
finish_reason = None
async for out in generation:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await VariableInterface.qos_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
text += out.response
tokens = out.generate_token_len
finish_reason = out.finish_reason
ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason}
return JSONResponse(ret)
@app.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)])
async def chat_interactive_v1(request: GenerateRequest,
raw_request: Request = None):
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
`interactive_mode = True`.
- On normal mode, no chat history is kept on the server. Set
`interactive_mode = False`.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- session_id: determine which instance will be called. If not specified
with a value other than -1, using random value directly.
- interactive_mode (bool): turn on interactive mode or not. On interactive
mode, session history is kept on the server (and vice versa).
- stream: whether to stream the results or not.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
- request_output_len (int): output token nums. If not specified, will use
maximum possible number for a session.
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- temperature (float): to modulate the next token probability
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
"""
if request.cancel:
if request.session_id != -1:
await VariableInterface.async_engine.stop_session(
request.session_id)
return {
'text': '',
'tokens': 0,
'input_tokens': 0,
'history_tokens': 0,
'finish_reason': 'stop'
}
else:
return create_error_response(
HTTPStatus.BAD_REQUEST,
'please set a session_id to cancel a request')
if request.session_id == -1:
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
async_engine = VariableInterface.async_engine async_engine = VariableInterface.async_engine
sequence_start = async_engine.id2step.get(str(request.session_id), 0) == 0 sequence_start = async_engine.id2step.get(str(request.session_id), 0) == 0
sequence_end = not request.interactive_mode sequence_end = not request.interactive_mode
if isinstance(request.stop, str):
request.stop = [request.stop]
gen_config = GenerationConfig(
max_new_tokens=request.request_output_len,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens)
generation = async_engine.generate( generation = async_engine.generate(
request.prompt, request.prompt,
request.session_id, request.session_id,
gen_config=gen_config,
stream_response=True, # always use stream to enable batching stream_response=True, # always use stream to enable batching
sequence_start=sequence_start, sequence_start=sequence_start,
sequence_end=sequence_end, sequence_end=sequence_end)
request_output_len=request.request_output_len,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos)
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
async for out in generation: async for out in generation:
chunk = GenerateResponse(text=out.response, chunk = GenerateResponse(text=out.response,
tokens=out.generate_token_len, tokens=out.generate_token_len,
input_tokens=out.input_token_len,
history_tokens=out.history_token_len,
finish_reason=out.finish_reason) finish_reason=out.finish_reason)
data = chunk.model_dump_json() data = chunk.model_dump_json()
yield f'{data}\n' yield f'{data}\n'
...@@ -493,32 +934,46 @@ async def chat_interactive_v1(request: GenerateRequest, ...@@ -493,32 +934,46 @@ async def chat_interactive_v1(request: GenerateRequest,
else: else:
ret = {} ret = {}
text = '' text = ''
tokens = 0 tokens, input_tokens, history_tokens = 0, 0, 0
finish_reason = None finish_reason = None
async for out in generation: async for out in generation:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
async_engine.stop_session(request.session_id) await async_engine.stop_session(request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected') 'Client disconnected')
text += out.response text += out.response
tokens = out.generate_token_len tokens = out.generate_token_len
input_tokens = out.input_token_len
history_tokens = out.history_token_len
finish_reason = out.finish_reason finish_reason = out.finish_reason
ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason} ret = {
'text': text,
'tokens': tokens,
'input_tokens': input_tokens,
'history_tokens': history_tokens,
'finish_reason': finish_reason
}
return JSONResponse(ret) return JSONResponse(ret)
def serve(model_path: str, def serve(model_path: str,
model_name: Optional[str] = None, model_name: Optional[str] = None,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
server_name: str = '0.0.0.0', server_name: str = '0.0.0.0',
server_port: int = 23333, server_port: int = 23333,
instance_num: int = 64,
tp: int = 1, tp: int = 1,
allow_origins: List[str] = ['*'], allow_origins: List[str] = ['*'],
allow_credentials: bool = True, allow_credentials: bool = True,
allow_methods: List[str] = ['*'], allow_methods: List[str] = ['*'],
allow_headers: List[str] = ['*'], allow_headers: List[str] = ['*'],
log_level: str = 'ERROR', log_level: str = 'ERROR',
api_keys: Optional[Union[List[str], str]] = None,
ssl: bool = False,
qos_config_path: str = '',
**kwargs): **kwargs):
"""An example to perform model inference through the command line """An example to perform model inference through the command line
interface. interface.
...@@ -534,22 +989,34 @@ def serve(model_path: str, ...@@ -534,22 +989,34 @@ def serve(model_path: str,
"InternLM/internlm-chat-20b-4bit", "InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc. "lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo - iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b", on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on. and so on.
model_name (str): needed when model_path is a pytorch model on model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b" huggingface.co, such as "InternLM/internlm-chat-7b"
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): host ip for serving server_name (str): host ip for serving
server_port (int): server port server_port (int): server port
instance_num (int): number of instances of turbomind model
tp (int): tensor parallel tp (int): tensor parallel
allow_origins (List[str]): a list of allowed origins for CORS allow_origins (List[str]): a list of allowed origins for CORS
allow_credentials (bool): whether to allow credentials for CORS allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS allow_headers (List[str]): a list of allowed HTTP headers for CORS
log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as
a single api_key. Default to None, which means no api key applied.
ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.
qos_config_path (str): qos policy config path
""" # noqa E501 """ # noqa E501
os.environ['TM_LOG_LEVEL'] = log_level if os.getenv('TM_LOG_LEVEL') is None:
os.environ['TM_LOG_LEVEL'] = log_level
logger = get_logger('lmdeploy')
logger.setLevel(log_level)
if allow_origins: if allow_origins:
app.add_middleware( app.add_middleware(
...@@ -559,16 +1026,55 @@ def serve(model_path: str, ...@@ -559,16 +1026,55 @@ def serve(model_path: str,
allow_methods=allow_methods, allow_methods=allow_methods,
allow_headers=allow_headers, allow_headers=allow_headers,
) )
if api_keys is not None:
if isinstance(api_keys, str):
api_keys = api_keys.split(',')
VariableInterface.api_keys = api_keys
ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http'
if ssl:
ssl_keyfile = os.environ['SSL_KEYFILE']
ssl_certfile = os.environ['SSL_CERTFILE']
http_or_https = 'https'
pipeline_type, pipeline_class = get_task(model_path)
VariableInterface.async_engine = pipeline_class(
model_path=model_path,
model_name=model_name,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)
if qos_config_path:
try:
with open(qos_config_path, 'r') as file:
qos_config_str = file.read()
VariableInterface.qos_engine = QosEngine(
qos_tag=qos_config_str,
engine=VariableInterface.async_engine,
**kwargs)
VariableInterface.qos_engine.start()
except FileNotFoundError:
VariableInterface.qos_engine = None
else:
# hide qos functions if not applied
for i in range(len(app.router.routes)):
if 'qos' in app.router.routes[i].path:
app.router.routes[i].include_in_schema = False
VariableInterface.async_engine = AsyncEngine(model_path=model_path,
model_name=model_name,
instance_num=instance_num,
tp=tp,
**kwargs)
for i in range(3): for i in range(3):
print(f'HINT: Please open \033[93m\033[1mhttp://{server_name}:' print(
f'{server_port}\033[0m in a browser for detailed api usage!!!') f'HINT: Please open \033[93m\033[1m{http_or_https}://'
uvicorn.run(app=app, host=server_name, port=server_port, log_level='info') f'{server_name}:{server_port}\033[0m in a browser for detailed api'
' usage!!!')
uvicorn.run(app=app,
host=server_name,
port=server_port,
log_level='info',
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -55,23 +55,48 @@ class UsageInfo(BaseModel): ...@@ -55,23 +55,48 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel): class ChatCompletionRequestQos(BaseModel):
"""Chat completion request.""" """Chat completion request."""
model: str model: str
messages: Union[str, List[Dict[str, str]]] messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = 512 max_tokens: Optional[int] = Field(default=None, examples=[None])
stop: Optional[bool] = False stop: Optional[bool] = False
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None user: Optional[str] = None
user_id: Optional[str] = None
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
top_k: Optional[int] = 40
class ChatCompletionRequest(BaseModel):
"""Chat completion request."""
model: str
# yapf: disable
messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]]) # noqa
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = Field(default=None, examples=[None])
stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None]) # noqa
# yapf: enable
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
# additional argument of lmdeploy # additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1 session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
skip_special_tokens: Optional[bool] = True
top_k: Optional[int] = 40
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
...@@ -120,6 +145,31 @@ class ChatCompletionStreamResponse(BaseModel): ...@@ -120,6 +145,31 @@ class ChatCompletionStreamResponse(BaseModel):
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
"""Completion request."""
model: str
prompt: Union[str, List[Any]]
suffix: Optional[str] = None
temperature: Optional[float] = 0.7
n: Optional[int] = 1
max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = Field(default=None,
examples=[None])
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
logprobs: Optional[int] = None
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
skip_special_tokens: Optional[bool] = True
top_k: Optional[int] = 40 # for opencompass
class CompletionRequestQos(BaseModel):
"""Completion request.""" """Completion request."""
model: str model: str
prompt: Union[str, List[Any]] prompt: Union[str, List[Any]]
...@@ -136,9 +186,11 @@ class CompletionRequest(BaseModel): ...@@ -136,9 +186,11 @@ class CompletionRequest(BaseModel):
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None user: Optional[str] = None
# additional argument of lmdeploy # additional argument of lmdeploy
top_k: Optional[int] = 40
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1 session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
user_id: Optional[str] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
...@@ -205,6 +257,25 @@ class EncodeResponse(BaseModel): ...@@ -205,6 +257,25 @@ class EncodeResponse(BaseModel):
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
"""Generate request."""
prompt: Union[str, List[Dict[str, Any]]]
session_id: int = -1
interactive_mode: bool = False
stream: bool = False
stop: Optional[Union[str, List[str]]] = Field(default=None,
examples=[None])
request_output_len: Optional[int] = Field(default=None,
examples=[None]) # noqa
top_p: float = 0.8
top_k: int = 40
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
skip_special_tokens: Optional[bool] = True
cancel: Optional[bool] = False # cancel a responding request
class GenerateRequestQos(BaseModel):
"""Generate request.""" """Generate request."""
prompt: Union[str, List[Dict[str, str]]] prompt: Union[str, List[Dict[str, str]]]
session_id: int = -1 session_id: int = -1
...@@ -217,10 +288,13 @@ class GenerateRequest(BaseModel): ...@@ -217,10 +288,13 @@ class GenerateRequest(BaseModel):
temperature: float = 0.8 temperature: float = 0.8
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
ignore_eos: bool = False ignore_eos: bool = False
user_id: Optional[str] = None
class GenerateResponse(BaseModel): class GenerateResponse(BaseModel):
"""Generate response.""" """Generate response."""
text: str text: str
tokens: int tokens: int
input_tokens: int
history_tokens: int
finish_reason: Optional[Literal['stop', 'length']] = None finish_reason: Optional[Literal['stop', 'length']] = None
...@@ -18,7 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse ...@@ -18,7 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor, from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor,
prepare_tensor) prepare_tensor)
from lmdeploy.utils import filter_suffix from lmdeploy.utils import filter_suffix, get_logger
@dataclass @dataclass
...@@ -51,13 +51,6 @@ def stream_callback(que, result, error): ...@@ -51,13 +51,6 @@ def stream_callback(que, result, error):
que.put(result.get_response(as_json=True)) que.put(result.get_response(as_json=True))
def get_logger(log_file=None, log_level=logging.INFO):
"""Return the logger."""
from lmdeploy.utils import get_logger
logger = get_logger('service.ft', log_file=log_file, log_level=log_level)
return logger
class Chatbot: class Chatbot:
"""Chatbot for LLaMA series models with turbomind as inference engine. """Chatbot for LLaMA series models with turbomind as inference engine.
...@@ -75,6 +68,10 @@ class Chatbot: ...@@ -75,6 +68,10 @@ class Chatbot:
ignore_eos: bool = False, ignore_eos: bool = False,
log_level: int = logging.INFO, log_level: int = logging.INFO,
display: bool = False, display: bool = False,
top_p: float = 1.0,
top_k: int = 1,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
**model_kwargs): **model_kwargs):
self.tritonserver_addr = tritonserver_addr self.tritonserver_addr = tritonserver_addr
self.model_name = model_name self.model_name = model_name
...@@ -97,10 +94,10 @@ class Chatbot: ...@@ -97,10 +94,10 @@ class Chatbot:
self.eos_id = -1 self.eos_id = -1
self.cfg = mmengine.Config( self.cfg = mmengine.Config(
dict(session_len=self.model.session_len, dict(session_len=self.model.session_len,
top_p=self.model.top_p, top_p=top_p,
top_k=self.model.top_k, top_k=top_k,
temperature=self.model.temperature, temperature=temperature,
repetition_penalty=self.model.repetition_penalty, repetition_penalty=repetition_penalty,
stop_words=stop_words, stop_words=stop_words,
bad_words=bad_words)) bad_words=bad_words))
self.log_level = log_level self.log_level = log_level
...@@ -113,6 +110,7 @@ class Chatbot: ...@@ -113,6 +110,7 @@ class Chatbot:
request_output_len: int = None, request_output_len: int = None,
sequence_start: bool = False, sequence_start: bool = False,
sequence_end: bool = False, sequence_end: bool = False,
skip_special_tokens: bool = True,
*args, *args,
**kwargs): **kwargs):
"""Start a new round conversion of a session. """Start a new round conversion of a session.
...@@ -124,13 +122,15 @@ class Chatbot: ...@@ -124,13 +122,15 @@ class Chatbot:
request_output_len (int): the expected generated token numbers request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns: Returns:
iterator: The generated content by chatbot iterator: The generated content by chatbot
""" """
assert isinstance(session_id, int), \ assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}' f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level) logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, ' logger.info(f'session {session_id}, request_id {request_id}, '
f'request_output_len {request_output_len}') f'request_output_len {request_output_len}')
...@@ -149,11 +149,13 @@ class Chatbot: ...@@ -149,11 +149,13 @@ class Chatbot:
self.cfg.update(**kwargs) self.cfg.update(**kwargs)
self._session.prompt = self._get_prompt(prompt, sequence_start) self._session.prompt = self._get_prompt(prompt, sequence_start)
for status, res, tokens in self._stream_infer(self._session, for status, res, tokens in self._stream_infer(
self._session.prompt, self._session,
request_output_len, self._session.prompt,
sequence_start, request_output_len,
sequence_end): sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words) res = filter_suffix(res, self.model.stop_words)
if status.value < 0: if status.value < 0:
...@@ -180,7 +182,7 @@ class Chatbot: ...@@ -180,7 +182,7 @@ class Chatbot:
assert isinstance(session_id, int), \ assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}' f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level) logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'end session: {session_id}') logger.info(f'end session: {session_id}')
if self._session is None: if self._session is None:
...@@ -218,7 +220,7 @@ class Chatbot: ...@@ -218,7 +220,7 @@ class Chatbot:
""" """
assert isinstance(session_id, int), \ assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}' f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level) logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'cancel session: {session_id}') logger.info(f'cancel session: {session_id}')
if self._session is None: if self._session is None:
...@@ -267,7 +269,7 @@ class Chatbot: ...@@ -267,7 +269,7 @@ class Chatbot:
assert isinstance(session_id, int), \ assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}' f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level) logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'resume session: {session_id}') logger.info(f'resume session: {session_id}')
if self._session is None: if self._session is None:
...@@ -301,6 +303,7 @@ class Chatbot: ...@@ -301,6 +303,7 @@ class Chatbot:
request_output_len: int = None, request_output_len: int = None,
sequence_start: bool = False, sequence_start: bool = False,
sequence_end: bool = False, sequence_end: bool = False,
skip_special_tokens: bool = True,
*args, *args,
**kwargs): **kwargs):
"""Start a new round conversion of a session. Return the chat """Start a new round conversion of a session. Return the chat
...@@ -313,6 +316,8 @@ class Chatbot: ...@@ -313,6 +316,8 @@ class Chatbot:
request_output_len (int): the expected generated token numbers request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns: Returns:
tuple(Status, str, int): status, text/chat completion, tuple(Status, str, int): status, text/chat completion,
generated token number generated token number
...@@ -320,7 +325,7 @@ class Chatbot: ...@@ -320,7 +325,7 @@ class Chatbot:
assert isinstance(session_id, int), \ assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}' f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level) logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, ' logger.info(f'session {session_id}, request_id {request_id}, '
f'request_output_len {request_output_len}') f'request_output_len {request_output_len}')
...@@ -338,11 +343,13 @@ class Chatbot: ...@@ -338,11 +343,13 @@ class Chatbot:
self._session.prompt = self._get_prompt(prompt, sequence_start) self._session.prompt = self._get_prompt(prompt, sequence_start)
status, res, tokens = None, '', 0 status, res, tokens = None, '', 0
for status, res, tokens in self._stream_infer(self._session, for status, res, tokens in self._stream_infer(
self._session.prompt, self._session,
request_output_len, self._session.prompt,
sequence_start, request_output_len,
sequence_end): sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
if status.value < 0: if status.value < 0:
break break
if status == StatusCode.TRITON_STREAM_END: # remove stop_words if status == StatusCode.TRITON_STREAM_END: # remove stop_words
...@@ -420,6 +427,7 @@ class Chatbot: ...@@ -420,6 +427,7 @@ class Chatbot:
request_output_len: int = 512, request_output_len: int = 512,
sequence_start: bool = True, sequence_start: bool = True,
sequence_end: bool = False, sequence_end: bool = False,
skip_special_tokens: bool = True,
cancel: bool = False): cancel: bool = False):
"""communicate with inference server to chat, or cancel a session, or """communicate with inference server to chat, or cancel a session, or
end a session. end a session.
...@@ -431,10 +439,12 @@ class Chatbot: ...@@ -431,10 +439,12 @@ class Chatbot:
sequence_start (bool): indicator for starting a sequence sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence sequence_end (bool): indicator for ending a sequence
cancel (bool): indicator for cancelling the session cancel (bool): indicator for cancelling the session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields: Yields:
tuple: status, text, generated token number tuple: status, text, generated token number
""" """
logger = get_logger(log_level=self.log_level) logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'session {session.session_id}, ' logger.info(f'session {session.session_id}, '
f'request id {session.request_id}, ' f'request id {session.request_id}, '
f'request_output_len {request_output_len}, ' f'request_output_len {request_output_len}, '
...@@ -498,7 +508,8 @@ class Chatbot: ...@@ -498,7 +508,8 @@ class Chatbot:
producer.start() producer.start()
for status, res, n_token in self.stream_consumer( for status, res, n_token in self.stream_consumer(
self.postprocess, que, session, input_tokens, preseq_length, self.postprocess, que, session, input_tokens, preseq_length,
cancel, logger, self.display, self.eos_id): cancel, logger, self.display, self.eos_id,
skip_special_tokens):
yield status, res, n_token yield status, res, n_token
producer.join() producer.join()
...@@ -591,7 +602,8 @@ class Chatbot: ...@@ -591,7 +602,8 @@ class Chatbot:
@staticmethod @staticmethod
def stream_consumer(postprocess, res_queue, session, n_input_token, def stream_consumer(postprocess, res_queue, session, n_input_token,
preseq_length, cancel, logger, display, eos_id): preseq_length, cancel, logger, display, eos_id,
skip_special_tokens):
"""Consume the response from the triton inference server. """Consume the response from the triton inference server.
Args: Args:
...@@ -605,11 +617,15 @@ class Chatbot: ...@@ -605,11 +617,15 @@ class Chatbot:
logger (util.Logger): logger (util.Logger):
display (bool): display the text in the consolo interface or not display (bool): display the text in the consolo interface or not
eos_id (int): eos token id eos_id (int): eos token id
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields: Yields:
tuple: status, text, generated token number tuple: status, text, generated token number
""" """
status, res, n_token = None, '', 0 status, res, n_token = None, '', 0
output_ids = np.zeros((1, 1, 0), dtype=np.uint32)
text = ''
while True: while True:
result = res_queue.get() result = res_queue.get()
if result is None: if result is None:
...@@ -648,7 +664,8 @@ class Chatbot: ...@@ -648,7 +664,8 @@ class Chatbot:
output_ids = output_ids[:, :, :-1] output_ids = output_ids[:, :, :-1]
output_str = postprocess( output_str = postprocess(
output_ids, np.array([[n_token]], dtype=np.uint32)) output_ids, np.array([[n_token]], dtype=np.uint32),
np.array([[int(skip_special_tokens)]], dtype=np.int32))
text = output_str[0].decode() text = output_str[0].decode()
# utf-8 char at the end means it's a potential unfinished # utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next # byte sequence, continue to concate it with the next
......
...@@ -84,10 +84,13 @@ class TritonPythonModel: ...@@ -84,10 +84,13 @@ class TritonPythonModel:
request, 'TOKENS_BATCH').as_numpy() request, 'TOKENS_BATCH').as_numpy()
sequence_length = pb_utils.get_input_tensor_by_name( sequence_length = pb_utils.get_input_tensor_by_name(
request, 'sequence_length').as_numpy() request, 'sequence_length').as_numpy()
skip_special_tokens = pb_utils.get_input_tensor_by_name(
request, 'skip_special_tokens').as_numpy()
# Postprocessing output data. # Postprocessing output data.
outputs = self._postprocessing(tokens_batch.tolist(), outputs = self._postprocessing(tokens_batch.tolist(),
sequence_length) sequence_length,
skip_special_tokens)
# Create output tensors. You need pb_utils.Tensor # Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse. # objects to create pb_utils.InferenceResponse.
...@@ -118,12 +121,16 @@ class TritonPythonModel: ...@@ -118,12 +121,16 @@ class TritonPythonModel:
""" """
print('Cleaning up...') print('Cleaning up...')
def _postprocessing(self, tokens_batch, sequence_length): def _postprocessing(self, tokens_batch, sequence_length,
skip_special_tokens):
"""decode token ids into texts.""" """decode token ids into texts."""
outputs = [] outputs = []
for beam_tokens, beam_len in zip(tokens_batch, sequence_length): for beam_tokens, beam_len, beam_skip_special in zip(
for tokens, _len in zip(beam_tokens, beam_len): tokens_batch, sequence_length, skip_special_tokens):
output = self.tokenizer.decode(tokens, _len) for tokens, _len, skip_special in zip(beam_tokens, beam_len,
beam_skip_special):
output = self.tokenizer.decode(
tokens, _len, skip_special_tokens=bool(skip_special))
output = output.encode('utf8') output = output.encode('utf8')
outputs.append(output) outputs.append(output)
return outputs return outputs
...@@ -11,6 +11,11 @@ input [ ...@@ -11,6 +11,11 @@ input [
name: "sequence_length" name: "sequence_length"
data_type: TYPE_UINT32 data_type: TYPE_UINT32
dims: [ -1 ] dims: [ -1 ]
},
{
name: "skip_special_tokens"
data_type: TYPE_INT32
dims: [ -1 ]
} }
] ]
output [ output [
......
...@@ -72,22 +72,29 @@ class Postprocessor: ...@@ -72,22 +72,29 @@ class Postprocessor:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.infer(*args, **kwargs) return self.infer(*args, **kwargs)
def infer(self, output_ids: np.ndarray, seqlen: np.ndarray): def infer(self,
output_ids: np.ndarray,
seqlen: np.ndarray,
skip_special_tokens: bool = True):
"""De-tokenize tokens for text. """De-tokenize tokens for text.
Args: Args:
output_ids(np.ndarray): tokens' id output_ids(np.ndarray): tokens' id
seqlen(np.ndarray): sequence length seqlen(np.ndarray): sequence length
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns: Returns:
str: decoded tokens str: decoded tokens
""" """
inputs = [ inputs = [
prepare_tensor('TOKENS_BATCH', output_ids), prepare_tensor('TOKENS_BATCH', output_ids),
prepare_tensor('sequence_length', seqlen) prepare_tensor('sequence_length', seqlen),
prepare_tensor('skip_special_tokens', skip_special_tokens)
] ]
inputs[0].set_data_from_numpy(output_ids) inputs[0].set_data_from_numpy(output_ids)
inputs[1].set_data_from_numpy(seqlen) inputs[1].set_data_from_numpy(seqlen)
inputs[2].set_data_from_numpy(skip_special_tokens)
model_name = 'postprocessing' model_name = 'postprocessing'
with grpcclient.InferenceServerClient(self.tritonserver_addr) \ with grpcclient.InferenceServerClient(self.tritonserver_addr) \
as client: as client:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import json import json
import os
import os.path as osp import os.path as osp
from typing import Optional, Sequence, Union from collections import deque
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Union
import torch import torch
from lmdeploy.utils import get_logger
# this file will be copied to triton server, make sure all
# importing are starting from the package root lmdeploy
@dataclass
class DetokenizeState:
"""A state collection of incrementally detekenization.
Args:
ids_offset (int): offset to all input ids. In LMDeploy, the output
ids length is not one by one. It could be random by random.
prev_tokens (List[str] | None): for incrementally decoding.
Default to None, which means the first round.
prefix_offset (int): the start index of tokens to be converted to
string (prev + new tokens). Default to 0 for the first round.
read_offset (int): the end index of tokens to be converted to
string (prev token). Default to 0 for the first round.
"""
ids_offset: int = 0
prev_tokens: Optional[List[str]] = None
prefix_offset: int = 0
read_offset: int = 0
def as_tuple(self) -> Tuple:
"""Return a tuple of states."""
return (self.ids_offset, self.prev_tokens, self.prefix_offset,
self.read_offset)
class SentencePieceTokenizer: class SentencePieceTokenizer:
"""Tokenizer of sentencepiece. """Tokenizer of sentencepiece.
...@@ -18,6 +49,12 @@ class SentencePieceTokenizer: ...@@ -18,6 +49,12 @@ class SentencePieceTokenizer:
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file) self.model = SentencePieceProcessor(model_file=model_file)
self._prefix_space_tokens = None self._prefix_space_tokens = None
# for stop words
self._maybe_decode_bytes: bool = None
# TODO maybe lack a constant.py
self._indexes_tokens_deque = deque(maxlen=10)
self.max_indexes_num = 5
self.logger = get_logger('lmdeploy')
@property @property
def vocab_size(self): def vocab_size(self):
...@@ -53,6 +90,27 @@ class SentencePieceTokenizer: ...@@ -53,6 +90,27 @@ class SentencePieceTokenizer:
else: else:
return decoded return decoded
def indexes_containing_token(self, token: str):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for _token, _indexes in self._indexes_tokens_deque:
if token == _token:
return _indexes
if token == ' ': # ' ' is special
token = '▁'
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
indexes = [i for i, voc in enumerate(vocab) if token in voc]
if len(indexes) > self.max_indexes_num:
indexes = self.encode(token, add_bos=False)[-1:]
self.logger.warning(
f'There are too many(>{self.max_indexes_num}) possible '
f'indexes may decoding {token}, we will use {indexes} only')
self._indexes_tokens_deque.append((token, indexes))
return indexes
def encode(self, s: str, add_bos: bool = True, **kwargs): def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt. """Tokenize a prompt.
...@@ -63,13 +121,18 @@ class SentencePieceTokenizer: ...@@ -63,13 +121,18 @@ class SentencePieceTokenizer:
""" """
return self.model.Encode(s, add_bos=add_bos, **kwargs) return self.model.Encode(s, add_bos=add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None): def decode(self,
t: Sequence[int],
offset: Optional[int] = None,
skip_special_tokens: bool = True,
**kwargs):
"""De-tokenize. """De-tokenize.
Args: Args:
t (List[int]): a list of token ids t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which offset (int): for incrementally decoding. Default to None, which
means not applied. means not applied.
skip_special_tokens (boo): not used in SentencePieceTokenizer.
Returns: Returns:
str: text of decoding tokens str: text of decoding tokens
""" """
...@@ -81,6 +144,34 @@ class SentencePieceTokenizer: ...@@ -81,6 +144,34 @@ class SentencePieceTokenizer:
out_string = self._maybe_add_prefix_space(t, out_string) out_string = self._maybe_add_prefix_space(t, out_string)
return out_string return out_string
def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
out_string = self.model.Decode(all_input_ids)
if state.prev_tokens is not None:
out_string = self._maybe_add_prefix_space(all_input_ids,
out_string)
state.prev_tokens = [] # not None for the above condition
return out_string, state
def __call__(self, s: Union[str, Sequence[str]]): def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts. """Tokenize prompts.
...@@ -106,20 +197,10 @@ class HuggingFaceTokenizer: ...@@ -106,20 +197,10 @@ class HuggingFaceTokenizer:
def __init__(self, model_dir: str): def __init__(self, model_dir: str):
from transformers import AutoTokenizer from transformers import AutoTokenizer
model_file = osp.join(model_dir, 'tokenizer.model') self.logger = get_logger('lmdeploy')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file)
if not osp.exists(backend_tokenizer_file) and model_file_exists:
print('WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir, self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True) trust_remote_code=True)
self._prefix_space_tokens = None self._prefix_space_tokens = None
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
if hasattr(self.model, 'backend_tokenizer'):
if os.access(model_dir, os.W_OK):
self.model.backend_tokenizer.save(backend_tokenizer_file)
if self.model.eos_token_id is None: if self.model.eos_token_id is None:
generation_config_file = osp.join(model_dir, generation_config_file = osp.join(model_dir,
...@@ -131,11 +212,27 @@ class HuggingFaceTokenizer: ...@@ -131,11 +212,27 @@ class HuggingFaceTokenizer:
elif hasattr(self.model, 'eod_id'): # Qwen remote elif hasattr(self.model, 'eod_id'): # Qwen remote
self.model.eos_token_id = self.model.eod_id self.model.eos_token_id = self.model.eod_id
# for stop words
self._vocab_size_with_added: int = None
self._maybe_decode_bytes: bool = None
# TODO maybe lack a constant.py
self._indexes_tokens_deque = deque(maxlen=10)
self.max_indexes_num = 5
self.token2id = {}
@property @property
def vocab_size(self): def vocab_size(self):
"""vocabulary size.""" """vocabulary size."""
return self.model.vocab_size return self.model.vocab_size
@property
def vocab_size_with_added(self):
"""vocabulary size with added vocab."""
if self._vocab_size_with_added is not None:
return self._vocab_size_with_added
self._vocab_size_with_added = len(self.model.get_vocab())
return self._vocab_size_with_added
@property @property
def bos_token_id(self): def bos_token_id(self):
"""begine of the sentence token id.""" """begine of the sentence token id."""
...@@ -159,7 +256,7 @@ class HuggingFaceTokenizer: ...@@ -159,7 +256,7 @@ class HuggingFaceTokenizer:
} }
return self._prefix_space_tokens return self._prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded): def _maybe_add_prefix_space(self, tokens: List[int], decoded: str):
"""maybe add prefix space for incremental decoding.""" """maybe add prefix space for incremental decoding."""
if len(tokens) and not decoded.startswith(' ') and\ if len(tokens) and not decoded.startswith(' ') and\
tokens[0] in self.prefix_space_tokens: tokens[0] in self.prefix_space_tokens:
...@@ -167,6 +264,66 @@ class HuggingFaceTokenizer: ...@@ -167,6 +264,66 @@ class HuggingFaceTokenizer:
else: else:
return decoded return decoded
@property
def maybe_decode_bytes(self):
"""Check if self.model.convert_ids_to_tokens return not a str value."""
if self._maybe_decode_bytes is None:
self._maybe_decode_bytes = False
vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size)))
for tok in vocab:
if not isinstance(tok, str):
self._maybe_decode_bytes = True
break
return self._maybe_decode_bytes
def indexes_containing_token(self, token: str):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for _token, _indexes in self._indexes_tokens_deque:
if token == _token:
return _indexes
if self.token2id == {}:
# decode is slower than convert_ids_to_tokens
if self.maybe_decode_bytes:
try:
self.token2id = {
self.model.decode(i): i
for i in range(self.vocab_size)
}
except Exception as e:
# qwen-vl
assert str(e) == 'Unclosed image token'
else:
self.token2id = {
self.model.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
if token == ' ': # ' ' is special
token = '▁'
indexes = [i for _token, i in self.token2id.items() if token in _token]
if len(indexes) > self.max_indexes_num:
# multiple id decode to same token
indexes = [i for i in indexes if self.decode([i]) == token]
indexes = indexes[:self.max_indexes_num]
self.logger.warning(
f'There are too many(>{self.max_indexes_num}) possible '
f'indexes may decoding {token}, we will use {indexes} only')
# there might be token id that exceeds self.vocab_size
if len(indexes) == 0:
indexes = self.encode(token, False)
if len(indexes) != 1:
self.logger.warning(
f'The token {token}, its length of indexes {indexes} is '
'not 1. Currently, it can not be used as stop words')
indexes = []
self._indexes_tokens_deque.append((token, indexes))
return indexes
def encode(self, s: str, add_bos: bool = True, **kwargs): def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt. """Tokenize a prompt.
...@@ -182,7 +339,10 @@ class HuggingFaceTokenizer: ...@@ -182,7 +339,10 @@ class HuggingFaceTokenizer:
encoded = encoded[1:] encoded = encoded[1:]
return encoded return encoded
def decode(self, t: Sequence[int], offset: Optional[int] = None): def decode(self,
t: Sequence[int],
offset: Optional[int] = None,
skip_special_tokens: bool = True):
"""De-tokenize. """De-tokenize.
Args: Args:
...@@ -192,14 +352,121 @@ class HuggingFaceTokenizer: ...@@ -192,14 +352,121 @@ class HuggingFaceTokenizer:
Returns: Returns:
str: text of decoding tokens str: text of decoding tokens
""" """
skip_special_tokens = True
t = t[offset:] t = t[offset:]
out_string = self.model.decode(t, out_string = self.model.decode(t,
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens)
if offset: if offset:
logger = get_logger('lmdeploy')
logger.warning('For incrementally detokenization, please try '
'detokenize_incrementally function instead.')
out_string = self._maybe_add_prefix_space(t, out_string) out_string = self._maybe_add_prefix_space(t, out_string)
return out_string return out_string
@staticmethod
def _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
if tokenizer.is_fast or not tokenizer.get_added_vocab():
return tokenizer.convert_tokens_to_string(output_tokens)
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99
sub_texts = []
current_sub_text = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(
current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return ' '.join(sub_texts)
else:
return ''.join(sub_texts)
# Based on
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165
def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
tokenizer = self.model
ids_offset, prev_tokens, prefix_offset, read_offset = state.as_tuple()
# This is the first iteration for this sequence
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids[ids_offset:],
skip_special_tokens=skip_special_tokens)
if prev_tokens is None:
# Please notice that in VLLM, indexes are detokenized one by one
# while in LMDeploy, every turn, the detokenized indexes length
# can be different.
if skip_special_tokens and new_tokens and new_tokens[
0] in tokenizer.all_special_ids:
read_offset = 1 # skip special token
output_tokens = new_tokens
prev_tokens = new_tokens
else:
# Put new_token_id in a list so skip_special_tokens is respected
output_tokens = prev_tokens + new_tokens
prev_tokens += new_tokens
prefix_text = self._convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = self._convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
# update state and get final decoded output
if len(new_text) > len(prefix_text) and not new_text.endswith('�'):
# utf-8 char at the end means it's a potential unfinished byte
# sequence from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
prefix_offset = read_offset
read_offset = len(output_tokens)
new_text = new_text[len(prefix_text):]
else:
new_text = ''
return new_text, DetokenizeState(len(all_input_ids), prev_tokens,
prefix_offset, read_offset)
def __call__(self, s: Union[str, Sequence[str]]): def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts. """Tokenize prompts.
...@@ -230,7 +497,7 @@ class Tokenizer: ...@@ -230,7 +497,7 @@ class Tokenizer:
model_file_exists = osp.exists(model_file) model_file_exists = osp.exists(model_file)
config_exists = osp.exists(tokenizer_config_file) config_exists = osp.exists(tokenizer_config_file)
use_hf_model = config_exists or not model_file_exists use_hf_model = config_exists or not model_file_exists
self.logger = get_logger('lmdeploy')
if not use_hf_model: if not use_hf_model:
self.model = SentencePieceTokenizer(model_file) self.model = SentencePieceTokenizer(model_file)
else: else:
...@@ -261,7 +528,12 @@ class Tokenizer: ...@@ -261,7 +528,12 @@ class Tokenizer:
""" """
return self.model.encode(s, add_bos, **kwargs) return self.model.encode(s, add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None): def decode(
self,
t: Sequence[int],
offset: Optional[int] = None,
skip_special_tokens: bool = True,
):
"""De-tokenize. """De-tokenize.
Args: Args:
...@@ -271,7 +543,34 @@ class Tokenizer: ...@@ -271,7 +543,34 @@ class Tokenizer:
Returns: Returns:
str: text of decoding tokens str: text of decoding tokens
""" """
return self.model.decode(t, offset) return self.model.decode(t, offset, skip_special_tokens)
def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
return self.model.detokenize_incrementally(
all_input_ids,
state=state,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens)
def __call__(self, s: Union[str, Sequence[str]]): def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts. """Tokenize prompts.
...@@ -282,3 +581,14 @@ class Tokenizer: ...@@ -282,3 +581,14 @@ class Tokenizer:
list[int]: token ids list[int]: token ids
""" """
return self.model(s) return self.model(s)
def indexes_containing_token(self, token):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
encoded = self.encode(token, add_bos=False)
if len(encoded) > 1:
self.logger.warning(
f'The token {token}, its length of indexes {encoded} is over '
'than 1. Currently, it can not be used as stop words')
return []
return self.model.indexes_containing_token(token)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment