Commit d7117b95 authored by zhouxiang's avatar zhouxiang
Browse files

同步0.2.6代码

parent 5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
"""Helpers for parallel and distributed inference."""
import functools
import os
import torch
from torch.distributed import broadcast, broadcast_object_list, is_initialized
def get_local_rank():
"""Get local rank of current process.
Assume environment variable ``LOCAL_RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
""" # noqa: E501
return int(os.getenv('LOCAL_RANK', '0'))
def get_rank():
"""Get rank of current process.
Assume environment variable ``RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
""" # noqa: E501
return int(os.getenv('RANK', '0'))
def get_world_size():
"""Get rank of current process.
Assume environment variable ``WORLD_SIZE`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
""" # noqa: E501
return int(os.getenv('WORLD_SIZE', '1'))
def master_only(func):
"""Decorator to run a function only on the master process."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_initialized():
if get_rank() != 0:
return None
return func(*args, **kwargs)
return wrapper
def master_only_and_broadcast_general(func):
"""Decorator to run a function only on the master process and broadcast the
result to all processes."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_initialized():
if get_rank() == 0:
result = [func(*args, **kwargs)]
else:
result = [None]
broadcast_object_list(result, src=0)
result = result[0]
else:
result = func(*args, **kwargs)
return result
return wrapper
def master_only_and_broadcast_tensor(func):
"""Decorator to run a function only on the master process and broadcast the
result to all processes.
Note: Require CUDA tensor.
Note: Not really work because we don't know the shape aforehand,
for cpu tensors, use master_only_and_broadcast_general
"""
@functools.wraps(func)
def wrapper(*args, size, dtype, **kwargs):
if is_initialized():
if get_rank() == 0:
result = func(*args, **kwargs)
else:
result = torch.empty(size=size,
dtype=dtype,
device=get_local_rank())
broadcast(result, src=0)
# print(f'rank {get_rank()} received {result}')
else:
result = func(*args, **kwargs)
return result
return wrapper
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import time
import warnings
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .dist import get_local_rank
logger = logging.getLogger(__name__)
class LoadWoInit:
"""Context manager that disable parameter initialization."""
def __init__(self):
self.constant_ = torch.nn.init.constant_
self.zeros_ = torch.nn.init.zeros_
self.ones_ = torch.nn.init.ones_
self.uniform_ = torch.nn.init.uniform_
self.normal_ = torch.nn.init.normal_
self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_
self.kaiming_normal_ = torch.nn.init.kaiming_normal_
def __enter__(self, *args, **kwargs):
torch.nn.init.constant_ = lambda *args, **kwargs: None
torch.nn.init.zeros_ = lambda *args, **kwargs: None
torch.nn.init.ones_ = lambda *args, **kwargs: None
torch.nn.init.uniform_ = lambda *args, **kwargs: None
torch.nn.init.normal_ = lambda *args, **kwargs: None
torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None
torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
def __exit__(self, *args, **kwargs):
torch.nn.init.constant_ = self.constant_
torch.nn.init.zeros_ = self.zeros_
torch.nn.init.ones_ = self.ones_
torch.nn.init.uniform_ = self.uniform_
torch.nn.init.normal_ = self.normal_
torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_
torch.nn.init.kaiming_normal_ = self.kaiming_normal_
def init_model(model_path: str,
tokenizer_path: Optional[str] = None,
use_fast_tokenizer=True):
"""Initialize model and tokenizer from given model path.
Args:
model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
"""
start = time.monotonic()
if not tokenizer_path:
tokenizer_path = model_path
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
use_fast=use_fast_tokenizer,
trust_remote_code=True)
with LoadWoInit():
model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.float16,
trust_remote_code=True)
logger.info(f'Model loaded in {time.monotonic() - start:.1f} seconds')
logger.info(f'Model loaded from {model_path}')
logger.debug(model)
return model, tokenizer
def accel_model(model,
accel: Optional[str] = None,
gpu_id=None,
max_alloc=2048,
tp_size=1):
"""Accelerate model with given accelerator.
Note:
Currently we support only deepspeed or just no acceleration.
"""
logger.info(f'Accelerate model with {accel}')
if accel is None:
# No acceleration, just to cuda
# assume single gpu single process
# user is responsible to assign the gpu id via CUDA_VISIBLE_DEVICES # noqa: E501
gpu_id = gpu_id if gpu_id is not None else get_local_rank()
model = model.cuda(gpu_id)
elif accel.lower() == 'deepspeed':
# Use deepspeed inference inject fast kernel and/or tensor parallel
try:
import deepspeed
except ImportError as e:
raise ImportError('--accel=deepspeed is specified but '
'deepspeed is not installed.\n'
'Install with `pip install deepspeed`.') from e
config = dict(
tensor_parallel=dict(tp_size=tp_size), # Use world size in general
dtype=torch.float16,
replace_with_kernel_inject=True,
max_out_tokens=max_alloc,
)
if 'InternLM' in model.__class__.__name__:
try:
# Use customized deepspeed supporting InternLM
# https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0 (commit cdef2ce) # noqa: E501
from deepspeed.module_inject.containers.internlm import \
InternLMLayerPolicy # noqa: E501
except ImportError:
# InternLM is not officially supported by DeepSpeed
# Set replace_with_kernel_inject=False to use AutoTP
config.update({'replace_with_kernel_inject': False})
warnings.warn(
'\033[0;93m'
'Current installation of deepspeed does not '
'support InternLM. Disable kernel injection. '
'To support InternLM, install customized deepspeed with '
'`pip install git+https://github.com/wangruohui/DeepSpeed@support_internlm_0.10.0`' # noqa: E501
'\033[0m')
else:
for module in model.modules():
# Since remote code is dynamically located,
# we need to do this dynamically
if module.__class__.__name__ == 'InternLMDecoderLayer':
InternLMLayerPolicy._orig_layer_class = module.__class__ # noqa: E501
break
logger.debug(f'Using deepspeed config\n{config}')
model = deepspeed.init_inference(
model=model, # Transformers models
config=config,
)
# for k, v in model.named_parameters():
# logger.debug(f"{k}: v.device")
else:
raise ValueError(f'Unsupported accelerator {accel}.')
logger.debug(model)
return model
# Copyright (c) OpenMMLab. All rights reserved.
from .linear import WeightOnlyQLinear
__all__ = ['WeightOnlyQLinear']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Type, TypeVar
import torch
from torch import nn
try:
import awq_inference_engine
except ModuleNotFoundError:
awq_inference_engine = None
class WeightOnlyQLinear(nn.Module):
"""This class implements weight only quantization linear.
Args:
w_bit (int): number of bits for quantization.
symmetry (bool): If true, use symmetric quantization,
otherwise use asymmetric quantization.
group_size (int): size of the quantization group.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (Tensor, optional): Defaults to None.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: Optional[torch.Tensor] = True,
w_bit: int = 4,
symmetry: bool = False,
group_size: int = 128,
) -> None:
super().__init__()
if w_bit not in [2, 4, 8]:
raise NotImplementedError('Only 2,4,8 bit are supported for now.')
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
w_pack_oc = out_features // (32 // self.w_bit)
w_inc = in_features
weight = torch.zeros((w_inc, w_pack_oc), dtype=torch.int32)
self.register_buffer('qweight', weight)
if bias:
self.register_buffer('bias', torch.zeros(out_features))
else:
self.bias = None
s_inc = in_features // self.group_size
s_oc = out_features
scales = torch.zeros((s_inc, s_oc), dtype=torch.float16)
self.register_buffer('scales', scales)
if not symmetry:
z_inc = in_features // self.group_size
z_oc = out_features // (32 // self.w_bit)
zeros = torch.zeros((z_inc, z_oc), dtype=torch.int32)
self.register_buffer('qzeros', zeros)
else:
self.qzeros = None
@classmethod
def from_linear(cls: Type['WeightOnlyQLinear'],
linear: nn.Linear,
quantizer: TypeVar('Quantizer'),
awq_layout: bool = True) -> 'WeightOnlyQLinear':
"""Create a WeightOnlyQLinear object from a PyTorch Linear object.
Args:
linear (nn.Linear): PyTorch Linear object.
quantizer (Quantizer): Object that handles quantization.
awq_layout (bool): AWQ layout. Defaults to True.
Returns:
WeightOnlyQLinear: A WeightOnlyQLinear object.
"""
device = linear.weight.device
w_bit = quantizer.bits
pack_num = 32 // w_bit
if awq_layout:
assert w_bit == 4
pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
else:
pack_order = torch.arange(pack_num)
group_size = quantizer.group_size
symmetry = quantizer.symmetry
in_features = linear.in_features
out_features = linear.out_features
bias = False if linear.bias is None else True
qlinear = cls(in_features, out_features, bias, w_bit, symmetry,
group_size)
qlinear.bias = linear.bias
qparams = quantizer.calculate_qparams(linear.weight)
i32_w = quantizer.quant(linear.weight, qparams, real=True)
i32_w = i32_w.t().contiguous()
pack_int_w = torch.zeros_like(qlinear.qweight).to(device)
for col in range(pack_int_w.shape[1]):
for i in range(pack_num):
pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
pack_int_w[:, col] |= pack_int_w_col << (i * w_bit)
qlinear.qweight = pack_int_w
qlinear.scales = qparams.scales.squeeze(-1).t().contiguous()
if qparams.zero_points is not None:
zeros = qparams.zero_points.to(torch.int32).to(device)
zeros = zeros.squeeze(-1).t().contiguous()
pack_int_zeros = torch.zeros_like(qlinear.qzeros).to(device)
for col in range(pack_int_zeros.shape[1]):
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + pack_order[i]]
pack_int_zeros[:, col] |= qzero_col << (i * w_bit)
qlinear.qzeros = pack_int_zeros
qlinear.to('cpu')
return qlinear
@torch.no_grad()
def forward(self, x):
if awq_inference_engine is None:
raise RuntimeError(
'Run the following command to install '
'the kernel for 4bit inference\n\n'
'git clone https://github.com/mit-han-lab/llm-awq.git\n'
'cd awq/kernels\n'
'python setup.py install\n')
out_shape = x.shape[:-1] + (self.out_features, )
inputs = x.reshape(-1, x.shape[-1])
out = awq_inference_engine.gemm_forward_cuda(inputs.half(),
self.qweight,
self.scales.half(),
self.qzeros,
self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch
from transformers.generation.utils import ModelOutput
logger = logging.getLogger(__name__)
class BasicSessionManager:
"""Basic session manager without history."""
def prepend_history(self, input_ids):
return input_ids
def add_to_history(self, output):
pass
class BasicSessionManagerWithHistory:
"""Basic session manager with chat history.
Args:
max_session_len (int): Maximum number of tokens allowed for all chat sessions.
reduce_size (int): Number of tokens to be trimmed when reaching maximum
session length. Default: 256.
start_ids (list[int]): Sequences of ids at the start of the chat session.
sep_ids (list[int]): Sequences of ids separating chat sessions.
""" # noqa: E501
bs = 1
def __init__(self,
max_session_len=2048,
reduce_size=256,
start_ids=[1],
sep_ids=[13]) -> None:
self.start_ids = torch.tensor(start_ids, dtype=torch.long)
self.sep_ids = torch.tensor(sep_ids, dtype=torch.long)
assert self.start_ids.ndim == 1
assert self.sep_ids.ndim == 1
self.max_session_len = max(len(start_ids), max_session_len)
self.reduce_size = min(reduce_size, max_session_len - len(start_ids))
assert self.max_session_len > self.reduce_size
self.new_session()
def new_session(self):
self.history_ids = self.start_ids.repeat(self.bs, 1)
def prepend_history(self, input_ids: torch.Tensor):
"""Prepend history ids to input ids and trim if over-length."""
input_ids = input_ids.to(self.history_ids.device).long()
sep_ids = self.sep_ids.to(self.history_ids.device).long().repeat(1, 1)
input_ids = torch.cat([self.history_ids, sep_ids, input_ids], dim=1)
if input_ids.shape[1] > self.max_session_len:
input_ids = input_ids[:,
(self.reduce_size - self.max_session_len):]
input_ids[:, :len(self.start_ids)] = self.start_ids.repeat(
self.bs, 1)
return input_ids
def add_to_history(self, output):
"""Save history output ids.
Note:
Output returned by HuggingFace generator contains both input
and output ids.
"""
if isinstance(output, ModelOutput):
self.history_ids = output.sequences
elif isinstance(output, torch.Tensor):
self.history_ids = output
else:
raise ValueError(f'Unknown output type {type(output)}')
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import inspect
from inspect import Parameter, Signature
from typing import Dict, Sequence
import logging
import psutil
from transformers.generation.streamers import BaseStreamer
from .dist import get_rank, master_only, master_only_and_broadcast_general
def get_gpu_memory(id: int = 0) -> int:
"""Returns the free and total physical memory of the GPU in bytes."""
import torch
return torch.cuda.mem_get_info(id)
try:
import readline # To support command line history # noqa: F401
except ImportError: # readline not available
pass
logger = logging.getLogger(__name__)
def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total
class TerminalIO:
"""Terminal input and output."""
def bind_sigature(input_names: str, args: Sequence, kwargs: Dict):
"""Bind args and kwargs to given input names."""
kind = inspect._ParameterKind.POSITIONAL_OR_KEYWORD
end_of_output = '\n'
@master_only_and_broadcast_general
def input(self):
"""Read input from terminal."""
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
try:
return '\n'.join(iter(input, sentinel))
except EOFError:
print('Detect EOF, exit')
exit()
@master_only
def output(self, string):
"""Output to terminal with flush."""
print(string, end='', flush=True)
class BasicStreamer(BaseStreamer):
"""Basic streamer for HuggingFace models."""
def __init__(self,
decode_func,
output_func,
end_of_output='\n',
skip_prompt=True):
self.decode = decode_func
self.output = output_func
self.end_of_output = end_of_output
self.skip_prompt = skip_prompt
self.gen_len = 0
def put(self, value):
"""Callback before forwarding current token id to model."""
if self.gen_len == 0 and self.skip_prompt:
pass
else:
token = self.decode(value)
self.output(token)
self.gen_len += 1
def end(self):
"""Callback at the end of generation."""
self.output(self.end_of_output)
self.gen_len = 0
def control(prompt, gen_config, sm):
"""Allow user to control generation config and session manager.
Return:
True if control command applied, False otherwise.
"""
if prompt == 'exit':
exit(0)
if prompt == 'clear':
sm.new_session()
logger.info('Session cleared')
return True
# Re-config during runtime
if prompt.startswith('config set'):
try:
keqv = prompt.split()[-1]
k, v = keqv.split('=')
v = eval(v)
gen_config.__setattr__(k, v)
logger.info(f'Worker {get_rank()} set {k} to {repr(v)}')
logger.info(f'Generator config changed to: {gen_config}')
return True
except: # noqa
logger.info(
'illegal instruction, treated as normal conversation. ')
return False
sig = Signature([Parameter(name, kind) for name in input_names])
bind = sig.bind(*args, **kwargs)
return bind.arguments
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import dataclasses
import os
import random
from contextlib import contextmanager
from typing import List, Literal, Optional, Union
from argparse import ArgumentError
from contextlib import asynccontextmanager
from itertools import count
from queue import Empty, Queue
from threading import Thread
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from lmdeploy.messages import (EngineGenerationConfig, GenerationConfig,
PytorchEngineConfig, Response,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig, best_match_model
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _stop_words, get_logger
logger = get_logger('lmdeploy')
def get_model_name_from_workspace_model(model_dir: str):
"""Get model name from workspace model."""
from configparser import ConfigParser
triton_model_path = os.path.join(model_dir, 'triton_models', 'weights')
if not os.path.exists(triton_model_path):
return None
ini_path = os.path.join(triton_model_path, 'config.ini')
# load cfg
with open(ini_path, 'r') as f:
parser = ConfigParser()
parser.read_file(f)
return parser['llama']['model_name']
def deduce_a_name(
model_path: str,
model_name: Optional[str] = None,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None) -> str:
"""Deduce a model name from all the possible arguments."""
def _config_model_name(config):
if config and config.model_name:
return config.model_name
return None
backend_config_model_name = _config_model_name(backend_config)
chat_template_config_model_name = _config_model_name(chat_template_config)
model_name = model_name or chat_template_config_model_name or backend_config_model_name # noqa
if model_name is None:
# model maybe from workspace for turbomind
model_name = get_model_name_from_workspace_model(model_path)
# may get a model name from model_path
if model_name is None:
model_name = best_match_model(model_path)
if model_name is None:
raise ArgumentError(None,
f'Please set model_name for {model_path}')
else:
logger.info(f'matched chat template name: {model_name}')
return model_name
@dataclasses.dataclass
......@@ -16,6 +74,55 @@ class GenOut:
finish_reason: Optional[Literal['stop', 'length']] = None
class Session:
"""Session for AsyncEngine.chat.
Args:
_id (int): session_id for internal use.
_step (int): the offset of the k/v cache for internal use.
_prompt (Any): input prompt for internal use.
_response (Reaponse): model output for prompt.
_engine (Any): engine for internal use.
history (List[Any, str]): chat history.
"""
_ids = count(0)
def __init__(self):
self._id: int = next(self._ids)
self._step: int = 0
self._prompt: Any = None
self._response: Response = None
self._engine: Any = None
self.history: List[Tuple[Any, str]] = []
def _merge_response(self, resp: Response, step: Union[Response, GenOut]):
"""merge response."""
resp.text += step.text if isinstance(step, Response) else step.response
resp.input_token_len = step.input_token_len
resp.generate_token_len = step.generate_token_len
resp.finish_reason = step.finish_reason
return resp
@property
def response(self) -> Response:
"""return response."""
return self._response
def close(self):
"""release engine storage for this session."""
if self._engine:
inst = self._engine.create_instance()
inst.end(self._id)
def __repr__(self) -> str:
res = ''
for user, assistant in self.history:
if isinstance(user, list):
user = str(user)
res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n'
return res
class AsyncEngine:
"""Async inference engine. Maintaining a bunch of tm_model instances.
......@@ -30,51 +137,150 @@ class AsyncEngine:
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b",
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b",
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
instance_num (int): instance numbers to be created
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel
"""
def __init__(self,
model_path: str,
model_name: Optional[str] = None,
instance_num: int = 32,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
tp: int = 1,
**kwargs) -> None:
from lmdeploy import turbomind as tm
self.tm_model = tm.TurboMind.from_pretrained(model_path,
model_name=model_name,
tp=tp,
**kwargs)
self.tokenizer = self.tm_model.tokenizer
self.instance_num = instance_num
self.model = self.tm_model.model
logger.info(
f'input backend={backend}, backend_config={backend_config}')
logger.info(f'input chat_template_config={chat_template_config}')
self.model_name = deduce_a_name(model_path, model_name, backend_config,
chat_template_config)
# build chat template config
if chat_template_config is None:
chat_template_config = ChatTemplateConfig(self.model_name)
elif chat_template_config.model_name is None:
chat_template_config.model_name = self.model_name
self.chat_template = chat_template_config.chat_template
# prevent bc
for k in list(kwargs.keys()):
if hasattr(chat_template_config, k):
logger.warning(f'{k} was deprecated. Please use '
'chat_template_config instead')
v = kwargs.pop(k)
setattr(chat_template_config, k, v)
logger.info(f'updated chat_template_onfig={chat_template_config}')
# build backend engine
if backend == 'turbomind':
self._build_turbomind(model_path=model_path,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)
elif backend == 'pytorch':
self._build_pytorch(model_path=model_path,
backend_config=backend_config,
**kwargs)
else:
raise ValueError(f'unsupported backend {backend}')
logger.info(f'updated backend_config={self.backend_config}')
# parameters for member functions
self.session_len = self.backend_config.session_len
self.stop_words = _stop_words(self.chat_template.stop_words,
self.engine.tokenizer)
if self.stop_words is not None:
self.stop_words = self.stop_words[0][0].tolist()
self.backend = backend
self.instance_num = self.backend_config.max_batch_size
self.tokenizer = self.engine.tokenizer
self.id2step = {}
self.id2generator = {}
self.loop = asyncio.get_event_loop()
self.running_session_ids = set()
self.gens_set = set()
for i in range(instance_num):
self.gens_set.add(self.tm_model.create_instance())
for i in range(self.instance_num):
self.gens_set.add(self.engine.create_instance())
def _build_turbomind(
self,
model_path: str,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
tp: int = 1,
**kwargs):
"""Innter build method for turbomind backend."""
if backend_config is None:
backend_config = TurbomindEngineConfig(model_name=self.model_name,
tp=tp)
assert isinstance(backend_config, TurbomindEngineConfig), 'Please'\
' use TurbomindEngineConfig imported from lmdeploy.messages for ' \
'turbomind backend'
if backend_config.session_len is None:
backend_config.session_len = self.chat_template.session_len
from lmdeploy import turbomind as tm
self.engine = tm.TurboMind.from_pretrained(
model_path,
engine_config=backend_config,
chat_template_config=chat_template_config,
**kwargs)
self.backend_config = backend_config
def _build_pytorch(
self,
model_path: str,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
**kwargs):
"""Innter build method for pytorch backend."""
from lmdeploy.pytorch.engine import Engine
if backend_config is None:
backend_config = PytorchEngineConfig(self.model_name)
assert isinstance(backend_config, PytorchEngineConfig), 'Please '\
'use PytorchEngineConfig imported from lmdeploy.messages for ' \
'pytorch backend'
if backend_config.session_len is None:
backend_config.session_len = self.chat_template.session_len
self.engine = Engine(model_path=model_path,
engine_config=backend_config)
self.backend_config = backend_config
def __call__(self,
prompts: List[str],
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
gen_config: Optional[GenerationConfig] = None,
request_output_len=512,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
do_preprocess=True,
top_k: int = 40,
top_p: float = 0.8,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
ignore_eos: bool = False,
do_preprocess: bool = True,
**kwargs):
"""Inference a batch of prompts.
Args:
prompts (List[str]): a batch of prompts
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
chat_template_config (ChatTemplateConfig | None):a instance of
ChatTemplateConfig. Default to None.
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
......@@ -85,245 +291,363 @@ class AsyncEngine:
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
if gen_config is None:
gen_config = GenerationConfig(
max_new_tokens=request_output_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos)
return self.batch_infer(prompts,
request_output_len=request_output_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
gen_config=gen_config,
do_preprocess=do_preprocess,
**kwargs)
def stop_session(self, session_id: int):
async def stop_session(self, session_id: int):
"""Stop a session by a session_id."""
input_ids = [self.tm_model.eos_id]
stop_generator = self.tm_model.create_instance()
for outputs in stop_generator.stream_infer(session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
if str(session_id) in self.id2generator:
await self.id2generator[str(session_id)].async_cancel(session_id)
self.gens_set.add(self.id2generator[str(session_id)])
def end_session(self, session_id: int):
self.running_session_ids.discard(session_id)
async def end_session(self, session_id: int):
"""Clear a session by a session_id."""
input_ids = [self.tm_model.eos_id]
end_generator = self.tm_model.create_instance()
for outputs in end_generator.stream_infer(session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=True):
pass
self.id2step[str(session_id)] = 0
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
if str(session_id) in self.id2generator:
await self.id2generator[str(session_id)].async_end(session_id)
self.id2step[str(session_id)] = 0
self.gens_set.add(self.id2generator[str(session_id)])
@contextmanager
def safe_run(self, session_id: Optional[int] = None):
self.running_session_ids.discard(session_id)
@asynccontextmanager
async def safe_run(self, session_id: Optional[int] = None):
"""A context manager to make sure server's safe running."""
try:
yield
except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id)
await self.stop_session(session_id)
raise e
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
if str(session_id) in self.id2generator:
self.gens_set.add(self.id2generator[str(session_id)])
async def get_embeddings(self, prompt, do_prerpocess=False):
if do_prerpocess:
prompt = self.model.get_prompt(prompt)
input_ids = self.tokenizer.encode(prompt)
return input_ids
self.running_session_ids.discard(session_id)
async def get_generator(self, stop: bool, session_id: int):
"""Only return the model instance if it is available."""
if stop:
return self.tm_model.create_instance()
while self.gens_set == set():
await asyncio.sleep(0)
return self.engine.create_instance()
# waiting no generator is available or the same session_id is running
while self.gens_set == set() or session_id in self.running_session_ids:
await asyncio.sleep(0.1)
generator = self.gens_set.pop()
self.id2generator[str(session_id)] = generator
self.running_session_ids.add(session_id)
return generator
def batch_infer(self,
prompts: Union[List[str], str],
request_output_len=512,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
do_preprocess=True,
prompts: Union[List[str], str, List[Dict],
List[List[Dict]]],
gen_config: Optional[Union[GenerationConfig,
EngineGenerationConfig]] = None,
do_preprocess: bool = True,
**kwargs):
"""Inference a batch of prompts.
Args:
prompts (List[str] | str): a batch of prompts
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
input_str = isinstance(prompts, str)
prompts = [prompts] if input_str else prompts
need_list_wrap = isinstance(prompts, str) or isinstance(
prompts[0], Dict)
prompts = [prompts] if need_list_wrap else prompts
assert isinstance(prompts, List), 'prompts should be a list'
batch_size = len(prompts)
outputs = [''] * batch_size
generators = []
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
stream_response=True,
sequence_start=True,
sequence_end=True,
request_output_len=request_output_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
ignore_eos=ignore_eos,
repetition_penalty=repetition_penalty,
do_preprocess=do_preprocess,
**kwargs))
async def _inner_call(i, generator):
async for out in generator:
outputs[i] += out.response
async def gather():
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(batch_size)])
self.loop.run_until_complete(gather())
outputs = outputs[0] if input_str else outputs
if gen_config is None:
gen_config = GenerationConfig()
if type(gen_config) is GenerationConfig:
gen_config = EngineGenerationConfig.From(gen_config,
self.tokenizer)
# set random if it is not set
if gen_config.random_seed is None:
gen_config.random_seed = random.getrandbits(64)
prompt_num = len(prompts)
outputs = [Response('', 0, 0, i) for i in range(prompt_num)]
for j in range(0, prompt_num, self.instance_num):
batch_prompts = prompts[j:j + self.instance_num]
generators = []
for i, prompt in enumerate(batch_prompts):
generators.append(
self.generate(prompt,
i,
gen_config=gen_config,
stream_response=True,
sequence_start=True,
sequence_end=True,
do_preprocess=do_preprocess,
**kwargs))
async def _inner_call(i, generator):
async for out in generator:
outputs[i + j].text += out.response
outputs[i + j].generate_token_len = out.generate_token_len
outputs[i + j].input_token_len = out.input_token_len
outputs[i + j].finish_reason = out.finish_reason
async def gather():
await asyncio.gather(*[
_inner_call(i, generators[i])
for i in range(len(batch_prompts))
])
self.loop.run_until_complete(gather())
outputs = outputs[0] if need_list_wrap else outputs
return outputs
def stream_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
gen_config: Optional[Union[GenerationConfig,
EngineGenerationConfig]] = None,
do_preprocess: bool = True,
**kwargs):
"""Inference a batch of prompts with stream mode.
Args:
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
need_list_wrap = isinstance(prompts, str) or isinstance(
prompts[0], Dict)
prompts = [prompts] if need_list_wrap else prompts
assert isinstance(prompts, List), 'prompts should be a list'
if gen_config is None:
gen_config = GenerationConfig()
if type(gen_config) is GenerationConfig:
gen_config = EngineGenerationConfig.From(gen_config,
self.tokenizer)
# set random if it is not set
if gen_config.random_seed is None:
gen_config.random_seed = random.getrandbits(64)
prompt_num = len(prompts)
outputs = Queue()
generators = []
for j in range(0, prompt_num, self.instance_num):
batch_prompts = prompts[j:j + self.instance_num]
generators = []
for i, prompt in enumerate(batch_prompts):
generators.append(
self.generate(prompt,
i,
gen_config=gen_config,
stream_response=True,
sequence_start=True,
sequence_end=True,
do_preprocess=do_preprocess,
**kwargs))
async def _inner_call(i, generator):
async for out in generator:
outputs.put(
Response(out.response, out.generate_token_len,
out.input_token_len, i + j,
out.finish_reason))
async def gather():
await asyncio.gather(*[
_inner_call(i, generators[i])
for i in range(len(batch_prompts))
])
outputs.put(None)
proc = Thread(
target=lambda: self.loop.run_until_complete(gather()))
proc.start()
while True:
try:
out = outputs.get(timeout=0.001)
if out is None:
break
yield out
except Empty:
pass
proc.join()
async def _get_prompt_input(self, prompt: str, do_preprocess: bool,
sequence_start: bool):
if do_preprocess:
prompt = self.chat_template.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
return {'prompt': prompt, 'input_ids': input_ids}
async def generate(
self,
messages,
session_id,
stream_response=True,
sequence_start=True,
sequence_end=True, # no interactive mode by default
step=0,
request_output_len=512,
stop=False,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
do_preprocess=True,
session_id: int,
gen_config: Optional[Union[GenerationConfig,
EngineGenerationConfig]] = None,
stream_response: bool = True,
sequence_start: bool = True,
sequence_end: bool = True, # no interactive mode by default
step: int = 0,
do_preprocess: bool = True,
**kwargs):
"""Generate responses.
Args:
messages (str | List): chat history or prompt
session_id (int): the session id
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
stream_response (bool): whether return responses streamingly
request_output_len (int): output token nums
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): whether stop inference
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
if str(session_id) not in self.id2step:
self.id2step[str(session_id)] = 0
if step != 0:
self.id2step[str(session_id)] = step
seed = random.getrandbits(64)
if gen_config is None:
gen_config = GenerationConfig()
if type(gen_config) is GenerationConfig:
gen_config = EngineGenerationConfig.From(gen_config,
self.tokenizer)
if gen_config.stop_words is None:
gen_config.stop_words = self.stop_words
# set random if it is not set and sequence_start is True
if gen_config.random_seed is None and sequence_start:
gen_config.random_seed = random.getrandbits(64)
prompt = messages
if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = None
request_output_len = min(
request_output_len, self.tm_model.session_len - self.id2step[str(session_id)] -
prompt_input = await self._get_prompt_input(prompt, do_preprocess,
sequence_start)
prompt = prompt_input['prompt']
logger.info(f'Prompt with applied chat template:\n{prompt}')
input_ids = prompt_input['input_ids']
if gen_config.max_new_tokens is None:
# for interactive endpoint, will try maximum possible token num
gen_config.max_new_tokens = max(
128, self.session_len - self.id2step[str(session_id)] -
len(input_ids))
request_output_len = max(0, request_output_len)
if stop is True:
self.stop_session(session_id)
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason)
elif self.id2step[str(session_id)] + len(
input_ids) + request_output_len > self.tm_model.session_len:
finish_reason = None
logger.info(f'session_id={session_id}, '
f'history_tokens={self.id2step[str(session_id)]}, '
f'input_tokens={len(input_ids)}, '
f'max_new_tokens={gen_config.max_new_tokens}, '
f'seq_start={sequence_start}, seq_end={sequence_end}, '
f'step={step}, prep={do_preprocess}')
if self.id2step[str(session_id)] + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
logger.warning(f'run out of tokens. session_id={session_id}')
finish_reason = 'length'
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason)
if sequence_end is True and sequence_start is False:
self.end_session(session_id)
await self.end_session(session_id)
else:
generator = await self.get_generator(stop, session_id)
with self.safe_run(session_id):
response_size = 0
generator = await self.get_generator(False, session_id)
async with self.safe_run(session_id):
state = DetokenizeState()
async for outputs in generator.async_stream_infer(
session_id=session_id,
input_ids=[input_ids],
**prompt_input,
gen_config=gen_config,
stream_output=stream_response,
request_output_len=request_output_len,
sequence_start=(sequence_start),
sequence_start=sequence_start,
sequence_end=sequence_end,
step=self.id2step[str(session_id)],
stop=stop,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
step=self.id2step[str(session_id)]):
_, res, tokens = outputs
# decode res
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if response.endswith('�'):
continue
response, state = self.tokenizer.detokenize_incrementally(
res,
state,
skip_special_tokens=gen_config.skip_special_tokens)
# response, history token len,
# input token len, gen token len
yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens
finish_reason = 'length' \
if tokens >= request_output_len else 'stop'
# `response_size` might be note updated since
# ` if response.endswith('�')`
if response_size == tokens:
if tokens >= gen_config.max_new_tokens else 'stop'
# utf-8 char at the end means it's a potential unfinished
# byte sequence
if not response.endswith('�'):
response = '' # avaid returning the last response twice
yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason)
# update step
self.id2step[str(session_id)] += len(input_ids) + tokens
if sequence_end or stop:
if sequence_end:
self.id2step[str(session_id)] = 0
# manually end pytorch session
# TODO modify pytorch or turbomind api
if self.backend == 'pytorch' and sequence_end:
await self.end_session(session_id)
def chat(self,
prompt: str,
session=None,
gen_config: Optional[Union[GenerationConfig,
EngineGenerationConfig]] = None,
do_preprocess: bool = True,
**kwargs) -> Session:
"""Chat.
Args:
prompt (str): prompt
session (Session): the chat session
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
**kwargs (dict): ad hoc parametrization of `gen_config
"""
if session is None:
session = Session()
session._engine = self.engine
# sync & init
session._prompt = prompt
session._response = None
sequence_start = session._step == 0
async def _work():
resp = Response('', -1, -1, session._id)
async for output in self.generate(prompt,
session_id=session._id,
gen_config=gen_config,
stream_response=False,
sequence_start=sequence_start,
sequence_end=False,
step=session._step,
do_preprocess=do_preprocess,
**kwargs):
resp = session._merge_response(resp, output)
return resp
from lmdeploy.pytorch.engine.request import _run_until_complete
resp = _run_until_complete(_work())
session._response = resp
session._step += resp.generate_token_len + resp.input_token_len
session.history.append((session._prompt, resp.text))
return session
......@@ -17,7 +17,8 @@ class InterFace:
def chat_stream_restful(instruction: str, state_chatbot: Sequence,
cancel_btn: gr.Button, reset_btn: gr.Button,
session_id: int):
session_id: int, top_p: float, temperature: float,
request_output_len: int):
"""Chat with AI assistant.
Args:
......@@ -33,9 +34,11 @@ def chat_stream_restful(instruction: str, state_chatbot: Sequence,
instruction,
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=512,
interactive_mode=True):
if finish_reason == 'length':
request_output_len=request_output_len,
interactive_mode=True,
top_p=top_p,
temperature=temperature):
if finish_reason == 'length' and tokens == 0:
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if tokens < 0:
......@@ -94,7 +97,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
stop=True,
cancel=True,
interactive_mode=True):
pass
# end the session
......@@ -106,6 +109,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
interactive_mode=False):
pass
# resume the session
# TODO this is not proper if api server is running pytorch backend
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
......@@ -155,10 +159,22 @@ def run_api_server(api_server_url: str,
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')
send_event = instruction_txtbox.submit(chat_stream_restful, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id
state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Literal, Optional, Union
from lmdeploy.archs import get_task
from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.model import ChatTemplateConfig
def run(model_path_or_server: str,
server_name: str = '0.0.0.0',
server_port: int = 6006,
batch_size: int = 32,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
tp: int = 1,
model_name: str = None,
**kwargs):
......@@ -19,6 +28,12 @@ def run(model_path_or_server: str,
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel for Turbomind
"""
if ':' in model_path_or_server:
......@@ -31,11 +46,22 @@ def run(model_path_or_server: str,
run_triton_server
run_triton_server(model_path_or_server, server_name, server_port)
else:
from lmdeploy.serve.gradio.turbomind_coupled import run_local
pipeline_type, _ = get_task(model_path_or_server)
if pipeline_type == 'vlm':
from lmdeploy.serve.gradio.vl import run_local
assert backend == 'turbomind', 'vlm only support turbomind backend'
if backend_config is not None and \
backend_config.session_len is None:
backend_config.session_len = 8192
else:
from lmdeploy.serve.gradio.turbomind_coupled import run_local
run_local(model_path_or_server,
model_name=model_name,
server_name=server_name,
server_port=server_port,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
model_name=model_name,
batch_size=batch_size,
tp=tp,
**kwargs)
......
......@@ -24,5 +24,5 @@ THEME = gr.themes.Soft(
secondary_hue=gr.themes.colors.sky,
font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
enable_btn = gr.update(interactive=True)
disable_btn = gr.update(interactive=False)
......@@ -16,7 +16,8 @@ class InterFace:
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int):
cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int,
top_p: float, temperature: float, request_output_len: int):
"""Chat with AI assistant.
Args:
......@@ -30,7 +31,12 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction = state_chatbot[-1][0]
bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
session_id,
instruction,
f'{session_id}-{len(state_chatbot)}',
request_output_len=request_output_len,
top_p=top_p,
temperature=temperature)
for status, tokens, _ in bot_response:
state_chatbot[-1] = (state_chatbot[-1][0], tokens)
......@@ -108,12 +114,24 @@ def run_triton_server(triton_server_addr: str,
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')
send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(chat_stream, [
state_chatbot, llama_chatbot, cancel_btn, reset_btn,
state_session_id
state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn])
cancel_btn.click(cancel_func,
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
from threading import Lock
from typing import Optional, Sequence
from typing import Literal, Optional, Sequence, Union
import gradio as gr
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
......@@ -14,13 +18,10 @@ class InterFace:
lock = Lock()
async def chat_stream_local(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
session_id: int,
):
async def chat_stream_local(instruction: str, state_chatbot: Sequence,
cancel_btn: gr.Button, reset_btn: gr.Button,
session_id: int, top_p: float, temperature: float,
request_output_len: int):
"""Chat with AI assistant.
Args:
......@@ -33,15 +34,23 @@ async def chat_stream_local(
state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
gen_config = GenerationConfig(max_new_tokens=request_output_len,
top_p=top_p,
top_k=40,
temperature=temperature,
random_seed=random.getrandbits(64)
if len(state_chatbot) == 1 else None)
async for outputs in InterFace.async_engine.generate(
instruction,
session_id,
gen_config=gen_config,
stream_response=True,
sequence_start=(len(state_chatbot) == 1),
sequence_end=False):
response = outputs.response
if outputs.finish_reason == 'length':
if outputs.finish_reason == 'length' and \
outputs.generate_token_len == 0:
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if outputs.generate_token_len < 0:
......@@ -69,7 +78,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
"""
state_chatbot = []
# end the session
InterFace.async_engine.end_session(session_id)
await InterFace.async_engine.end_session(session_id)
return (state_chatbot, state_chatbot, gr.Textbox.update(value=''))
......@@ -85,28 +94,36 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
InterFace.async_engine.stop_session(session_id)
InterFace.async_engine.end_session(session_id)
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
if qa[1] is not None:
messages.append(dict(role='assistant', content=qa[1]))
async for out in InterFace.async_engine.generate(messages,
session_id,
request_output_len=0,
stream_response=True,
sequence_start=True,
sequence_end=False):
pass
yield (state_chatbot, disable_btn, enable_btn)
await InterFace.async_engine.stop_session(session_id)
# pytorch backend does not support resume chat history now
if InterFace.async_engine.backend == 'pytorch':
yield (state_chatbot, disable_btn, enable_btn)
else:
await InterFace.async_engine.end_session(session_id)
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
if qa[1] is not None:
messages.append(dict(role='assistant', content=qa[1]))
gen_config = GenerationConfig(max_new_tokens=0)
async for out in InterFace.async_engine.generate(messages,
session_id,
gen_config=gen_config,
stream_response=True,
sequence_start=True,
sequence_end=False):
pass
yield (state_chatbot, disable_btn, enable_btn)
def run_local(model_path: str,
model_name: Optional[str] = None,
server_name: str = 'localhost',
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
server_name: str = '0.0.0.0',
server_port: int = 6006,
batch_size: int = 4,
tp: int = 1,
**kwargs):
"""chat with AI assistant through web ui.
......@@ -122,22 +139,32 @@ def run_local(model_path: str,
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b",
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b",
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
server_name (str): the ip address of gradio server
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): the ip address of gradio server. Default to
"0.0.0.0". For huggingface space demo, it should be
"huggingface-space".
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
"""
InterFace.async_engine = AsyncEngine(model_path=model_path,
model_name=model_name,
instance_num=batch_size,
tp=tp,
**kwargs)
InterFace.async_engine = AsyncEngine(
model_path=model_path,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
model_name=model_name,
tp=tp,
**kwargs)
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
......@@ -148,17 +175,29 @@ def run_local(model_path: str,
chatbot = gr.Chatbot(
elem_id='chatbot',
label=InterFace.async_engine.tm_model.model_name)
label=InterFace.async_engine.engine.model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
with gr.Row():
request_output_len = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')
send_event = instruction_txtbox.submit(chat_stream_local, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id
state_session_id, top_p, temperature, request_output_len
], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
......@@ -184,14 +223,19 @@ def run_local(model_path: str,
demo.load(init, inputs=None, outputs=[state_session_id])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
if server_name == 'huggingface-space':
demo.queue(concurrency_count=InterFace.async_engine.instance_num,
max_size=100).launch()
else:
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=InterFace.async_engine.instance_num,
max_size=100,
api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
if __name__ == '__main__':
......
......@@ -4,8 +4,11 @@ from typing import Any, Dict, Iterable, List, Optional, Union
import requests
from lmdeploy.utils import get_logger
def get_model_list(api_url: str):
"""Get model list from api server."""
response = requests.get(api_url)
if hasattr(response, 'text'):
model_list = json.loads(response.text)
......@@ -14,15 +17,31 @@ def get_model_list(api_url: str):
return None
def json_loads(content):
"""Loads content to json format."""
try:
content = json.loads(content)
return content
except: # noqa
logger = get_logger('lmdeploy')
logger.warning(f'weird json content {content}')
return ''
class APIClient:
"""Chatbot for LLaMA series models with turbomind as inference engine.
Args:
api_server_url (str): communicating address 'http://<ip>:<port>' of
api_server
api_key (str | None): api key. Default to None, which means no
api key will be used.
"""
def __init__(self, api_server_url: str, **kwargs):
def __init__(self,
api_server_url: str,
api_key: Optional[str] = None,
**kwargs):
self.api_server_url = api_server_url
self.chat_intractive_v1_url = f'{api_server_url}/v1/chat/interactive'
self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions'
......@@ -30,6 +49,10 @@ class APIClient:
self.models_v1_url = f'{api_server_url}/v1/models'
self.encode_v1_url = f'{api_server_url}/v1/encode'
self._available_models = None
self.api_key = api_key
self.headers = {'content-type': 'application/json'}
if api_key is not None:
self.headers['Authorization'] = f'Bearer {api_key}'
@property
def available_models(self):
......@@ -38,7 +61,7 @@ class APIClient:
return self._available_models
response = requests.get(self.models_v1_url)
if hasattr(response, 'text'):
model_list = json.loads(response.text)
model_list = json_loads(response.text)
model_list = model_list.pop('data', [])
self._available_models = [item['id'] for item in model_list]
return self._available_models
......@@ -57,15 +80,14 @@ class APIClient:
when it is not. Default to True.
Return: (input_ids, length)
"""
headers = {'content-type': 'application/json'}
response = requests.post(self.encode_v1_url,
headers=headers,
headers=self.headers,
json=dict(input=input,
do_preprocess=do_preprocess,
add_bos=add_bos),
stream=False)
if hasattr(response, 'text'):
output = json.loads(response.text)
output = json_loads(response.text)
return output['input_ids'], output['length']
return None, None
......@@ -75,8 +97,8 @@ class APIClient:
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 1.0,
n: Optional[int] = 1,
max_tokens: Optional[int] = 512,
stop: Optional[bool] = False,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = False,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
......@@ -84,12 +106,14 @@ class APIClient:
repetition_penalty: Optional[float] = 1.0,
session_id: Optional[int] = -1,
ignore_eos: Optional[bool] = False,
skip_special_tokens: Optional[bool] = True,
**kwargs):
"""Chat completion v1.
Args:
model: model name. Available from self.available_models.
messages: string prompt or chat history in OpenAI format.
messages: string prompt or chat history in OpenAI format. Chat
history example: `[{"role": "user", "content": "hi"}]`.
temperature (float): to modulate the next token probability
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
......@@ -97,11 +121,15 @@ class APIClient:
n (int): How many chat completion choices to generate for each
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
max_tokens (int): output token nums
max_tokens (int | None): output token nums. Default to None.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields:
json objects in openai formats
......@@ -111,9 +139,8 @@ class APIClient:
for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self']
}
headers = {'content-type': 'application/json'}
response = requests.post(self.chat_completions_v1_url,
headers=headers,
headers=self.headers,
json=pload,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
......@@ -126,11 +153,11 @@ class APIClient:
continue
if decoded[:6] == 'data: ':
decoded = decoded[6:]
output = json.loads(decoded)
output = json_loads(decoded)
yield output
else:
decoded = chunk.decode('utf-8')
output = json.loads(decoded)
output = json_loads(decoded)
yield output
def chat_interactive_v1(self,
......@@ -138,13 +165,14 @@ class APIClient:
session_id: int = -1,
interactive_mode: bool = False,
stream: bool = False,
stop: bool = False,
request_output_len: int = 512,
stop: Optional[Union[str, List[str]]] = None,
request_output_len: Optional[int] = None,
top_p: float = 0.8,
top_k: int = 40,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = True,
**kwargs):
"""Interactive completions.
......@@ -162,8 +190,10 @@ class APIClient:
interactive mode, session history is kept on the server (and
vice versa).
stream: whether to stream the results or not.
stop: whether to stop the session response or not.
request_output_len (int): output token nums
stop (str | List[str] | None): To stop generating further tokens.
Only accept stop words that's encoded to one token idex.
request_output_len (int): output token nums. If not specified,
will use maximum possible number for a session.
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
......@@ -173,18 +203,20 @@ class APIClient:
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
json objects consist of text, tokens, finish_reason
json objects consist of text, tokens, input_tokens,
history_tokens, finish_reason
"""
pload = {
k: v
for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self']
}
headers = {'content-type': 'application/json'}
response = requests.post(self.chat_intractive_v1_url,
headers=headers,
headers=self.headers,
json=pload,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
......@@ -192,7 +224,7 @@ class APIClient:
delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
output = json.loads(decoded)
output = json_loads(decoded)
yield output
def completions_v1(
......@@ -204,12 +236,15 @@ class APIClient:
n: Optional[int] = 1,
max_tokens: Optional[int] = 16,
stream: Optional[bool] = False,
stop: Optional[Union[str, List[str]]] = None,
top_p: Optional[float] = 1.0,
top_k: Optional[int] = 40,
user: Optional[str] = None,
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0,
session_id: Optional[int] = -1,
ignore_eos: Optional[bool] = False,
skip_special_tokens: Optional[bool] = True,
**kwargs):
"""Chat completion v1.
......@@ -223,14 +258,20 @@ class APIClient:
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
n (int): How many chat completion choices to generate for each
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
user (str): A unique identifier representing your end-user.
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields:
json objects in openai formats
......@@ -240,9 +281,8 @@ class APIClient:
for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self']
}
headers = {'content-type': 'application/json'}
response = requests.post(self.completions_v1_url,
headers=headers,
headers=self.headers,
json=pload,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
......@@ -250,16 +290,16 @@ class APIClient:
delimiter=b'\n'):
if chunk:
if stream:
decoded = chunk.decode('utf-8')[6:]
decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]':
continue
if decoded[:6] == 'data: ':
decoded = decoded[6:]
output = json.loads(decoded)
output = json_loads(decoded)
yield output
else:
decoded = chunk.decode('utf-8')
output = json.loads(decoded)
output = json_loads(decoded)
yield output
def chat(self,
......@@ -307,7 +347,7 @@ class APIClient:
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos):
if outputs['finish_reason'] == 'length':
if outputs['finish_reason'] == 'length' and outputs['tokens'] == 0:
print('WARNING: exceed session max length.'
' Please end the session.')
yield outputs['text'], outputs['tokens'], outputs['finish_reason']
......@@ -334,15 +374,21 @@ def input_prompt():
return '\n'.join(iter(input, sentinel))
def get_streaming_response(prompt: str,
api_url: str,
session_id: int,
request_output_len: int = 512,
stream: bool = True,
interactive_mode: bool = False,
ignore_eos: bool = False,
stop: bool = False) -> Iterable[List[str]]:
def get_streaming_response(
prompt: str,
api_url: str,
session_id: int,
request_output_len: int = 512,
stream: bool = True,
interactive_mode: bool = False,
ignore_eos: bool = False,
cancel: bool = False,
top_p: float = 0.8,
temperature: float = 0.7,
api_key: Optional[str] = None) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'}
if api_key is not None:
headers['Authorization'] = f'Bearer {api_key}'
pload = {
'prompt': prompt,
'stream': stream,
......@@ -350,7 +396,9 @@ def get_streaming_response(prompt: str,
'request_output_len': request_output_len,
'interactive_mode': interactive_mode,
'ignore_eos': ignore_eos,
'stop': stop
'cancel': cancel,
'top_p': top_p,
'temperature': temperature
}
response = requests.post(api_url,
headers=headers,
......@@ -360,15 +408,18 @@ def get_streaming_response(prompt: str,
decode_unicode=False,
delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
data = json_loads(chunk.decode('utf-8'))
output = data.pop('text', '')
tokens = data.pop('tokens', 0)
finish_reason = data.pop('finish_reason', None)
yield output, tokens, finish_reason
def main(api_server_url: str, session_id: int = 0):
api_client = APIClient(api_server_url)
def main(api_server_url: str,
session_id: int = 0,
api_key: Optional[str] = None):
"""Main function to chat in terminal."""
api_client = APIClient(api_server_url, api_key=api_key)
while True:
prompt = input_prompt()
if prompt in ['exit', 'end']:
......
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import os
import random
import time
from http import HTTPStatus
from typing import AsyncGenerator, List, Optional
from typing import AsyncGenerator, List, Literal, Optional, Union
import uvicorn
from fastapi import FastAPI, Request
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from lmdeploy.archs import get_task
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionRequest, ChatCompletionRequestQos, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, CompletionRequest,
CompletionResponse, CompletionResponseChoice,
CompletionRequestQos, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse,
GenerateRequest, GenerateResponse, ModelCard, ModelList, ModelPermission,
UsageInfo)
GenerateRequest, GenerateRequestQos, GenerateResponse, ModelCard,
ModelList, ModelPermission, UsageInfo)
from lmdeploy.serve.qos_engine.qos_engine import QosEngine
from lmdeploy.utils import get_logger
class VariableInterface:
"""A IO interface maintaining variables."""
async_engine: AsyncEngine = None
session_id: int = 0
api_keys: Optional[List[str]] = None
qos_engine: QosEngine = None
request_hosts = []
app = FastAPI(docs_url='/')
get_bearer_token = HTTPBearer(auto_error=False)
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
) -> str:
"""Check if client provide valid api key.
Adopted from https://github.com/lm-sys/FastChat/blob/v0.2.35/fastchat/serve/openai_api_server.py#L108-L127
""" # noqa
if VariableInterface.api_keys:
if auth is None or (
token := auth.credentials) not in VariableInterface.api_keys:
raise HTTPException(
status_code=401,
detail={
'error': {
'message': 'Please request with valid api key!',
'type': 'invalid_request_error',
'param': None,
'code': 'invalid_api_key',
}
},
)
return token
else:
# api_keys not set; allow all
return None
def get_model_list():
......@@ -37,10 +74,10 @@ def get_model_list():
Only provided one now.
"""
return [VariableInterface.async_engine.tm_model.model_name]
return [VariableInterface.async_engine.model_name]
@app.get('/v1/models')
@app.get('/v1/models', dependencies=[Depends(check_api_key)])
def available_models():
"""Show available models."""
model_cards = []
......@@ -74,17 +111,149 @@ async def check_request(request) -> Optional[JSONResponse]:
return ret
def ip2id(host_ip: str):
"""Convert host ip address to session id."""
if '.' in host_ip: # IPv4
return int(host_ip.replace('.', '')[-8:])
if ':' in host_ip: # IPv6
return int(host_ip.replace(':', '')[-8:], 16)
print('Warning, could not get session id from ip, set it 0')
return 0
@app.post('/v1/chat/completions_qos')
async def chat_completions_v1_qos(request: ChatCompletionRequestQos,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Refer to `https://platform.openai.com/docs/api-reference/chat/create`
for the API specification.
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = str(request.session_id)
created_time = int(time.time())
if VariableInterface.qos_engine is None:
return create_error_response(
HTTPStatus.NOT_FOUND,
'cannot parse qos engine config, this api is not work')
result_generator = await VariableInterface.qos_engine.generate_with_qos(
request)
if result_generator is None:
return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR,
'Failed to generate completions')
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role='assistant', content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.model_dump_json()
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role='assistant'),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f'data: {data}\n\n'
async for res in result_generator:
response_json = create_stream_response_json(
index=0,
text=res.response,
)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type='text/event-stream')
# Non-streaming response
final_res = None
text = ''
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
text += res.response
assert final_res is not None
choices = []
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=text),
finish_reason=final_res.finish_reason,
)
choices.append(choice_data)
total_tokens = sum([
final_res.history_token_len, final_res.input_token_len,
final_res.generate_token_len
])
usage = UsageInfo(
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
@app.post('/v1/chat/completions')
@app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)])
async def chat_completions_v1(request: ChatCompletionRequest,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
......@@ -94,7 +263,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format.
- messages: string prompt or chat history in OpenAI format. Chat history
example: `[{"role": "user", "content": "hi"}]`.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
......@@ -102,13 +272,18 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- max_tokens (int | None): output token nums. Default to None.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
......@@ -116,8 +291,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
if request.session_id == -1:
request.session_id = random.randint(1, 10086)
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
......@@ -126,18 +301,26 @@ async def chat_completions_v1(request: ChatCompletionRequest,
request_id = str(request.session_id)
created_time = int(time.time())
if isinstance(request.stop, str):
request.stop = [request.stop]
gen_config = GenerationConfig(
max_new_tokens=request.max_tokens,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens)
result_generator = VariableInterface.async_engine.generate(
request.messages,
request.session_id,
True, # always use stream to enable batching
gen_config=gen_config,
stream_response=True, # always use stream to enable batching
sequence_start=True,
sequence_end=True,
request_output_len=request.max_tokens if request.max_tokens else 512,
stop=request.stop,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
do_preprocess=not isinstance(request.messages,
str), # text completion for string input
)
......@@ -196,7 +379,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(request.session_id)
await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
......@@ -230,7 +414,155 @@ async def chat_completions_v1(request: ChatCompletionRequest,
return response
@app.post('/v1/completions')
@app.post('/v1/completions_qos')
async def completions_v1_qos(request: CompletionRequestQos,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Go to `https://platform.openai.com/docs/api-reference/completions/create`
for the API specification.
The request should be a JSON object with the following fields:
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = str(request.session_id)
created_time = int(time.time())
if isinstance(request.prompt, str):
request.prompt = [request.prompt]
if VariableInterface.qos_engine is None:
return create_error_response(
HTTPStatus.NOT_FOUND,
'cannot parse qos engine config, this api is not work')
generators = await VariableInterface.qos_engine.generate_with_qos(request)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.model_dump_json()
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for generator in generators:
for i in range(request.n):
choice_data = CompletionResponseStreamChoice(
index=i,
text='',
finish_reason=None,
)
chunk = CompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f'data: {data}\n\n'
async for res in generator:
response_json = create_stream_response_json(
index=0,
text=res.response,
)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type='text/event-stream')
# Non-streaming response
usage = UsageInfo()
choices = []
async def _inner_call(i, generator):
final_res = None
text = ''
async for res in generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
text += res.response
assert final_res is not None
choice_data = CompletionResponseChoice(
index=0,
text=text,
finish_reason=final_res.finish_reason,
)
choices.append(choice_data)
total_tokens = sum([
final_res.history_token_len, final_res.input_token_len,
final_res.generate_token_len
])
usage.prompt_tokens += final_res.input_token_len
usage.completion_tokens += final_res.generate_token_len
usage.total_tokens += total_tokens
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(len(generators))])
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
@app.post('/v1/completions', dependencies=[Depends(check_api_key)])
async def completions_v1(request: CompletionRequest,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
......@@ -242,7 +574,7 @@ async def completions_v1(request: CompletionRequest,
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- max_tokens (int): output token nums. Default to 16.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
......@@ -253,18 +585,23 @@ async def completions_v1(request: CompletionRequest,
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
if request.session_id == -1:
request.session_id = random.randint(1, 10086)
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
......@@ -274,21 +611,26 @@ async def completions_v1(request: CompletionRequest,
created_time = int(time.time())
if isinstance(request.prompt, str):
request.prompt = [request.prompt]
if isinstance(request.stop, str):
request.stop = [request.stop]
gen_config = GenerationConfig(
max_new_tokens=request.max_tokens if request.max_tokens else 512,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens)
generators = []
for i in range(len(request.prompt)):
result_generator = VariableInterface.async_engine.generate(
request.prompt[i],
request.session_id + i,
True, # always use stream to enable batching
gen_config=gen_config,
stream_response=True, # always use stream to enable batching
sequence_start=True,
sequence_end=True,
request_output_len=request.max_tokens
if request.max_tokens else 512,
stop=False,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
do_preprocess=False)
generators.append(result_generator)
......@@ -351,7 +693,8 @@ async def completions_v1(request: CompletionRequest,
async for res in generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(request.session_id)
await VariableInterface.async_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
......@@ -394,7 +737,7 @@ async def create_embeddings(request: EmbeddingsRequest,
'Unsupported by turbomind.')
@app.post('/v1/encode')
@app.post('/v1/encode', dependencies=[Depends(check_api_key)])
async def encode(request: EncodeRequest, raw_request: Request = None):
"""Encode prompts.
......@@ -407,7 +750,7 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
def encode(prompt: str, do_preprocess: bool, add_bos: bool):
if do_preprocess:
prompt = VariableInterface.async_engine.model.get_prompt(
prompt = VariableInterface.async_engine.chat_template.get_prompt(
prompt, sequence_start=add_bos)
input_ids = VariableInterface.async_engine.tokenizer.encode(
prompt, add_bos=add_bos)
......@@ -425,12 +768,9 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
return EncodeResponse(input_ids=encoded, length=length)
@app.post('/generate',
tags=['deprecated'],
description='please use /v1/chat/interactive')
@app.post('/v1/chat/interactive')
async def chat_interactive_v1(request: GenerateRequest,
raw_request: Request = None):
@app.post('/v1/chat/interactive_qos')
async def chat_interactive_v1_qos(request: GenerateRequestQos,
raw_request: Request = None):
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
......@@ -456,33 +796,134 @@ async def chat_interactive_v1(request: GenerateRequest,
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
"""
if request.session_id == -1:
request.session_id = random.randint(10087, 23333)
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
if VariableInterface.qos_engine is None:
return create_error_response(
HTTPStatus.NOT_FOUND,
'cannot parse qos engine config, this api is not work')
generation = await VariableInterface.qos_engine.generate_with_qos(request)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for out in generation:
chunk = GenerateResponse(text=out.response,
tokens=out.generate_token_len,
input_tokens=out.input_token_len,
history_tokens=out.history_token_len,
finish_reason=out.finish_reason)
data = chunk.model_dump_json()
yield f'{data}\n'
if request.stream:
return StreamingResponse(stream_results(),
media_type='text/event-stream')
else:
ret = {}
text = ''
tokens = 0
finish_reason = None
async for out in generation:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await VariableInterface.qos_engine.stop_session(
request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
text += out.response
tokens = out.generate_token_len
finish_reason = out.finish_reason
ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason}
return JSONResponse(ret)
@app.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)])
async def chat_interactive_v1(request: GenerateRequest,
raw_request: Request = None):
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
`interactive_mode = True`.
- On normal mode, no chat history is kept on the server. Set
`interactive_mode = False`.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- session_id: determine which instance will be called. If not specified
with a value other than -1, using random value directly.
- interactive_mode (bool): turn on interactive mode or not. On interactive
mode, session history is kept on the server (and vice versa).
- stream: whether to stream the results or not.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
- request_output_len (int): output token nums. If not specified, will use
maximum possible number for a session.
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- temperature (float): to modulate the next token probability
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
"""
if request.cancel:
if request.session_id != -1:
await VariableInterface.async_engine.stop_session(
request.session_id)
return {
'text': '',
'tokens': 0,
'input_tokens': 0,
'history_tokens': 0,
'finish_reason': 'stop'
}
else:
return create_error_response(
HTTPStatus.BAD_REQUEST,
'please set a session_id to cancel a request')
if request.session_id == -1:
VariableInterface.session_id += 1
request.session_id = VariableInterface.session_id
async_engine = VariableInterface.async_engine
sequence_start = async_engine.id2step.get(str(request.session_id), 0) == 0
sequence_end = not request.interactive_mode
if isinstance(request.stop, str):
request.stop = [request.stop]
gen_config = GenerationConfig(
max_new_tokens=request.request_output_len,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens)
generation = async_engine.generate(
request.prompt,
request.session_id,
gen_config=gen_config,
stream_response=True, # always use stream to enable batching
sequence_start=sequence_start,
sequence_end=sequence_end,
request_output_len=request.request_output_len,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos)
sequence_end=sequence_end)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for out in generation:
chunk = GenerateResponse(text=out.response,
tokens=out.generate_token_len,
input_tokens=out.input_token_len,
history_tokens=out.history_token_len,
finish_reason=out.finish_reason)
data = chunk.model_dump_json()
yield f'{data}\n'
......@@ -493,32 +934,46 @@ async def chat_interactive_v1(request: GenerateRequest,
else:
ret = {}
text = ''
tokens = 0
tokens, input_tokens, history_tokens = 0, 0, 0
finish_reason = None
async for out in generation:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
async_engine.stop_session(request.session_id)
await async_engine.stop_session(request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
text += out.response
tokens = out.generate_token_len
input_tokens = out.input_token_len
history_tokens = out.history_token_len
finish_reason = out.finish_reason
ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason}
ret = {
'text': text,
'tokens': tokens,
'input_tokens': input_tokens,
'history_tokens': history_tokens,
'finish_reason': finish_reason
}
return JSONResponse(ret)
def serve(model_path: str,
model_name: Optional[str] = None,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
server_name: str = '0.0.0.0',
server_port: int = 23333,
instance_num: int = 64,
tp: int = 1,
allow_origins: List[str] = ['*'],
allow_credentials: bool = True,
allow_methods: List[str] = ['*'],
allow_headers: List[str] = ['*'],
log_level: str = 'ERROR',
api_keys: Optional[Union[List[str], str]] = None,
ssl: bool = False,
qos_config_path: str = '',
**kwargs):
"""An example to perform model inference through the command line
interface.
......@@ -534,22 +989,34 @@ def serve(model_path: str,
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b",
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b"
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): host ip for serving
server_port (int): server port
instance_num (int): number of instances of turbomind model
tp (int): tensor parallel
allow_origins (List[str]): a list of allowed origins for CORS
allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as
a single api_key. Default to None, which means no api key applied.
ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.
qos_config_path (str): qos policy config path
""" # noqa E501
os.environ['TM_LOG_LEVEL'] = log_level
if os.getenv('TM_LOG_LEVEL') is None:
os.environ['TM_LOG_LEVEL'] = log_level
logger = get_logger('lmdeploy')
logger.setLevel(log_level)
if allow_origins:
app.add_middleware(
......@@ -559,16 +1026,55 @@ def serve(model_path: str,
allow_methods=allow_methods,
allow_headers=allow_headers,
)
if api_keys is not None:
if isinstance(api_keys, str):
api_keys = api_keys.split(',')
VariableInterface.api_keys = api_keys
ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http'
if ssl:
ssl_keyfile = os.environ['SSL_KEYFILE']
ssl_certfile = os.environ['SSL_CERTFILE']
http_or_https = 'https'
pipeline_type, pipeline_class = get_task(model_path)
VariableInterface.async_engine = pipeline_class(
model_path=model_path,
model_name=model_name,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)
if qos_config_path:
try:
with open(qos_config_path, 'r') as file:
qos_config_str = file.read()
VariableInterface.qos_engine = QosEngine(
qos_tag=qos_config_str,
engine=VariableInterface.async_engine,
**kwargs)
VariableInterface.qos_engine.start()
except FileNotFoundError:
VariableInterface.qos_engine = None
else:
# hide qos functions if not applied
for i in range(len(app.router.routes)):
if 'qos' in app.router.routes[i].path:
app.router.routes[i].include_in_schema = False
VariableInterface.async_engine = AsyncEngine(model_path=model_path,
model_name=model_name,
instance_num=instance_num,
tp=tp,
**kwargs)
for i in range(3):
print(f'HINT: Please open \033[93m\033[1mhttp://{server_name}:'
f'{server_port}\033[0m in a browser for detailed api usage!!!')
uvicorn.run(app=app, host=server_name, port=server_port, log_level='info')
print(
f'HINT: Please open \033[93m\033[1m{http_or_https}://'
f'{server_name}:{server_port}\033[0m in a browser for detailed api'
' usage!!!')
uvicorn.run(app=app,
host=server_name,
port=server_port,
log_level='info',
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile)
if __name__ == '__main__':
......
......@@ -55,23 +55,48 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
class ChatCompletionRequestQos(BaseModel):
"""Chat completion request."""
model: str
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = 512
max_tokens: Optional[int] = Field(default=None, examples=[None])
stop: Optional[bool] = False
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
user_id: Optional[str] = None
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
top_k: Optional[int] = 40
class ChatCompletionRequest(BaseModel):
"""Chat completion request."""
model: str
# yapf: disable
messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]]) # noqa
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = Field(default=None, examples=[None])
stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None]) # noqa
# yapf: enable
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
skip_special_tokens: Optional[bool] = True
top_k: Optional[int] = 40
class ChatMessage(BaseModel):
......@@ -120,6 +145,31 @@ class ChatCompletionStreamResponse(BaseModel):
class CompletionRequest(BaseModel):
"""Completion request."""
model: str
prompt: Union[str, List[Any]]
suffix: Optional[str] = None
temperature: Optional[float] = 0.7
n: Optional[int] = 1
max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = Field(default=None,
examples=[None])
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
logprobs: Optional[int] = None
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
skip_special_tokens: Optional[bool] = True
top_k: Optional[int] = 40 # for opencompass
class CompletionRequestQos(BaseModel):
"""Completion request."""
model: str
prompt: Union[str, List[Any]]
......@@ -136,9 +186,11 @@ class CompletionRequest(BaseModel):
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
# additional argument of lmdeploy
top_k: Optional[int] = 40
repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
user_id: Optional[str] = None
class CompletionResponseChoice(BaseModel):
......@@ -205,6 +257,25 @@ class EncodeResponse(BaseModel):
class GenerateRequest(BaseModel):
"""Generate request."""
prompt: Union[str, List[Dict[str, Any]]]
session_id: int = -1
interactive_mode: bool = False
stream: bool = False
stop: Optional[Union[str, List[str]]] = Field(default=None,
examples=[None])
request_output_len: Optional[int] = Field(default=None,
examples=[None]) # noqa
top_p: float = 0.8
top_k: int = 40
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
skip_special_tokens: Optional[bool] = True
cancel: Optional[bool] = False # cancel a responding request
class GenerateRequestQos(BaseModel):
"""Generate request."""
prompt: Union[str, List[Dict[str, str]]]
session_id: int = -1
......@@ -217,10 +288,13 @@ class GenerateRequest(BaseModel):
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
user_id: Optional[str] = None
class GenerateResponse(BaseModel):
"""Generate response."""
text: str
tokens: int
input_tokens: int
history_tokens: int
finish_reason: Optional[Literal['stop', 'length']] = None
......@@ -18,7 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse
from lmdeploy.model import MODELS
from lmdeploy.serve.turbomind.utils import (Postprocessor, Preprocessor,
prepare_tensor)
from lmdeploy.utils import filter_suffix
from lmdeploy.utils import filter_suffix, get_logger
@dataclass
......@@ -51,13 +51,6 @@ def stream_callback(que, result, error):
que.put(result.get_response(as_json=True))
def get_logger(log_file=None, log_level=logging.INFO):
"""Return the logger."""
from lmdeploy.utils import get_logger
logger = get_logger('service.ft', log_file=log_file, log_level=log_level)
return logger
class Chatbot:
"""Chatbot for LLaMA series models with turbomind as inference engine.
......@@ -75,6 +68,10 @@ class Chatbot:
ignore_eos: bool = False,
log_level: int = logging.INFO,
display: bool = False,
top_p: float = 1.0,
top_k: int = 1,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
**model_kwargs):
self.tritonserver_addr = tritonserver_addr
self.model_name = model_name
......@@ -97,10 +94,10 @@ class Chatbot:
self.eos_id = -1
self.cfg = mmengine.Config(
dict(session_len=self.model.session_len,
top_p=self.model.top_p,
top_k=self.model.top_k,
temperature=self.model.temperature,
repetition_penalty=self.model.repetition_penalty,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
bad_words=bad_words))
self.log_level = log_level
......@@ -113,6 +110,7 @@ class Chatbot:
request_output_len: int = None,
sequence_start: bool = False,
sequence_end: bool = False,
skip_special_tokens: bool = True,
*args,
**kwargs):
"""Start a new round conversion of a session.
......@@ -124,13 +122,15 @@ class Chatbot:
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
iterator: The generated content by chatbot
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'request_output_len {request_output_len}')
......@@ -149,11 +149,13 @@ class Chatbot:
self.cfg.update(**kwargs)
self._session.prompt = self._get_prompt(prompt, sequence_start)
for status, res, tokens in self._stream_infer(self._session,
self._session.prompt,
request_output_len,
sequence_start,
sequence_end):
for status, res, tokens in self._stream_infer(
self._session,
self._session.prompt,
request_output_len,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
res = filter_suffix(res, self.model.stop_words)
if status.value < 0:
......@@ -180,7 +182,7 @@ class Chatbot:
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'end session: {session_id}')
if self._session is None:
......@@ -218,7 +220,7 @@ class Chatbot:
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'cancel session: {session_id}')
if self._session is None:
......@@ -267,7 +269,7 @@ class Chatbot:
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'resume session: {session_id}')
if self._session is None:
......@@ -301,6 +303,7 @@ class Chatbot:
request_output_len: int = None,
sequence_start: bool = False,
sequence_end: bool = False,
skip_special_tokens: bool = True,
*args,
**kwargs):
"""Start a new round conversion of a session. Return the chat
......@@ -313,6 +316,8 @@ class Chatbot:
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
......@@ -320,7 +325,7 @@ class Chatbot:
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'request_output_len {request_output_len}')
......@@ -338,11 +343,13 @@ class Chatbot:
self._session.prompt = self._get_prompt(prompt, sequence_start)
status, res, tokens = None, '', 0
for status, res, tokens in self._stream_infer(self._session,
self._session.prompt,
request_output_len,
sequence_start,
sequence_end):
for status, res, tokens in self._stream_infer(
self._session,
self._session.prompt,
request_output_len,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
if status.value < 0:
break
if status == StatusCode.TRITON_STREAM_END: # remove stop_words
......@@ -420,6 +427,7 @@ class Chatbot:
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
skip_special_tokens: bool = True,
cancel: bool = False):
"""communicate with inference server to chat, or cancel a session, or
end a session.
......@@ -431,10 +439,12 @@ class Chatbot:
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
cancel (bool): indicator for cancelling the session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
tuple: status, text, generated token number
"""
logger = get_logger(log_level=self.log_level)
logger = get_logger('service.ft', log_level=self.log_level)
logger.info(f'session {session.session_id}, '
f'request id {session.request_id}, '
f'request_output_len {request_output_len}, '
......@@ -498,7 +508,8 @@ class Chatbot:
producer.start()
for status, res, n_token in self.stream_consumer(
self.postprocess, que, session, input_tokens, preseq_length,
cancel, logger, self.display, self.eos_id):
cancel, logger, self.display, self.eos_id,
skip_special_tokens):
yield status, res, n_token
producer.join()
......@@ -591,7 +602,8 @@ class Chatbot:
@staticmethod
def stream_consumer(postprocess, res_queue, session, n_input_token,
preseq_length, cancel, logger, display, eos_id):
preseq_length, cancel, logger, display, eos_id,
skip_special_tokens):
"""Consume the response from the triton inference server.
Args:
......@@ -605,11 +617,15 @@ class Chatbot:
logger (util.Logger):
display (bool): display the text in the consolo interface or not
eos_id (int): eos token id
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
tuple: status, text, generated token number
"""
status, res, n_token = None, '', 0
output_ids = np.zeros((1, 1, 0), dtype=np.uint32)
text = ''
while True:
result = res_queue.get()
if result is None:
......@@ -648,7 +664,8 @@ class Chatbot:
output_ids = output_ids[:, :, :-1]
output_str = postprocess(
output_ids, np.array([[n_token]], dtype=np.uint32))
output_ids, np.array([[n_token]], dtype=np.uint32),
np.array([[int(skip_special_tokens)]], dtype=np.int32))
text = output_str[0].decode()
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
......
......@@ -84,10 +84,13 @@ class TritonPythonModel:
request, 'TOKENS_BATCH').as_numpy()
sequence_length = pb_utils.get_input_tensor_by_name(
request, 'sequence_length').as_numpy()
skip_special_tokens = pb_utils.get_input_tensor_by_name(
request, 'skip_special_tokens').as_numpy()
# Postprocessing output data.
outputs = self._postprocessing(tokens_batch.tolist(),
sequence_length)
sequence_length,
skip_special_tokens)
# Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse.
......@@ -118,12 +121,16 @@ class TritonPythonModel:
"""
print('Cleaning up...')
def _postprocessing(self, tokens_batch, sequence_length):
def _postprocessing(self, tokens_batch, sequence_length,
skip_special_tokens):
"""decode token ids into texts."""
outputs = []
for beam_tokens, beam_len in zip(tokens_batch, sequence_length):
for tokens, _len in zip(beam_tokens, beam_len):
output = self.tokenizer.decode(tokens, _len)
for beam_tokens, beam_len, beam_skip_special in zip(
tokens_batch, sequence_length, skip_special_tokens):
for tokens, _len, skip_special in zip(beam_tokens, beam_len,
beam_skip_special):
output = self.tokenizer.decode(
tokens, _len, skip_special_tokens=bool(skip_special))
output = output.encode('utf8')
outputs.append(output)
return outputs
......@@ -11,6 +11,11 @@ input [
name: "sequence_length"
data_type: TYPE_UINT32
dims: [ -1 ]
},
{
name: "skip_special_tokens"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
output [
......
......@@ -72,22 +72,29 @@ class Postprocessor:
def __call__(self, *args, **kwargs):
return self.infer(*args, **kwargs)
def infer(self, output_ids: np.ndarray, seqlen: np.ndarray):
def infer(self,
output_ids: np.ndarray,
seqlen: np.ndarray,
skip_special_tokens: bool = True):
"""De-tokenize tokens for text.
Args:
output_ids(np.ndarray): tokens' id
seqlen(np.ndarray): sequence length
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
str: decoded tokens
"""
inputs = [
prepare_tensor('TOKENS_BATCH', output_ids),
prepare_tensor('sequence_length', seqlen)
prepare_tensor('sequence_length', seqlen),
prepare_tensor('skip_special_tokens', skip_special_tokens)
]
inputs[0].set_data_from_numpy(output_ids)
inputs[1].set_data_from_numpy(seqlen)
inputs[2].set_data_from_numpy(skip_special_tokens)
model_name = 'postprocessing'
with grpcclient.InferenceServerClient(self.tritonserver_addr) \
as client:
......
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import os.path as osp
from typing import Optional, Sequence, Union
from collections import deque
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Union
import torch
from lmdeploy.utils import get_logger
# this file will be copied to triton server, make sure all
# importing are starting from the package root lmdeploy
@dataclass
class DetokenizeState:
"""A state collection of incrementally detekenization.
Args:
ids_offset (int): offset to all input ids. In LMDeploy, the output
ids length is not one by one. It could be random by random.
prev_tokens (List[str] | None): for incrementally decoding.
Default to None, which means the first round.
prefix_offset (int): the start index of tokens to be converted to
string (prev + new tokens). Default to 0 for the first round.
read_offset (int): the end index of tokens to be converted to
string (prev token). Default to 0 for the first round.
"""
ids_offset: int = 0
prev_tokens: Optional[List[str]] = None
prefix_offset: int = 0
read_offset: int = 0
def as_tuple(self) -> Tuple:
"""Return a tuple of states."""
return (self.ids_offset, self.prev_tokens, self.prefix_offset,
self.read_offset)
class SentencePieceTokenizer:
"""Tokenizer of sentencepiece.
......@@ -18,6 +49,12 @@ class SentencePieceTokenizer:
from sentencepiece import SentencePieceProcessor
self.model = SentencePieceProcessor(model_file=model_file)
self._prefix_space_tokens = None
# for stop words
self._maybe_decode_bytes: bool = None
# TODO maybe lack a constant.py
self._indexes_tokens_deque = deque(maxlen=10)
self.max_indexes_num = 5
self.logger = get_logger('lmdeploy')
@property
def vocab_size(self):
......@@ -53,6 +90,27 @@ class SentencePieceTokenizer:
else:
return decoded
def indexes_containing_token(self, token: str):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for _token, _indexes in self._indexes_tokens_deque:
if token == _token:
return _indexes
if token == ' ': # ' ' is special
token = '▁'
vocab = self.model.IdToPiece(list(range(self.vocab_size)))
indexes = [i for i, voc in enumerate(vocab) if token in voc]
if len(indexes) > self.max_indexes_num:
indexes = self.encode(token, add_bos=False)[-1:]
self.logger.warning(
f'There are too many(>{self.max_indexes_num}) possible '
f'indexes may decoding {token}, we will use {indexes} only')
self._indexes_tokens_deque.append((token, indexes))
return indexes
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
......@@ -63,13 +121,18 @@ class SentencePieceTokenizer:
"""
return self.model.Encode(s, add_bos=add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None):
def decode(self,
t: Sequence[int],
offset: Optional[int] = None,
skip_special_tokens: bool = True,
**kwargs):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
skip_special_tokens (boo): not used in SentencePieceTokenizer.
Returns:
str: text of decoding tokens
"""
......@@ -81,6 +144,34 @@ class SentencePieceTokenizer:
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
out_string = self.model.Decode(all_input_ids)
if state.prev_tokens is not None:
out_string = self._maybe_add_prefix_space(all_input_ids,
out_string)
state.prev_tokens = [] # not None for the above condition
return out_string, state
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
......@@ -106,20 +197,10 @@ class HuggingFaceTokenizer:
def __init__(self, model_dir: str):
from transformers import AutoTokenizer
model_file = osp.join(model_dir, 'tokenizer.model')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file)
if not osp.exists(backend_tokenizer_file) and model_file_exists:
print('WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.')
self.logger = get_logger('lmdeploy')
self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True)
self._prefix_space_tokens = None
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
if hasattr(self.model, 'backend_tokenizer'):
if os.access(model_dir, os.W_OK):
self.model.backend_tokenizer.save(backend_tokenizer_file)
if self.model.eos_token_id is None:
generation_config_file = osp.join(model_dir,
......@@ -131,11 +212,27 @@ class HuggingFaceTokenizer:
elif hasattr(self.model, 'eod_id'): # Qwen remote
self.model.eos_token_id = self.model.eod_id
# for stop words
self._vocab_size_with_added: int = None
self._maybe_decode_bytes: bool = None
# TODO maybe lack a constant.py
self._indexes_tokens_deque = deque(maxlen=10)
self.max_indexes_num = 5
self.token2id = {}
@property
def vocab_size(self):
"""vocabulary size."""
return self.model.vocab_size
@property
def vocab_size_with_added(self):
"""vocabulary size with added vocab."""
if self._vocab_size_with_added is not None:
return self._vocab_size_with_added
self._vocab_size_with_added = len(self.model.get_vocab())
return self._vocab_size_with_added
@property
def bos_token_id(self):
"""begine of the sentence token id."""
......@@ -159,7 +256,7 @@ class HuggingFaceTokenizer:
}
return self._prefix_space_tokens
def _maybe_add_prefix_space(self, tokens, decoded):
def _maybe_add_prefix_space(self, tokens: List[int], decoded: str):
"""maybe add prefix space for incremental decoding."""
if len(tokens) and not decoded.startswith(' ') and\
tokens[0] in self.prefix_space_tokens:
......@@ -167,6 +264,66 @@ class HuggingFaceTokenizer:
else:
return decoded
@property
def maybe_decode_bytes(self):
"""Check if self.model.convert_ids_to_tokens return not a str value."""
if self._maybe_decode_bytes is None:
self._maybe_decode_bytes = False
vocab = self.model.convert_ids_to_tokens(
list(range(self.vocab_size)))
for tok in vocab:
if not isinstance(tok, str):
self._maybe_decode_bytes = True
break
return self._maybe_decode_bytes
def indexes_containing_token(self, token: str):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for _token, _indexes in self._indexes_tokens_deque:
if token == _token:
return _indexes
if self.token2id == {}:
# decode is slower than convert_ids_to_tokens
if self.maybe_decode_bytes:
try:
self.token2id = {
self.model.decode(i): i
for i in range(self.vocab_size)
}
except Exception as e:
# qwen-vl
assert str(e) == 'Unclosed image token'
else:
self.token2id = {
self.model.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
if token == ' ': # ' ' is special
token = '▁'
indexes = [i for _token, i in self.token2id.items() if token in _token]
if len(indexes) > self.max_indexes_num:
# multiple id decode to same token
indexes = [i for i in indexes if self.decode([i]) == token]
indexes = indexes[:self.max_indexes_num]
self.logger.warning(
f'There are too many(>{self.max_indexes_num}) possible '
f'indexes may decoding {token}, we will use {indexes} only')
# there might be token id that exceeds self.vocab_size
if len(indexes) == 0:
indexes = self.encode(token, False)
if len(indexes) != 1:
self.logger.warning(
f'The token {token}, its length of indexes {indexes} is '
'not 1. Currently, it can not be used as stop words')
indexes = []
self._indexes_tokens_deque.append((token, indexes))
return indexes
def encode(self, s: str, add_bos: bool = True, **kwargs):
"""Tokenize a prompt.
......@@ -182,7 +339,10 @@ class HuggingFaceTokenizer:
encoded = encoded[1:]
return encoded
def decode(self, t: Sequence[int], offset: Optional[int] = None):
def decode(self,
t: Sequence[int],
offset: Optional[int] = None,
skip_special_tokens: bool = True):
"""De-tokenize.
Args:
......@@ -192,14 +352,121 @@ class HuggingFaceTokenizer:
Returns:
str: text of decoding tokens
"""
skip_special_tokens = True
t = t[offset:]
out_string = self.model.decode(t,
skip_special_tokens=skip_special_tokens)
if offset:
logger = get_logger('lmdeploy')
logger.warning('For incrementally detokenization, please try '
'detokenize_incrementally function instead.')
out_string = self._maybe_add_prefix_space(t, out_string)
return out_string
@staticmethod
def _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
if tokenizer.is_fast or not tokenizer.get_added_vocab():
return tokenizer.convert_tokens_to_string(output_tokens)
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99
sub_texts = []
current_sub_text = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(
current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return ' '.join(sub_texts)
else:
return ''.join(sub_texts)
# Based on
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165
def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
tokenizer = self.model
ids_offset, prev_tokens, prefix_offset, read_offset = state.as_tuple()
# This is the first iteration for this sequence
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids[ids_offset:],
skip_special_tokens=skip_special_tokens)
if prev_tokens is None:
# Please notice that in VLLM, indexes are detokenized one by one
# while in LMDeploy, every turn, the detokenized indexes length
# can be different.
if skip_special_tokens and new_tokens and new_tokens[
0] in tokenizer.all_special_ids:
read_offset = 1 # skip special token
output_tokens = new_tokens
prev_tokens = new_tokens
else:
# Put new_token_id in a list so skip_special_tokens is respected
output_tokens = prev_tokens + new_tokens
prev_tokens += new_tokens
prefix_text = self._convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = self._convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
# update state and get final decoded output
if len(new_text) > len(prefix_text) and not new_text.endswith('�'):
# utf-8 char at the end means it's a potential unfinished byte
# sequence from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
prefix_offset = read_offset
read_offset = len(output_tokens)
new_text = new_text[len(prefix_text):]
else:
new_text = ''
return new_text, DetokenizeState(len(all_input_ids), prev_tokens,
prefix_offset, read_offset)
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
......@@ -230,7 +497,7 @@ class Tokenizer:
model_file_exists = osp.exists(model_file)
config_exists = osp.exists(tokenizer_config_file)
use_hf_model = config_exists or not model_file_exists
self.logger = get_logger('lmdeploy')
if not use_hf_model:
self.model = SentencePieceTokenizer(model_file)
else:
......@@ -261,7 +528,12 @@ class Tokenizer:
"""
return self.model.encode(s, add_bos, **kwargs)
def decode(self, t: Sequence[int], offset: Optional[int] = None):
def decode(
self,
t: Sequence[int],
offset: Optional[int] = None,
skip_special_tokens: bool = True,
):
"""De-tokenize.
Args:
......@@ -271,7 +543,34 @@ class Tokenizer:
Returns:
str: text of decoding tokens
"""
return self.model.decode(t, offset)
return self.model.decode(t, offset, skip_special_tokens)
def detokenize_incrementally(self,
all_input_ids: Sequence[int],
state: DetokenizeState,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
return self.model.detokenize_incrementally(
all_input_ids,
state=state,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens)
def __call__(self, s: Union[str, Sequence[str]]):
"""Tokenize prompts.
......@@ -282,3 +581,14 @@ class Tokenizer:
list[int]: token ids
"""
return self.model(s)
def indexes_containing_token(self, token):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
encoded = self.encode(token, add_bos=False)
if len(encoded) > 1:
self.logger.warning(
f'The token {token}, its length of indexes {encoded} is over '
'than 1. Currently, it can not be used as stop words')
return []
return self.model.indexes_containing_token(token)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment