Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
d7117b95
Commit
d7117b95
authored
Mar 22, 2024
by
zhouxiang
Browse files
同步0.2.6代码
parent
5f83e392
Changes
151
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1849 additions
and
1015 deletions
+1849
-1015
lmdeploy/pytorch/dist.py
lmdeploy/pytorch/dist.py
+0
-98
lmdeploy/pytorch/model.py
lmdeploy/pytorch/model.py
+0
-159
lmdeploy/pytorch/modules/__init__.py
lmdeploy/pytorch/modules/__init__.py
+0
-4
lmdeploy/pytorch/modules/linear.py
lmdeploy/pytorch/modules/linear.py
+0
-154
lmdeploy/pytorch/session.py
lmdeploy/pytorch/session.py
+0
-81
lmdeploy/pytorch/utils.py
lmdeploy/pytorch/utils.py
+18
-93
lmdeploy/serve/async_engine.py
lmdeploy/serve/async_engine.py
+513
-189
lmdeploy/serve/gradio/api_server_backend.py
lmdeploy/serve/gradio/api_server_backend.py
+22
-6
lmdeploy/serve/gradio/app.py
lmdeploy/serve/gradio/app.py
+28
-2
lmdeploy/serve/gradio/constants.py
lmdeploy/serve/gradio/constants.py
+2
-2
lmdeploy/serve/gradio/triton_server_backend.py
lmdeploy/serve/gradio/triton_server_backend.py
+21
-3
lmdeploy/serve/gradio/turbomind_coupled.py
lmdeploy/serve/gradio/turbomind_coupled.py
+90
-46
lmdeploy/serve/openai/api_client.py
lmdeploy/serve/openai/api_client.py
+92
-41
lmdeploy/serve/openai/api_server.py
lmdeploy/serve/openai/api_server.py
+583
-77
lmdeploy/serve/openai/protocol.py
lmdeploy/serve/openai/protocol.py
+76
-2
lmdeploy/serve/turbomind/chatbot.py
lmdeploy/serve/turbomind/chatbot.py
+48
-31
lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py
...y/serve/turbomind/triton_models/postprocessing/1/model.py
+12
-5
lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt
...serve/turbomind/triton_models/postprocessing/config.pbtxt
+5
-0
lmdeploy/serve/turbomind/utils.py
lmdeploy/serve/turbomind/utils.py
+9
-2
lmdeploy/tokenizer.py
lmdeploy/tokenizer.py
+330
-20
No files found.
lmdeploy/pytorch/dist.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
"""Helpers for parallel and distributed inference."""
import
functools
import
os
import
torch
from
torch.distributed
import
broadcast
,
broadcast_object_list
,
is_initialized
def
get_local_rank
():
"""Get local rank of current process.
Assume environment variable ``LOCAL_RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
"""
# noqa: E501
return
int
(
os
.
getenv
(
'LOCAL_RANK'
,
'0'
))
def
get_rank
():
"""Get rank of current process.
Assume environment variable ``RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
"""
# noqa: E501
return
int
(
os
.
getenv
(
'RANK'
,
'0'
))
def
get_world_size
():
"""Get rank of current process.
Assume environment variable ``WORLD_SIZE`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
"""
# noqa: E501
return
int
(
os
.
getenv
(
'WORLD_SIZE'
,
'1'
))
def
master_only
(
func
):
"""Decorator to run a function only on the master process."""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
is_initialized
():
if
get_rank
()
!=
0
:
return
None
return
func
(
*
args
,
**
kwargs
)
return
wrapper
def
master_only_and_broadcast_general
(
func
):
"""Decorator to run a function only on the master process and broadcast the
result to all processes."""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
is_initialized
():
if
get_rank
()
==
0
:
result
=
[
func
(
*
args
,
**
kwargs
)]
else
:
result
=
[
None
]
broadcast_object_list
(
result
,
src
=
0
)
result
=
result
[
0
]
else
:
result
=
func
(
*
args
,
**
kwargs
)
return
result
return
wrapper
def
master_only_and_broadcast_tensor
(
func
):
"""Decorator to run a function only on the master process and broadcast the
result to all processes.
Note: Require CUDA tensor.
Note: Not really work because we don't know the shape aforehand,
for cpu tensors, use master_only_and_broadcast_general
"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
size
,
dtype
,
**
kwargs
):
if
is_initialized
():
if
get_rank
()
==
0
:
result
=
func
(
*
args
,
**
kwargs
)
else
:
result
=
torch
.
empty
(
size
=
size
,
dtype
=
dtype
,
device
=
get_local_rank
())
broadcast
(
result
,
src
=
0
)
# print(f'rank {get_rank()} received {result}')
else
:
result
=
func
(
*
args
,
**
kwargs
)
return
result
return
wrapper
lmdeploy/pytorch/model.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
time
import
warnings
from
typing
import
Optional
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
.dist
import
get_local_rank
logger
=
logging
.
getLogger
(
__name__
)
class
LoadWoInit
:
"""Context manager that disable parameter initialization."""
def
__init__
(
self
):
self
.
constant_
=
torch
.
nn
.
init
.
constant_
self
.
zeros_
=
torch
.
nn
.
init
.
zeros_
self
.
ones_
=
torch
.
nn
.
init
.
ones_
self
.
uniform_
=
torch
.
nn
.
init
.
uniform_
self
.
normal_
=
torch
.
nn
.
init
.
normal_
self
.
kaiming_uniform_
=
torch
.
nn
.
init
.
kaiming_uniform_
self
.
kaiming_normal_
=
torch
.
nn
.
init
.
kaiming_normal_
def
__enter__
(
self
,
*
args
,
**
kwargs
):
torch
.
nn
.
init
.
constant_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
zeros_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
ones_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
uniform_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
normal_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
kaiming_uniform_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
kaiming_normal_
=
lambda
*
args
,
**
kwargs
:
None
def
__exit__
(
self
,
*
args
,
**
kwargs
):
torch
.
nn
.
init
.
constant_
=
self
.
constant_
torch
.
nn
.
init
.
zeros_
=
self
.
zeros_
torch
.
nn
.
init
.
ones_
=
self
.
ones_
torch
.
nn
.
init
.
uniform_
=
self
.
uniform_
torch
.
nn
.
init
.
normal_
=
self
.
normal_
torch
.
nn
.
init
.
kaiming_uniform_
=
self
.
kaiming_uniform_
torch
.
nn
.
init
.
kaiming_normal_
=
self
.
kaiming_normal_
def
init_model
(
model_path
:
str
,
tokenizer_path
:
Optional
[
str
]
=
None
,
use_fast_tokenizer
=
True
):
"""Initialize model and tokenizer from given model path.
Args:
model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
"""
start
=
time
.
monotonic
()
if
not
tokenizer_path
:
tokenizer_path
=
model_path
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
use_fast
=
use_fast_tokenizer
,
trust_remote_code
=
True
)
with
LoadWoInit
():
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
logger
.
info
(
f
'Model loaded in
{
time
.
monotonic
()
-
start
:.
1
f
}
seconds'
)
logger
.
info
(
f
'Model loaded from
{
model_path
}
'
)
logger
.
debug
(
model
)
return
model
,
tokenizer
def
accel_model
(
model
,
accel
:
Optional
[
str
]
=
None
,
gpu_id
=
None
,
max_alloc
=
2048
,
tp_size
=
1
):
"""Accelerate model with given accelerator.
Note:
Currently we support only deepspeed or just no acceleration.
"""
logger
.
info
(
f
'Accelerate model with
{
accel
}
'
)
if
accel
is
None
:
# No acceleration, just to cuda
# assume single gpu single process
# user is responsible to assign the gpu id via CUDA_VISIBLE_DEVICES # noqa: E501
gpu_id
=
gpu_id
if
gpu_id
is
not
None
else
get_local_rank
()
model
=
model
.
cuda
(
gpu_id
)
elif
accel
.
lower
()
==
'deepspeed'
:
# Use deepspeed inference inject fast kernel and/or tensor parallel
try
:
import
deepspeed
except
ImportError
as
e
:
raise
ImportError
(
'--accel=deepspeed is specified but '
'deepspeed is not installed.
\n
'
'Install with `pip install deepspeed`.'
)
from
e
config
=
dict
(
tensor_parallel
=
dict
(
tp_size
=
tp_size
),
# Use world size in general
dtype
=
torch
.
float16
,
replace_with_kernel_inject
=
True
,
max_out_tokens
=
max_alloc
,
)
if
'InternLM'
in
model
.
__class__
.
__name__
:
try
:
# Use customized deepspeed supporting InternLM
# https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0 (commit cdef2ce) # noqa: E501
from
deepspeed.module_inject.containers.internlm
import
\
InternLMLayerPolicy
# noqa: E501
except
ImportError
:
# InternLM is not officially supported by DeepSpeed
# Set replace_with_kernel_inject=False to use AutoTP
config
.
update
({
'replace_with_kernel_inject'
:
False
})
warnings
.
warn
(
'
\033
[0;93m'
'Current installation of deepspeed does not '
'support InternLM. Disable kernel injection. '
'To support InternLM, install customized deepspeed with '
'`pip install git+https://github.com/wangruohui/DeepSpeed@support_internlm_0.10.0`'
# noqa: E501
'
\033
[0m'
)
else
:
for
module
in
model
.
modules
():
# Since remote code is dynamically located,
# we need to do this dynamically
if
module
.
__class__
.
__name__
==
'InternLMDecoderLayer'
:
InternLMLayerPolicy
.
_orig_layer_class
=
module
.
__class__
# noqa: E501
break
logger
.
debug
(
f
'Using deepspeed config
\n
{
config
}
'
)
model
=
deepspeed
.
init_inference
(
model
=
model
,
# Transformers models
config
=
config
,
)
# for k, v in model.named_parameters():
# logger.debug(f"{k}: v.device")
else
:
raise
ValueError
(
f
'Unsupported accelerator
{
accel
}
.'
)
logger
.
debug
(
model
)
return
model
lmdeploy/pytorch/modules/__init__.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
from
.linear
import
WeightOnlyQLinear
__all__
=
[
'WeightOnlyQLinear'
]
lmdeploy/pytorch/modules/linear.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Type
,
TypeVar
import
torch
from
torch
import
nn
try
:
import
awq_inference_engine
except
ModuleNotFoundError
:
awq_inference_engine
=
None
class
WeightOnlyQLinear
(
nn
.
Module
):
"""This class implements weight only quantization linear.
Args:
w_bit (int): number of bits for quantization.
symmetry (bool): If true, use symmetric quantization,
otherwise use asymmetric quantization.
group_size (int): size of the quantization group.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (Tensor, optional): Defaults to None.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
True
,
w_bit
:
int
=
4
,
symmetry
:
bool
=
False
,
group_size
:
int
=
128
,
)
->
None
:
super
().
__init__
()
if
w_bit
not
in
[
2
,
4
,
8
]:
raise
NotImplementedError
(
'Only 2,4,8 bit are supported for now.'
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
assert
self
.
in_features
%
self
.
group_size
==
0
assert
out_features
%
(
32
//
self
.
w_bit
)
==
0
w_pack_oc
=
out_features
//
(
32
//
self
.
w_bit
)
w_inc
=
in_features
weight
=
torch
.
zeros
((
w_inc
,
w_pack_oc
),
dtype
=
torch
.
int32
)
self
.
register_buffer
(
'qweight'
,
weight
)
if
bias
:
self
.
register_buffer
(
'bias'
,
torch
.
zeros
(
out_features
))
else
:
self
.
bias
=
None
s_inc
=
in_features
//
self
.
group_size
s_oc
=
out_features
scales
=
torch
.
zeros
((
s_inc
,
s_oc
),
dtype
=
torch
.
float16
)
self
.
register_buffer
(
'scales'
,
scales
)
if
not
symmetry
:
z_inc
=
in_features
//
self
.
group_size
z_oc
=
out_features
//
(
32
//
self
.
w_bit
)
zeros
=
torch
.
zeros
((
z_inc
,
z_oc
),
dtype
=
torch
.
int32
)
self
.
register_buffer
(
'qzeros'
,
zeros
)
else
:
self
.
qzeros
=
None
@
classmethod
def
from_linear
(
cls
:
Type
[
'WeightOnlyQLinear'
],
linear
:
nn
.
Linear
,
quantizer
:
TypeVar
(
'Quantizer'
),
awq_layout
:
bool
=
True
)
->
'WeightOnlyQLinear'
:
"""Create a WeightOnlyQLinear object from a PyTorch Linear object.
Args:
linear (nn.Linear): PyTorch Linear object.
quantizer (Quantizer): Object that handles quantization.
awq_layout (bool): AWQ layout. Defaults to True.
Returns:
WeightOnlyQLinear: A WeightOnlyQLinear object.
"""
device
=
linear
.
weight
.
device
w_bit
=
quantizer
.
bits
pack_num
=
32
//
w_bit
if
awq_layout
:
assert
w_bit
==
4
pack_order
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
else
:
pack_order
=
torch
.
arange
(
pack_num
)
group_size
=
quantizer
.
group_size
symmetry
=
quantizer
.
symmetry
in_features
=
linear
.
in_features
out_features
=
linear
.
out_features
bias
=
False
if
linear
.
bias
is
None
else
True
qlinear
=
cls
(
in_features
,
out_features
,
bias
,
w_bit
,
symmetry
,
group_size
)
qlinear
.
bias
=
linear
.
bias
qparams
=
quantizer
.
calculate_qparams
(
linear
.
weight
)
i32_w
=
quantizer
.
quant
(
linear
.
weight
,
qparams
,
real
=
True
)
i32_w
=
i32_w
.
t
().
contiguous
()
pack_int_w
=
torch
.
zeros_like
(
qlinear
.
qweight
).
to
(
device
)
for
col
in
range
(
pack_int_w
.
shape
[
1
]):
for
i
in
range
(
pack_num
):
pack_int_w_col
=
i32_w
[:,
col
*
pack_num
+
pack_order
[
i
]]
pack_int_w
[:,
col
]
|=
pack_int_w_col
<<
(
i
*
w_bit
)
qlinear
.
qweight
=
pack_int_w
qlinear
.
scales
=
qparams
.
scales
.
squeeze
(
-
1
).
t
().
contiguous
()
if
qparams
.
zero_points
is
not
None
:
zeros
=
qparams
.
zero_points
.
to
(
torch
.
int32
).
to
(
device
)
zeros
=
zeros
.
squeeze
(
-
1
).
t
().
contiguous
()
pack_int_zeros
=
torch
.
zeros_like
(
qlinear
.
qzeros
).
to
(
device
)
for
col
in
range
(
pack_int_zeros
.
shape
[
1
]):
for
i
in
range
(
pack_num
):
qzero_col
=
zeros
[:,
col
*
pack_num
+
pack_order
[
i
]]
pack_int_zeros
[:,
col
]
|=
qzero_col
<<
(
i
*
w_bit
)
qlinear
.
qzeros
=
pack_int_zeros
qlinear
.
to
(
'cpu'
)
return
qlinear
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
if
awq_inference_engine
is
None
:
raise
RuntimeError
(
'Run the following command to install '
'the kernel for 4bit inference
\n\n
'
'git clone https://github.com/mit-han-lab/llm-awq.git
\n
'
'cd awq/kernels
\n
'
'python setup.py install
\n
'
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,
)
inputs
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
awq_inference_engine
.
gemm_forward_cuda
(
inputs
.
half
(),
self
.
qweight
,
self
.
scales
.
half
(),
self
.
qzeros
,
self
.
group_size
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
lmdeploy/pytorch/session.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
torch
from
transformers.generation.utils
import
ModelOutput
logger
=
logging
.
getLogger
(
__name__
)
class
BasicSessionManager
:
"""Basic session manager without history."""
def
prepend_history
(
self
,
input_ids
):
return
input_ids
def
add_to_history
(
self
,
output
):
pass
class
BasicSessionManagerWithHistory
:
"""Basic session manager with chat history.
Args:
max_session_len (int): Maximum number of tokens allowed for all chat sessions.
reduce_size (int): Number of tokens to be trimmed when reaching maximum
session length. Default: 256.
start_ids (list[int]): Sequences of ids at the start of the chat session.
sep_ids (list[int]): Sequences of ids separating chat sessions.
"""
# noqa: E501
bs
=
1
def
__init__
(
self
,
max_session_len
=
2048
,
reduce_size
=
256
,
start_ids
=
[
1
],
sep_ids
=
[
13
])
->
None
:
self
.
start_ids
=
torch
.
tensor
(
start_ids
,
dtype
=
torch
.
long
)
self
.
sep_ids
=
torch
.
tensor
(
sep_ids
,
dtype
=
torch
.
long
)
assert
self
.
start_ids
.
ndim
==
1
assert
self
.
sep_ids
.
ndim
==
1
self
.
max_session_len
=
max
(
len
(
start_ids
),
max_session_len
)
self
.
reduce_size
=
min
(
reduce_size
,
max_session_len
-
len
(
start_ids
))
assert
self
.
max_session_len
>
self
.
reduce_size
self
.
new_session
()
def
new_session
(
self
):
self
.
history_ids
=
self
.
start_ids
.
repeat
(
self
.
bs
,
1
)
def
prepend_history
(
self
,
input_ids
:
torch
.
Tensor
):
"""Prepend history ids to input ids and trim if over-length."""
input_ids
=
input_ids
.
to
(
self
.
history_ids
.
device
).
long
()
sep_ids
=
self
.
sep_ids
.
to
(
self
.
history_ids
.
device
).
long
().
repeat
(
1
,
1
)
input_ids
=
torch
.
cat
([
self
.
history_ids
,
sep_ids
,
input_ids
],
dim
=
1
)
if
input_ids
.
shape
[
1
]
>
self
.
max_session_len
:
input_ids
=
input_ids
[:,
(
self
.
reduce_size
-
self
.
max_session_len
):]
input_ids
[:,
:
len
(
self
.
start_ids
)]
=
self
.
start_ids
.
repeat
(
self
.
bs
,
1
)
return
input_ids
def
add_to_history
(
self
,
output
):
"""Save history output ids.
Note:
Output returned by HuggingFace generator contains both input
and output ids.
"""
if
isinstance
(
output
,
ModelOutput
):
self
.
history_ids
=
output
.
sequences
elif
isinstance
(
output
,
torch
.
Tensor
):
self
.
history_ids
=
output
else
:
raise
ValueError
(
f
'Unknown output type
{
type
(
output
)
}
'
)
lmdeploy/pytorch/utils.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import
inspect
from
inspect
import
Parameter
,
Signature
from
typing
import
Dict
,
Sequence
import
logging
import
psutil
from
transformers.generation.streamers
import
BaseStreamer
from
.dist
import
get_rank
,
master_only
,
master_only_and_broadcast_general
def
get_gpu_memory
(
id
:
int
=
0
)
->
int
:
"""Returns the free and total physical memory of the GPU in bytes."""
import
torch
return
torch
.
cuda
.
mem_get_info
(
id
)
try
:
import
readline
# To support command line history # noqa: F401
except
ImportError
:
# readline not available
pass
logger
=
logging
.
getLogger
(
__name__
)
def
get_cpu_memory
()
->
int
:
"""Returns the total CPU memory of the node in bytes."""
return
psutil
.
virtual_memory
().
total
class
TerminalIO
:
def
bind_sigature
(
input_names
:
str
,
args
:
Sequence
,
kwargs
:
Dict
):
"""Terminal input and output."""
"""Bind args and kwargs to given input names."""
kind
=
inspect
.
_ParameterKind
.
POSITIONAL_OR_KEYWORD
end_of_output
=
'
\n
'
sig
=
Signature
([
Parameter
(
name
,
kind
)
for
name
in
input_names
])
bind
=
sig
.
bind
(
*
args
,
**
kwargs
)
@
master_only_and_broadcast_general
return
bind
.
arguments
def
input
(
self
):
"""Read input from terminal."""
print
(
'
\n
double enter to end input >>> '
,
end
=
''
)
sentinel
=
''
# ends when this string is seen
try
:
return
'
\n
'
.
join
(
iter
(
input
,
sentinel
))
except
EOFError
:
print
(
'Detect EOF, exit'
)
exit
()
@
master_only
def
output
(
self
,
string
):
"""Output to terminal with flush."""
print
(
string
,
end
=
''
,
flush
=
True
)
class
BasicStreamer
(
BaseStreamer
):
"""Basic streamer for HuggingFace models."""
def
__init__
(
self
,
decode_func
,
output_func
,
end_of_output
=
'
\n
'
,
skip_prompt
=
True
):
self
.
decode
=
decode_func
self
.
output
=
output_func
self
.
end_of_output
=
end_of_output
self
.
skip_prompt
=
skip_prompt
self
.
gen_len
=
0
def
put
(
self
,
value
):
"""Callback before forwarding current token id to model."""
if
self
.
gen_len
==
0
and
self
.
skip_prompt
:
pass
else
:
token
=
self
.
decode
(
value
)
self
.
output
(
token
)
self
.
gen_len
+=
1
def
end
(
self
):
"""Callback at the end of generation."""
self
.
output
(
self
.
end_of_output
)
self
.
gen_len
=
0
def
control
(
prompt
,
gen_config
,
sm
):
"""Allow user to control generation config and session manager.
Return:
True if control command applied, False otherwise.
"""
if
prompt
==
'exit'
:
exit
(
0
)
if
prompt
==
'clear'
:
sm
.
new_session
()
logger
.
info
(
'Session cleared'
)
return
True
# Re-config during runtime
if
prompt
.
startswith
(
'config set'
):
try
:
keqv
=
prompt
.
split
()[
-
1
]
k
,
v
=
keqv
.
split
(
'='
)
v
=
eval
(
v
)
gen_config
.
__setattr__
(
k
,
v
)
logger
.
info
(
f
'Worker
{
get_rank
()
}
set
{
k
}
to
{
repr
(
v
)
}
'
)
logger
.
info
(
f
'Generator config changed to:
{
gen_config
}
'
)
return
True
except
:
# noqa
logger
.
info
(
'illegal instruction, treated as normal conversation. '
)
return
False
lmdeploy/serve/async_engine.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
asyncio
import
asyncio
import
dataclasses
import
dataclasses
import
os
import
random
import
random
from
contextlib
import
contextmanager
from
argparse
import
ArgumentError
from
typing
import
List
,
Literal
,
Optional
,
Union
from
contextlib
import
asynccontextmanager
from
itertools
import
count
from
queue
import
Empty
,
Queue
from
threading
import
Thread
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
lmdeploy.messages
import
(
EngineGenerationConfig
,
GenerationConfig
,
PytorchEngineConfig
,
Response
,
TurbomindEngineConfig
)
from
lmdeploy.model
import
ChatTemplateConfig
,
best_match_model
from
lmdeploy.tokenizer
import
DetokenizeState
from
lmdeploy.utils
import
_stop_words
,
get_logger
logger
=
get_logger
(
'lmdeploy'
)
def
get_model_name_from_workspace_model
(
model_dir
:
str
):
"""Get model name from workspace model."""
from
configparser
import
ConfigParser
triton_model_path
=
os
.
path
.
join
(
model_dir
,
'triton_models'
,
'weights'
)
if
not
os
.
path
.
exists
(
triton_model_path
):
return
None
ini_path
=
os
.
path
.
join
(
triton_model_path
,
'config.ini'
)
# load cfg
with
open
(
ini_path
,
'r'
)
as
f
:
parser
=
ConfigParser
()
parser
.
read_file
(
f
)
return
parser
[
'llama'
][
'model_name'
]
def
deduce_a_name
(
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
)
->
str
:
"""Deduce a model name from all the possible arguments."""
def
_config_model_name
(
config
):
if
config
and
config
.
model_name
:
return
config
.
model_name
return
None
backend_config_model_name
=
_config_model_name
(
backend_config
)
chat_template_config_model_name
=
_config_model_name
(
chat_template_config
)
model_name
=
model_name
or
chat_template_config_model_name
or
backend_config_model_name
# noqa
if
model_name
is
None
:
# model maybe from workspace for turbomind
model_name
=
get_model_name_from_workspace_model
(
model_path
)
# may get a model name from model_path
if
model_name
is
None
:
model_name
=
best_match_model
(
model_path
)
if
model_name
is
None
:
raise
ArgumentError
(
None
,
f
'Please set model_name for
{
model_path
}
'
)
else
:
logger
.
info
(
f
'matched chat template name:
{
model_name
}
'
)
return
model_name
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -16,6 +74,55 @@ class GenOut:
...
@@ -16,6 +74,55 @@ class GenOut:
finish_reason
:
Optional
[
Literal
[
'stop'
,
'length'
]]
=
None
finish_reason
:
Optional
[
Literal
[
'stop'
,
'length'
]]
=
None
class
Session
:
"""Session for AsyncEngine.chat.
Args:
_id (int): session_id for internal use.
_step (int): the offset of the k/v cache for internal use.
_prompt (Any): input prompt for internal use.
_response (Reaponse): model output for prompt.
_engine (Any): engine for internal use.
history (List[Any, str]): chat history.
"""
_ids
=
count
(
0
)
def
__init__
(
self
):
self
.
_id
:
int
=
next
(
self
.
_ids
)
self
.
_step
:
int
=
0
self
.
_prompt
:
Any
=
None
self
.
_response
:
Response
=
None
self
.
_engine
:
Any
=
None
self
.
history
:
List
[
Tuple
[
Any
,
str
]]
=
[]
def
_merge_response
(
self
,
resp
:
Response
,
step
:
Union
[
Response
,
GenOut
]):
"""merge response."""
resp
.
text
+=
step
.
text
if
isinstance
(
step
,
Response
)
else
step
.
response
resp
.
input_token_len
=
step
.
input_token_len
resp
.
generate_token_len
=
step
.
generate_token_len
resp
.
finish_reason
=
step
.
finish_reason
return
resp
@
property
def
response
(
self
)
->
Response
:
"""return response."""
return
self
.
_response
def
close
(
self
):
"""release engine storage for this session."""
if
self
.
_engine
:
inst
=
self
.
_engine
.
create_instance
()
inst
.
end
(
self
.
_id
)
def
__repr__
(
self
)
->
str
:
res
=
''
for
user
,
assistant
in
self
.
history
:
if
isinstance
(
user
,
list
):
user
=
str
(
user
)
res
+=
f
'USER:
\n
{
user
}
\n
ASSISTANT:
\n
{
assistant
}
\n
'
return
res
class
AsyncEngine
:
class
AsyncEngine
:
"""Async inference engine. Maintaining a bunch of tm_model instances.
"""Async inference engine. Maintaining a bunch of tm_model instances.
...
@@ -30,51 +137,150 @@ class AsyncEngine:
...
@@ -30,51 +137,150 @@ class AsyncEngine:
"InternLM/internlm-chat-20b-4bit",
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
on huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
and so on.
model_name (str): needed when model_path is a pytorch model on
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
instance_num (int): instance numbers to be created
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel
tp (int): tensor parallel
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
model_path
:
str
,
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
instance_num
:
int
=
32
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
tp
:
int
=
1
,
tp
:
int
=
1
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
from
lmdeploy
import
turbomind
as
tm
logger
.
info
(
self
.
tm_model
=
tm
.
TurboMind
.
from_pretrained
(
model_path
,
f
'input backend=
{
backend
}
, backend_config=
{
backend_config
}
'
)
model_name
=
model_name
,
logger
.
info
(
f
'input chat_template_config=
{
chat_template_config
}
'
)
tp
=
tp
,
**
kwargs
)
self
.
model_name
=
deduce_a_name
(
model_path
,
model_name
,
backend_config
,
self
.
tokenizer
=
self
.
tm_model
.
tokenizer
chat_template_config
)
self
.
instance_num
=
instance_num
# build chat template config
self
.
model
=
self
.
tm_model
.
model
if
chat_template_config
is
None
:
chat_template_config
=
ChatTemplateConfig
(
self
.
model_name
)
elif
chat_template_config
.
model_name
is
None
:
chat_template_config
.
model_name
=
self
.
model_name
self
.
chat_template
=
chat_template_config
.
chat_template
# prevent bc
for
k
in
list
(
kwargs
.
keys
()):
if
hasattr
(
chat_template_config
,
k
):
logger
.
warning
(
f
'
{
k
}
was deprecated. Please use '
'chat_template_config instead'
)
v
=
kwargs
.
pop
(
k
)
setattr
(
chat_template_config
,
k
,
v
)
logger
.
info
(
f
'updated chat_template_onfig=
{
chat_template_config
}
'
)
# build backend engine
if
backend
==
'turbomind'
:
self
.
_build_turbomind
(
model_path
=
model_path
,
backend_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
tp
=
tp
,
**
kwargs
)
elif
backend
==
'pytorch'
:
self
.
_build_pytorch
(
model_path
=
model_path
,
backend_config
=
backend_config
,
**
kwargs
)
else
:
raise
ValueError
(
f
'unsupported backend
{
backend
}
'
)
logger
.
info
(
f
'updated backend_config=
{
self
.
backend_config
}
'
)
# parameters for member functions
self
.
session_len
=
self
.
backend_config
.
session_len
self
.
stop_words
=
_stop_words
(
self
.
chat_template
.
stop_words
,
self
.
engine
.
tokenizer
)
if
self
.
stop_words
is
not
None
:
self
.
stop_words
=
self
.
stop_words
[
0
][
0
].
tolist
()
self
.
backend
=
backend
self
.
instance_num
=
self
.
backend_config
.
max_batch_size
self
.
tokenizer
=
self
.
engine
.
tokenizer
self
.
id2step
=
{}
self
.
id2step
=
{}
self
.
id2generator
=
{}
self
.
id2generator
=
{}
self
.
loop
=
asyncio
.
get_event_loop
()
self
.
loop
=
asyncio
.
get_event_loop
()
self
.
running_session_ids
=
set
()
self
.
gens_set
=
set
()
self
.
gens_set
=
set
()
for
i
in
range
(
instance_num
):
for
i
in
range
(
self
.
instance_num
):
self
.
gens_set
.
add
(
self
.
tm_model
.
create_instance
())
self
.
gens_set
.
add
(
self
.
engine
.
create_instance
())
def
_build_turbomind
(
self
,
model_path
:
str
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
tp
:
int
=
1
,
**
kwargs
):
"""Innter build method for turbomind backend."""
if
backend_config
is
None
:
backend_config
=
TurbomindEngineConfig
(
model_name
=
self
.
model_name
,
tp
=
tp
)
assert
isinstance
(
backend_config
,
TurbomindEngineConfig
),
'Please'
\
' use TurbomindEngineConfig imported from lmdeploy.messages for '
\
'turbomind backend'
if
backend_config
.
session_len
is
None
:
backend_config
.
session_len
=
self
.
chat_template
.
session_len
from
lmdeploy
import
turbomind
as
tm
self
.
engine
=
tm
.
TurboMind
.
from_pretrained
(
model_path
,
engine_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
**
kwargs
)
self
.
backend_config
=
backend_config
def
_build_pytorch
(
self
,
model_path
:
str
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
**
kwargs
):
"""Innter build method for pytorch backend."""
from
lmdeploy.pytorch.engine
import
Engine
if
backend_config
is
None
:
backend_config
=
PytorchEngineConfig
(
self
.
model_name
)
assert
isinstance
(
backend_config
,
PytorchEngineConfig
),
'Please '
\
'use PytorchEngineConfig imported from lmdeploy.messages for '
\
'pytorch backend'
if
backend_config
.
session_len
is
None
:
backend_config
.
session_len
=
self
.
chat_template
.
session_len
self
.
engine
=
Engine
(
model_path
=
model_path
,
engine_config
=
backend_config
)
self
.
backend_config
=
backend_config
def
__call__
(
self
,
def
__call__
(
self
,
prompts
:
List
[
str
],
prompts
:
Union
[
List
[
str
],
str
,
List
[
Dict
],
List
[
List
[
Dict
]]],
gen_config
:
Optional
[
GenerationConfig
]
=
None
,
request_output_len
=
512
,
request_output_len
=
512
,
top_k
=
40
,
top_k
:
int
=
40
,
top_p
=
0.8
,
top_p
:
float
=
0.8
,
temperature
=
0.8
,
temperature
:
float
=
0.8
,
repetition_penalty
=
1.0
,
repetition_penalty
:
float
=
1.0
,
ignore_eos
=
False
,
ignore_eos
:
bool
=
False
,
do_preprocess
=
True
,
do_preprocess
:
bool
=
True
,
**
kwargs
):
**
kwargs
):
"""Inference a batch of prompts.
"""Inference a batch of prompts.
Args:
Args:
prompts (List[str]): a batch of prompts
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
chat_template_config (ChatTemplateConfig | None):a instance of
ChatTemplateConfig. Default to None.
request_output_len (int): output token nums
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
tokens to keep for top-k-filtering
...
@@ -85,245 +291,363 @@ class AsyncEngine:
...
@@ -85,245 +291,363 @@ class AsyncEngine:
repetition_penalty (float): The parameter for repetition penalty.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
"""
if
gen_config
is
None
:
gen_config
=
GenerationConfig
(
max_new_tokens
=
request_output_len
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
)
return
self
.
batch_infer
(
prompts
,
return
self
.
batch_infer
(
prompts
,
request_output_len
=
request_output_len
,
gen_config
=
gen_config
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
,
do_preprocess
=
do_preprocess
,
do_preprocess
=
do_preprocess
,
**
kwargs
)
**
kwargs
)
def
stop_session
(
self
,
session_id
:
int
):
async
def
stop_session
(
self
,
session_id
:
int
):
"""Stop a session by a session_id."""
"""Stop a session by a session_id."""
input_ids
=
[
self
.
tm_model
.
eos_id
]
if
str
(
session_id
)
in
self
.
id2generator
:
stop_generator
=
self
.
tm_model
.
create_instance
()
await
self
.
id2generator
[
str
(
session_id
)].
async_cancel
(
session_id
)
for
outputs
in
stop_generator
.
stream_infer
(
session_id
,
input_ids
,
request_output_len
=
0
,
sequence_start
=
False
,
sequence_end
=
False
,
stop
=
True
):
pass
if
str
(
session_id
)
in
self
.
id2generator
and
self
.
id2generator
[
str
(
session_id
)]
not
in
self
.
gens_set
:
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
def
end_session
(
self
,
session_id
:
int
):
self
.
running_session_ids
.
discard
(
session_id
)
async
def
end_session
(
self
,
session_id
:
int
):
"""Clear a session by a session_id."""
"""Clear a session by a session_id."""
input_ids
=
[
self
.
tm_model
.
eos_id
]
if
str
(
session_id
)
in
self
.
id2generator
:
end_generator
=
self
.
tm_model
.
create_instance
()
await
self
.
id2generator
[
str
(
session_id
)].
async_end
(
session_id
)
for
outputs
in
end_generator
.
stream_infer
(
session_id
,
self
.
id2step
[
str
(
session_id
)]
=
0
input_ids
,
request_output_len
=
0
,
sequence_start
=
False
,
sequence_end
=
True
):
pass
self
.
id2step
[
str
(
session_id
)]
=
0
if
str
(
session_id
)
in
self
.
id2generator
and
self
.
id2generator
[
str
(
session_id
)]
not
in
self
.
gens_set
:
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
@
contextmanager
self
.
running_session_ids
.
discard
(
session_id
)
def
safe_run
(
self
,
session_id
:
Optional
[
int
]
=
None
):
@
asynccontextmanager
async
def
safe_run
(
self
,
session_id
:
Optional
[
int
]
=
None
):
"""A context manager to make sure server's safe running."""
"""A context manager to make sure server's safe running."""
try
:
try
:
yield
yield
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
# noqa
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
# noqa
self
.
stop_session
(
session_id
)
await
self
.
stop_session
(
session_id
)
raise
e
raise
e
if
str
(
session_id
)
in
self
.
id2generator
and
self
.
id2generator
[
str
(
if
str
(
session_id
)
in
self
.
id2generator
:
session_id
)]
not
in
self
.
gens_set
:
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
self
.
running_session_ids
.
discard
(
session_id
)
async
def
get_embeddings
(
self
,
prompt
,
do_prerpocess
=
False
):
if
do_prerpocess
:
prompt
=
self
.
model
.
get_prompt
(
prompt
)
input_ids
=
self
.
tokenizer
.
encode
(
prompt
)
return
input_ids
async
def
get_generator
(
self
,
stop
:
bool
,
session_id
:
int
):
async
def
get_generator
(
self
,
stop
:
bool
,
session_id
:
int
):
"""Only return the model instance if it is available."""
"""Only return the model instance if it is available."""
if
stop
:
if
stop
:
return
self
.
tm_model
.
create_instance
()
return
self
.
engine
.
create_instance
()
while
self
.
gens_set
==
set
():
# waiting no generator is available or the same session_id is running
await
asyncio
.
sleep
(
0
)
while
self
.
gens_set
==
set
()
or
session_id
in
self
.
running_session_ids
:
await
asyncio
.
sleep
(
0.1
)
generator
=
self
.
gens_set
.
pop
()
generator
=
self
.
gens_set
.
pop
()
self
.
id2generator
[
str
(
session_id
)]
=
generator
self
.
id2generator
[
str
(
session_id
)]
=
generator
self
.
running_session_ids
.
add
(
session_id
)
return
generator
return
generator
def
batch_infer
(
self
,
def
batch_infer
(
self
,
prompts
:
Union
[
List
[
str
],
str
],
prompts
:
Union
[
List
[
str
],
str
,
List
[
Dict
],
request_output_len
=
512
,
List
[
List
[
Dict
]]],
top_k
=
40
,
gen_config
:
Optional
[
Union
[
GenerationConfig
,
top_p
=
0.8
,
EngineGenerationConfig
]]
=
None
,
temperature
=
0.8
,
do_preprocess
:
bool
=
True
,
repetition_penalty
=
1.0
,
ignore_eos
=
False
,
do_preprocess
=
True
,
**
kwargs
):
**
kwargs
):
"""Inference a batch of prompts.
"""Inference a batch of prompts.
Args:
Args:
prompts (List[str] | str): a batch of prompts
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
request_output_len (int): output token nums
prompts. It accepts: string prompt, a list of string prompts,
top_k (int): The number of the highest probability vocabulary
a chat history in OpenAI format or a list of chat history.
tokens to keep for top-k-filtering
gen_config (GenerationConfig | None): a instance of
top_p (float): If set to float < 1, only the smallest set of most
GenerationConfig. Default to None.
probable tokens with probabilities that add up to top_p or higher
do_preprocess (bool): whether pre-process the messages. Default to
are kept for generation.
True, which means chat_template will be applied.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
"""
"""
input_str
=
isinstance
(
prompts
,
str
)
need_list_wrap
=
isinstance
(
prompts
,
str
)
or
isinstance
(
prompts
=
[
prompts
]
if
input_str
else
prompts
prompts
[
0
],
Dict
)
prompts
=
[
prompts
]
if
need_list_wrap
else
prompts
assert
isinstance
(
prompts
,
List
),
'prompts should be a list'
assert
isinstance
(
prompts
,
List
),
'prompts should be a list'
batch_size
=
len
(
prompts
)
if
gen_config
is
None
:
outputs
=
[
''
]
*
batch_size
gen_config
=
GenerationConfig
()
generators
=
[]
if
type
(
gen_config
)
is
GenerationConfig
:
for
i
,
prompt
in
enumerate
(
prompts
):
gen_config
=
EngineGenerationConfig
.
From
(
gen_config
,
generators
.
append
(
self
.
tokenizer
)
self
.
generate
(
prompt
,
# set random if it is not set
i
,
if
gen_config
.
random_seed
is
None
:
stream_response
=
True
,
gen_config
.
random_seed
=
random
.
getrandbits
(
64
)
sequence_start
=
True
,
prompt_num
=
len
(
prompts
)
sequence_end
=
True
,
outputs
=
[
Response
(
''
,
0
,
0
,
i
)
for
i
in
range
(
prompt_num
)]
request_output_len
=
request_output_len
,
for
j
in
range
(
0
,
prompt_num
,
self
.
instance_num
):
top_k
=
top_k
,
batch_prompts
=
prompts
[
j
:
j
+
self
.
instance_num
]
top_p
=
top_p
,
generators
=
[]
temperature
=
temperature
,
for
i
,
prompt
in
enumerate
(
batch_prompts
):
ignore_eos
=
ignore_eos
,
generators
.
append
(
repetition_penalty
=
repetition_penalty
,
self
.
generate
(
prompt
,
do_preprocess
=
do_preprocess
,
i
,
**
kwargs
))
gen_config
=
gen_config
,
stream_response
=
True
,
async
def
_inner_call
(
i
,
generator
):
sequence_start
=
True
,
async
for
out
in
generator
:
sequence_end
=
True
,
outputs
[
i
]
+=
out
.
response
do_preprocess
=
do_preprocess
,
**
kwargs
))
async
def
gather
():
await
asyncio
.
gather
(
async
def
_inner_call
(
i
,
generator
):
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
batch_size
)])
async
for
out
in
generator
:
outputs
[
i
+
j
].
text
+=
out
.
response
self
.
loop
.
run_until_complete
(
gather
())
outputs
[
i
+
j
].
generate_token_len
=
out
.
generate_token_len
outputs
=
outputs
[
0
]
if
input_str
else
outputs
outputs
[
i
+
j
].
input_token_len
=
out
.
input_token_len
outputs
[
i
+
j
].
finish_reason
=
out
.
finish_reason
async
def
gather
():
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
len
(
batch_prompts
))
])
self
.
loop
.
run_until_complete
(
gather
())
outputs
=
outputs
[
0
]
if
need_list_wrap
else
outputs
return
outputs
return
outputs
def
stream_infer
(
self
,
prompts
:
Union
[
List
[
str
],
str
,
List
[
Dict
],
List
[
List
[
Dict
]]],
gen_config
:
Optional
[
Union
[
GenerationConfig
,
EngineGenerationConfig
]]
=
None
,
do_preprocess
:
bool
=
True
,
**
kwargs
):
"""Inference a batch of prompts with stream mode.
Args:
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
need_list_wrap
=
isinstance
(
prompts
,
str
)
or
isinstance
(
prompts
[
0
],
Dict
)
prompts
=
[
prompts
]
if
need_list_wrap
else
prompts
assert
isinstance
(
prompts
,
List
),
'prompts should be a list'
if
gen_config
is
None
:
gen_config
=
GenerationConfig
()
if
type
(
gen_config
)
is
GenerationConfig
:
gen_config
=
EngineGenerationConfig
.
From
(
gen_config
,
self
.
tokenizer
)
# set random if it is not set
if
gen_config
.
random_seed
is
None
:
gen_config
.
random_seed
=
random
.
getrandbits
(
64
)
prompt_num
=
len
(
prompts
)
outputs
=
Queue
()
generators
=
[]
for
j
in
range
(
0
,
prompt_num
,
self
.
instance_num
):
batch_prompts
=
prompts
[
j
:
j
+
self
.
instance_num
]
generators
=
[]
for
i
,
prompt
in
enumerate
(
batch_prompts
):
generators
.
append
(
self
.
generate
(
prompt
,
i
,
gen_config
=
gen_config
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
do_preprocess
=
do_preprocess
,
**
kwargs
))
async
def
_inner_call
(
i
,
generator
):
async
for
out
in
generator
:
outputs
.
put
(
Response
(
out
.
response
,
out
.
generate_token_len
,
out
.
input_token_len
,
i
+
j
,
out
.
finish_reason
))
async
def
gather
():
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
len
(
batch_prompts
))
])
outputs
.
put
(
None
)
proc
=
Thread
(
target
=
lambda
:
self
.
loop
.
run_until_complete
(
gather
()))
proc
.
start
()
while
True
:
try
:
out
=
outputs
.
get
(
timeout
=
0.001
)
if
out
is
None
:
break
yield
out
except
Empty
:
pass
proc
.
join
()
async
def
_get_prompt_input
(
self
,
prompt
:
str
,
do_preprocess
:
bool
,
sequence_start
:
bool
):
if
do_preprocess
:
prompt
=
self
.
chat_template
.
messages2prompt
(
prompt
,
sequence_start
)
input_ids
=
self
.
tokenizer
.
encode
(
prompt
,
add_bos
=
sequence_start
)
return
{
'prompt'
:
prompt
,
'input_ids'
:
input_ids
}
async
def
generate
(
async
def
generate
(
self
,
self
,
messages
,
messages
,
session_id
,
session_id
:
int
,
stream_response
=
True
,
gen_config
:
Optional
[
Union
[
GenerationConfig
,
sequence_start
=
True
,
EngineGenerationConfig
]]
=
None
,
sequence_end
=
True
,
# no interactive mode by default
stream_response
:
bool
=
True
,
step
=
0
,
sequence_start
:
bool
=
True
,
request_output_len
=
512
,
sequence_end
:
bool
=
True
,
# no interactive mode by default
stop
=
False
,
step
:
int
=
0
,
top_k
=
40
,
do_preprocess
:
bool
=
True
,
top_p
=
0.8
,
temperature
=
0.8
,
repetition_penalty
=
1.0
,
ignore_eos
=
False
,
do_preprocess
=
True
,
**
kwargs
):
**
kwargs
):
"""Generate responses.
"""Generate responses.
Args:
Args:
messages (str | List): chat history or prompt
messages (str | List): chat history or prompt
session_id (int): the session id
session_id (int): the session id
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
stream_response (bool): whether return responses streamingly
stream_response (bool): whether return responses streamingly
request_output_len (int): output token nums
sequence_start (bool): indicator for starting a sequence
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
step (int): the offset of the k/v cache
stop (bool): whether stop inference
do_preprocess (bool): whether pre-process the messages. Default to
top_k (int): The number of the highest probability vocabulary
True, which means chat_template will be applied.
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
"""
"""
if
str
(
session_id
)
not
in
self
.
id2step
:
if
str
(
session_id
)
not
in
self
.
id2step
:
self
.
id2step
[
str
(
session_id
)]
=
0
self
.
id2step
[
str
(
session_id
)]
=
0
if
step
!=
0
:
if
step
!=
0
:
self
.
id2step
[
str
(
session_id
)]
=
step
self
.
id2step
[
str
(
session_id
)]
=
step
seed
=
random
.
getrandbits
(
64
)
if
gen_config
is
None
:
gen_config
=
GenerationConfig
()
if
type
(
gen_config
)
is
GenerationConfig
:
gen_config
=
EngineGenerationConfig
.
From
(
gen_config
,
self
.
tokenizer
)
if
gen_config
.
stop_words
is
None
:
gen_config
.
stop_words
=
self
.
stop_words
# set random if it is not set and sequence_start is True
if
gen_config
.
random_seed
is
None
and
sequence_start
:
gen_config
.
random_seed
=
random
.
getrandbits
(
64
)
prompt
=
messages
prompt
=
messages
if
do_preprocess
:
prompt
=
self
.
model
.
messages2prompt
(
prompt
,
sequence_start
)
input_ids
=
self
.
tokenizer
.
encode
(
prompt
,
add_bos
=
sequence_start
)
finish_reason
=
None
request_output_len
=
min
(
prompt_input
=
await
self
.
_get_prompt_input
(
prompt
,
do_preprocess
,
request_output_len
,
self
.
tm_model
.
session_len
-
self
.
id2step
[
str
(
session_id
)]
-
sequence_start
)
prompt
=
prompt_input
[
'prompt'
]
logger
.
info
(
f
'Prompt with applied chat template:
\n
{
prompt
}
'
)
input_ids
=
prompt_input
[
'input_ids'
]
if
gen_config
.
max_new_tokens
is
None
:
# for interactive endpoint, will try maximum possible token num
gen_config
.
max_new_tokens
=
max
(
128
,
self
.
session_len
-
self
.
id2step
[
str
(
session_id
)]
-
len
(
input_ids
))
len
(
input_ids
))
request_output_len
=
max
(
0
,
request_output_len
)
finish_reason
=
None
logger
.
info
(
f
'session_id=
{
session_id
}
, '
if
stop
is
True
:
f
'history_tokens=
{
self
.
id2step
[
str
(
session_id
)]
}
, '
self
.
stop_session
(
session_id
)
f
'input_tokens=
{
len
(
input_ids
)
}
, '
yield
GenOut
(
''
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
0
,
f
'max_new_tokens=
{
gen_config
.
max_new_tokens
}
, '
finish_reason
)
f
'seq_start=
{
sequence_start
}
, seq_end=
{
sequence_end
}
, '
elif
self
.
id2step
[
str
(
session_id
)]
+
len
(
f
'step=
{
step
}
, prep=
{
do_preprocess
}
'
)
input_ids
)
+
request_output_len
>
self
.
tm_model
.
session_len
:
if
self
.
id2step
[
str
(
session_id
)]
+
len
(
input_ids
)
+
gen_config
.
max_new_tokens
>
self
.
session_len
:
logger
.
warning
(
f
'run out of tokens. session_id=
{
session_id
}
'
)
finish_reason
=
'length'
finish_reason
=
'length'
yield
GenOut
(
''
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
0
,
yield
GenOut
(
''
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
0
,
finish_reason
)
finish_reason
)
if
sequence_end
is
True
and
sequence_start
is
False
:
if
sequence_end
is
True
and
sequence_start
is
False
:
self
.
end_session
(
session_id
)
await
self
.
end_session
(
session_id
)
else
:
else
:
generator
=
await
self
.
get_generator
(
stop
,
session_id
)
generator
=
await
self
.
get_generator
(
False
,
session_id
)
with
self
.
safe_run
(
session_id
):
async
with
self
.
safe_run
(
session_id
):
response_size
=
0
state
=
DetokenizeState
()
async
for
outputs
in
generator
.
async_stream_infer
(
async
for
outputs
in
generator
.
async_stream_infer
(
session_id
=
session_id
,
session_id
=
session_id
,
input_ids
=
[
input_ids
],
**
prompt_input
,
gen_config
=
gen_config
,
stream_output
=
stream_response
,
stream_output
=
stream_response
,
request_output_len
=
request_output_len
,
sequence_start
=
sequence_start
,
sequence_start
=
(
sequence_start
),
sequence_end
=
sequence_end
,
sequence_end
=
sequence_end
,
step
=
self
.
id2step
[
str
(
session_id
)],
step
=
self
.
id2step
[
str
(
session_id
)]):
stop
=
stop
,
_
,
res
,
tokens
=
outputs
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
,
random_seed
=
seed
if
sequence_start
else
None
):
res
,
tokens
=
outputs
[
0
]
# decode res
# decode res
response
=
self
.
tokenizer
.
decode
(
res
.
tolist
(),
response
,
state
=
self
.
tokenizer
.
detokenize_incrementally
(
offset
=
response_size
)
res
,
# utf-8 char at the end means it's a potential unfinished
state
,
# byte sequence, continue to concate it with the next
skip_special_tokens
=
gen_config
.
skip_special_tokens
)
# sequence and decode them together
if
response
.
endswith
(
'�'
):
continue
# response, history token len,
# response, history token len,
# input token len, gen token len
# input token len, gen token len
yield
GenOut
(
response
,
self
.
id2step
[
str
(
session_id
)],
yield
GenOut
(
response
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
tokens
,
finish_reason
)
len
(
input_ids
),
tokens
,
finish_reason
)
response_size
=
tokens
finish_reason
=
'length'
\
finish_reason
=
'length'
\
if
tokens
>=
request_output_l
en
else
'stop'
if
tokens
>=
gen_config
.
max_new_tok
en
s
else
'stop'
#
`response_size` might be note updated since
#
utf-8 char at the end means it's a potential unfinished
#
` if response.endswith('�')`
#
byte sequence
if
response
_size
==
tokens
:
if
not
response
.
endswith
(
'�'
)
:
response
=
''
# avaid returning the last response twice
response
=
''
# avaid returning the last response twice
yield
GenOut
(
response
,
self
.
id2step
[
str
(
session_id
)],
yield
GenOut
(
response
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
tokens
,
finish_reason
)
len
(
input_ids
),
tokens
,
finish_reason
)
# update step
# update step
self
.
id2step
[
str
(
session_id
)]
+=
len
(
input_ids
)
+
tokens
self
.
id2step
[
str
(
session_id
)]
+=
len
(
input_ids
)
+
tokens
if
sequence_end
or
stop
:
if
sequence_end
:
self
.
id2step
[
str
(
session_id
)]
=
0
self
.
id2step
[
str
(
session_id
)]
=
0
# manually end pytorch session
# TODO modify pytorch or turbomind api
if
self
.
backend
==
'pytorch'
and
sequence_end
:
await
self
.
end_session
(
session_id
)
def
chat
(
self
,
prompt
:
str
,
session
=
None
,
gen_config
:
Optional
[
Union
[
GenerationConfig
,
EngineGenerationConfig
]]
=
None
,
do_preprocess
:
bool
=
True
,
**
kwargs
)
->
Session
:
"""Chat.
Args:
prompt (str): prompt
session (Session): the chat session
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
**kwargs (dict): ad hoc parametrization of `gen_config
"""
if
session
is
None
:
session
=
Session
()
session
.
_engine
=
self
.
engine
# sync & init
session
.
_prompt
=
prompt
session
.
_response
=
None
sequence_start
=
session
.
_step
==
0
async
def
_work
():
resp
=
Response
(
''
,
-
1
,
-
1
,
session
.
_id
)
async
for
output
in
self
.
generate
(
prompt
,
session_id
=
session
.
_id
,
gen_config
=
gen_config
,
stream_response
=
False
,
sequence_start
=
sequence_start
,
sequence_end
=
False
,
step
=
session
.
_step
,
do_preprocess
=
do_preprocess
,
**
kwargs
):
resp
=
session
.
_merge_response
(
resp
,
output
)
return
resp
from
lmdeploy.pytorch.engine.request
import
_run_until_complete
resp
=
_run_until_complete
(
_work
())
session
.
_response
=
resp
session
.
_step
+=
resp
.
generate_token_len
+
resp
.
input_token_len
session
.
history
.
append
((
session
.
_prompt
,
resp
.
text
))
return
session
lmdeploy/serve/gradio/api_server_backend.py
View file @
d7117b95
...
@@ -17,7 +17,8 @@ class InterFace:
...
@@ -17,7 +17,8 @@ class InterFace:
def
chat_stream_restful
(
instruction
:
str
,
state_chatbot
:
Sequence
,
def
chat_stream_restful
(
instruction
:
str
,
state_chatbot
:
Sequence
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
):
session_id
:
int
,
top_p
:
float
,
temperature
:
float
,
request_output_len
:
int
):
"""Chat with AI assistant.
"""Chat with AI assistant.
Args:
Args:
...
@@ -33,9 +34,11 @@ def chat_stream_restful(instruction: str, state_chatbot: Sequence,
...
@@ -33,9 +34,11 @@ def chat_stream_restful(instruction: str, state_chatbot: Sequence,
instruction
,
instruction
,
f
'
{
InterFace
.
api_server_url
}
/v1/chat/interactive'
,
f
'
{
InterFace
.
api_server_url
}
/v1/chat/interactive'
,
session_id
=
session_id
,
session_id
=
session_id
,
request_output_len
=
512
,
request_output_len
=
request_output_len
,
interactive_mode
=
True
):
interactive_mode
=
True
,
if
finish_reason
==
'length'
:
top_p
=
top_p
,
temperature
=
temperature
):
if
finish_reason
==
'length'
and
tokens
==
0
:
gr
.
Warning
(
'WARNING: exceed session max length.'
gr
.
Warning
(
'WARNING: exceed session max length.'
' Please restart the session by reset button.'
)
' Please restart the session by reset button.'
)
if
tokens
<
0
:
if
tokens
<
0
:
...
@@ -94,7 +97,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
...
@@ -94,7 +97,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
f
'
{
InterFace
.
api_server_url
}
/v1/chat/interactive'
,
f
'
{
InterFace
.
api_server_url
}
/v1/chat/interactive'
,
session_id
=
session_id
,
session_id
=
session_id
,
request_output_len
=
0
,
request_output_len
=
0
,
stop
=
True
,
cancel
=
True
,
interactive_mode
=
True
):
interactive_mode
=
True
):
pass
pass
# end the session
# end the session
...
@@ -106,6 +109,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
...
@@ -106,6 +109,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
interactive_mode
=
False
):
interactive_mode
=
False
):
pass
pass
# resume the session
# resume the session
# TODO this is not proper if api server is running pytorch backend
messages
=
[]
messages
=
[]
for
qa
in
state_chatbot
:
for
qa
in
state_chatbot
:
messages
.
append
(
dict
(
role
=
'user'
,
content
=
qa
[
0
]))
messages
.
append
(
dict
(
role
=
'user'
,
content
=
qa
[
0
]))
...
@@ -155,10 +159,22 @@ def run_api_server(api_server_url: str,
...
@@ -155,10 +159,22 @@ def run_api_server(api_server_url: str,
with
gr
.
Row
():
with
gr
.
Row
():
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
with
gr
.
Row
():
request_output_len
=
gr
.
Slider
(
1
,
2048
,
value
=
512
,
step
=
1
,
label
=
'Maximum new tokens'
)
top_p
=
gr
.
Slider
(
0.01
,
1
,
value
=
0.8
,
step
=
0.01
,
label
=
'Top_p'
)
temperature
=
gr
.
Slider
(
0.01
,
1.5
,
value
=
0.7
,
step
=
0.01
,
label
=
'Temperature'
)
send_event
=
instruction_txtbox
.
submit
(
chat_stream_restful
,
[
send_event
=
instruction_txtbox
.
submit
(
chat_stream_restful
,
[
instruction_txtbox
,
state_chatbot
,
cancel_btn
,
reset_btn
,
instruction_txtbox
,
state_chatbot
,
cancel_btn
,
reset_btn
,
state_session_id
state_session_id
,
top_p
,
temperature
,
request_output_len
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
instruction_txtbox
.
submit
(
instruction_txtbox
.
submit
(
lambda
:
gr
.
Textbox
.
update
(
value
=
''
),
lambda
:
gr
.
Textbox
.
update
(
value
=
''
),
...
...
lmdeploy/serve/gradio/app.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Literal
,
Optional
,
Union
from
lmdeploy.archs
import
get_task
from
lmdeploy.messages
import
PytorchEngineConfig
,
TurbomindEngineConfig
from
lmdeploy.model
import
ChatTemplateConfig
def
run
(
model_path_or_server
:
str
,
def
run
(
model_path_or_server
:
str
,
server_name
:
str
=
'0.0.0.0'
,
server_name
:
str
=
'0.0.0.0'
,
server_port
:
int
=
6006
,
server_port
:
int
=
6006
,
batch_size
:
int
=
32
,
batch_size
:
int
=
32
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
PytorchEngineConfig
,
TurbomindEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
tp
:
int
=
1
,
tp
:
int
=
1
,
model_name
:
str
=
None
,
model_name
:
str
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -19,6 +28,12 @@ def run(model_path_or_server: str,
...
@@ -19,6 +28,12 @@ def run(model_path_or_server: str,
server_name (str): the ip address of gradio server
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
batch_size (int): batch size for running Turbomind directly
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel for Turbomind
tp (int): tensor parallel for Turbomind
"""
"""
if
':'
in
model_path_or_server
:
if
':'
in
model_path_or_server
:
...
@@ -31,11 +46,22 @@ def run(model_path_or_server: str,
...
@@ -31,11 +46,22 @@ def run(model_path_or_server: str,
run_triton_server
run_triton_server
run_triton_server
(
model_path_or_server
,
server_name
,
server_port
)
run_triton_server
(
model_path_or_server
,
server_name
,
server_port
)
else
:
else
:
from
lmdeploy.serve.gradio.turbomind_coupled
import
run_local
pipeline_type
,
_
=
get_task
(
model_path_or_server
)
if
pipeline_type
==
'vlm'
:
from
lmdeploy.serve.gradio.vl
import
run_local
assert
backend
==
'turbomind'
,
'vlm only support turbomind backend'
if
backend_config
is
not
None
and
\
backend_config
.
session_len
is
None
:
backend_config
.
session_len
=
8192
else
:
from
lmdeploy.serve.gradio.turbomind_coupled
import
run_local
run_local
(
model_path_or_server
,
run_local
(
model_path_or_server
,
model_name
=
model_name
,
server_name
=
server_name
,
server_name
=
server_name
,
server_port
=
server_port
,
server_port
=
server_port
,
backend
=
backend
,
backend_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
model_name
=
model_name
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
tp
=
tp
,
tp
=
tp
,
**
kwargs
)
**
kwargs
)
...
...
lmdeploy/serve/gradio/constants.py
View file @
d7117b95
...
@@ -24,5 +24,5 @@ THEME = gr.themes.Soft(
...
@@ -24,5 +24,5 @@ THEME = gr.themes.Soft(
secondary_hue
=
gr
.
themes
.
colors
.
sky
,
secondary_hue
=
gr
.
themes
.
colors
.
sky
,
font
=
[
gr
.
themes
.
GoogleFont
(
'Inconsolata'
),
'Arial'
,
'sans-serif'
])
font
=
[
gr
.
themes
.
GoogleFont
(
'Inconsolata'
),
'Arial'
,
'sans-serif'
])
enable_btn
=
gr
.
Button
.
update
(
interactive
=
True
)
enable_btn
=
gr
.
update
(
interactive
=
True
)
disable_btn
=
gr
.
Button
.
update
(
interactive
=
False
)
disable_btn
=
gr
.
update
(
interactive
=
False
)
lmdeploy/serve/gradio/triton_server_backend.py
View file @
d7117b95
...
@@ -16,7 +16,8 @@ class InterFace:
...
@@ -16,7 +16,8 @@ class InterFace:
def
chat_stream
(
state_chatbot
:
Sequence
,
llama_chatbot
:
Chatbot
,
def
chat_stream
(
state_chatbot
:
Sequence
,
llama_chatbot
:
Chatbot
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
):
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
,
top_p
:
float
,
temperature
:
float
,
request_output_len
:
int
):
"""Chat with AI assistant.
"""Chat with AI assistant.
Args:
Args:
...
@@ -30,7 +31,12 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
...
@@ -30,7 +31,12 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction
=
state_chatbot
[
-
1
][
0
]
instruction
=
state_chatbot
[
-
1
][
0
]
bot_response
=
llama_chatbot
.
stream_infer
(
bot_response
=
llama_chatbot
.
stream_infer
(
session_id
,
instruction
,
f
'
{
session_id
}
-
{
len
(
state_chatbot
)
}
'
)
session_id
,
instruction
,
f
'
{
session_id
}
-
{
len
(
state_chatbot
)
}
'
,
request_output_len
=
request_output_len
,
top_p
=
top_p
,
temperature
=
temperature
)
for
status
,
tokens
,
_
in
bot_response
:
for
status
,
tokens
,
_
in
bot_response
:
state_chatbot
[
-
1
]
=
(
state_chatbot
[
-
1
][
0
],
tokens
)
state_chatbot
[
-
1
]
=
(
state_chatbot
[
-
1
][
0
],
tokens
)
...
@@ -108,12 +114,24 @@ def run_triton_server(triton_server_addr: str,
...
@@ -108,12 +114,24 @@ def run_triton_server(triton_server_addr: str,
with
gr
.
Row
():
with
gr
.
Row
():
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
with
gr
.
Row
():
request_output_len
=
gr
.
Slider
(
1
,
2048
,
value
=
512
,
step
=
1
,
label
=
'Maximum new tokens'
)
top_p
=
gr
.
Slider
(
0.01
,
1
,
value
=
0.8
,
step
=
0.01
,
label
=
'Top_p'
)
temperature
=
gr
.
Slider
(
0.01
,
1.5
,
value
=
0.7
,
step
=
0.01
,
label
=
'Temperature'
)
send_event
=
instruction_txtbox
.
submit
(
send_event
=
instruction_txtbox
.
submit
(
add_instruction
,
[
instruction_txtbox
,
state_chatbot
],
add_instruction
,
[
instruction_txtbox
,
state_chatbot
],
[
instruction_txtbox
,
state_chatbot
]).
then
(
chat_stream
,
[
[
instruction_txtbox
,
state_chatbot
]).
then
(
chat_stream
,
[
state_chatbot
,
llama_chatbot
,
cancel_btn
,
reset_btn
,
state_chatbot
,
llama_chatbot
,
cancel_btn
,
reset_btn
,
state_session_id
state_session_id
,
top_p
,
temperature
,
request_output_len
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
cancel_btn
.
click
(
cancel_func
,
cancel_btn
.
click
(
cancel_func
,
...
...
lmdeploy/serve/gradio/turbomind_coupled.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
random
from
threading
import
Lock
from
threading
import
Lock
from
typing
import
Optional
,
Sequence
from
typing
import
Literal
,
Optional
,
Sequence
,
Union
import
gradio
as
gr
import
gradio
as
gr
from
lmdeploy.messages
import
(
GenerationConfig
,
PytorchEngineConfig
,
TurbomindEngineConfig
)
from
lmdeploy.model
import
ChatTemplateConfig
from
lmdeploy.serve.async_engine
import
AsyncEngine
from
lmdeploy.serve.async_engine
import
AsyncEngine
from
lmdeploy.serve.gradio.constants
import
CSS
,
THEME
,
disable_btn
,
enable_btn
from
lmdeploy.serve.gradio.constants
import
CSS
,
THEME
,
disable_btn
,
enable_btn
...
@@ -14,13 +18,10 @@ class InterFace:
...
@@ -14,13 +18,10 @@ class InterFace:
lock
=
Lock
()
lock
=
Lock
()
async
def
chat_stream_local
(
async
def
chat_stream_local
(
instruction
:
str
,
state_chatbot
:
Sequence
,
instruction
:
str
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
state_chatbot
:
Sequence
,
session_id
:
int
,
top_p
:
float
,
temperature
:
float
,
cancel_btn
:
gr
.
Button
,
request_output_len
:
int
):
reset_btn
:
gr
.
Button
,
session_id
:
int
,
):
"""Chat with AI assistant.
"""Chat with AI assistant.
Args:
Args:
...
@@ -33,15 +34,23 @@ async def chat_stream_local(
...
@@ -33,15 +34,23 @@ async def chat_stream_local(
state_chatbot
=
state_chatbot
+
[(
instruction
,
None
)]
state_chatbot
=
state_chatbot
+
[(
instruction
,
None
)]
yield
(
state_chatbot
,
state_chatbot
,
disable_btn
,
enable_btn
)
yield
(
state_chatbot
,
state_chatbot
,
disable_btn
,
enable_btn
)
gen_config
=
GenerationConfig
(
max_new_tokens
=
request_output_len
,
top_p
=
top_p
,
top_k
=
40
,
temperature
=
temperature
,
random_seed
=
random
.
getrandbits
(
64
)
if
len
(
state_chatbot
)
==
1
else
None
)
async
for
outputs
in
InterFace
.
async_engine
.
generate
(
async
for
outputs
in
InterFace
.
async_engine
.
generate
(
instruction
,
instruction
,
session_id
,
session_id
,
gen_config
=
gen_config
,
stream_response
=
True
,
stream_response
=
True
,
sequence_start
=
(
len
(
state_chatbot
)
==
1
),
sequence_start
=
(
len
(
state_chatbot
)
==
1
),
sequence_end
=
False
):
sequence_end
=
False
):
response
=
outputs
.
response
response
=
outputs
.
response
if
outputs
.
finish_reason
==
'length'
:
if
outputs
.
finish_reason
==
'length'
and
\
outputs
.
generate_token_len
==
0
:
gr
.
Warning
(
'WARNING: exceed session max length.'
gr
.
Warning
(
'WARNING: exceed session max length.'
' Please restart the session by reset button.'
)
' Please restart the session by reset button.'
)
if
outputs
.
generate_token_len
<
0
:
if
outputs
.
generate_token_len
<
0
:
...
@@ -69,7 +78,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
...
@@ -69,7 +78,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
"""
"""
state_chatbot
=
[]
state_chatbot
=
[]
# end the session
# end the session
InterFace
.
async_engine
.
end_session
(
session_id
)
await
InterFace
.
async_engine
.
end_session
(
session_id
)
return
(
state_chatbot
,
state_chatbot
,
gr
.
Textbox
.
update
(
value
=
''
))
return
(
state_chatbot
,
state_chatbot
,
gr
.
Textbox
.
update
(
value
=
''
))
...
@@ -85,28 +94,36 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
...
@@ -85,28 +94,36 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
session_id (int): the session id
session_id (int): the session id
"""
"""
yield
(
state_chatbot
,
disable_btn
,
disable_btn
)
yield
(
state_chatbot
,
disable_btn
,
disable_btn
)
InterFace
.
async_engine
.
stop_session
(
session_id
)
await
InterFace
.
async_engine
.
stop_session
(
session_id
)
InterFace
.
async_engine
.
end_session
(
session_id
)
# pytorch backend does not support resume chat history now
messages
=
[]
if
InterFace
.
async_engine
.
backend
==
'pytorch'
:
for
qa
in
state_chatbot
:
yield
(
state_chatbot
,
disable_btn
,
enable_btn
)
messages
.
append
(
dict
(
role
=
'user'
,
content
=
qa
[
0
]))
else
:
if
qa
[
1
]
is
not
None
:
await
InterFace
.
async_engine
.
end_session
(
session_id
)
messages
.
append
(
dict
(
role
=
'assistant'
,
content
=
qa
[
1
]))
messages
=
[]
async
for
out
in
InterFace
.
async_engine
.
generate
(
messages
,
for
qa
in
state_chatbot
:
session_id
,
messages
.
append
(
dict
(
role
=
'user'
,
content
=
qa
[
0
]))
request_output_len
=
0
,
if
qa
[
1
]
is
not
None
:
stream_response
=
True
,
messages
.
append
(
dict
(
role
=
'assistant'
,
content
=
qa
[
1
]))
sequence_start
=
True
,
gen_config
=
GenerationConfig
(
max_new_tokens
=
0
)
sequence_end
=
False
):
async
for
out
in
InterFace
.
async_engine
.
generate
(
messages
,
pass
session_id
,
yield
(
state_chatbot
,
disable_btn
,
enable_btn
)
gen_config
=
gen_config
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
False
):
pass
yield
(
state_chatbot
,
disable_btn
,
enable_btn
)
def
run_local
(
model_path
:
str
,
def
run_local
(
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
server_name
:
str
=
'localhost'
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
PytorchEngineConfig
,
TurbomindEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
server_name
:
str
=
'0.0.0.0'
,
server_port
:
int
=
6006
,
server_port
:
int
=
6006
,
batch_size
:
int
=
4
,
tp
:
int
=
1
,
tp
:
int
=
1
,
**
kwargs
):
**
kwargs
):
"""chat with AI assistant through web ui.
"""chat with AI assistant through web ui.
...
@@ -122,22 +139,32 @@ def run_local(model_path: str,
...
@@ -122,22 +139,32 @@ def run_local(model_path: str,
"InternLM/internlm-chat-20b-4bit",
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
on huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
and so on.
model_name (str): needed when model_path is a pytorch model on
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
server_name (str): the ip address of gradio server
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): the ip address of gradio server. Default to
"0.0.0.0". For huggingface space demo, it should be
"huggingface-space".
server_port (int): the port of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
tp (int): tensor parallel for Turbomind
"""
"""
InterFace
.
async_engine
=
AsyncEngine
(
model_path
=
model_path
,
InterFace
.
async_engine
=
AsyncEngine
(
model_name
=
model_name
,
model_path
=
model_path
,
instance_num
=
batch_size
,
backend
=
backend
,
tp
=
tp
,
backend_config
=
backend_config
,
**
kwargs
)
chat_template_config
=
chat_template_config
,
model_name
=
model_name
,
tp
=
tp
,
**
kwargs
)
with
gr
.
Blocks
(
css
=
CSS
,
theme
=
THEME
)
as
demo
:
with
gr
.
Blocks
(
css
=
CSS
,
theme
=
THEME
)
as
demo
:
state_chatbot
=
gr
.
State
([])
state_chatbot
=
gr
.
State
([])
...
@@ -148,17 +175,29 @@ def run_local(model_path: str,
...
@@ -148,17 +175,29 @@ def run_local(model_path: str,
chatbot
=
gr
.
Chatbot
(
chatbot
=
gr
.
Chatbot
(
elem_id
=
'chatbot'
,
elem_id
=
'chatbot'
,
label
=
InterFace
.
async_engine
.
tm_model
.
model_name
)
label
=
InterFace
.
async_engine
.
engine
.
model_name
)
instruction_txtbox
=
gr
.
Textbox
(
instruction_txtbox
=
gr
.
Textbox
(
placeholder
=
'Please input the instruction'
,
placeholder
=
'Please input the instruction'
,
label
=
'Instruction'
)
label
=
'Instruction'
)
with
gr
.
Row
():
with
gr
.
Row
():
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
with
gr
.
Row
():
request_output_len
=
gr
.
Slider
(
1
,
2048
,
value
=
512
,
step
=
1
,
label
=
'Maximum new tokens'
)
top_p
=
gr
.
Slider
(
0.01
,
1
,
value
=
0.8
,
step
=
0.01
,
label
=
'Top_p'
)
temperature
=
gr
.
Slider
(
0.01
,
1.5
,
value
=
0.7
,
step
=
0.01
,
label
=
'Temperature'
)
send_event
=
instruction_txtbox
.
submit
(
chat_stream_local
,
[
send_event
=
instruction_txtbox
.
submit
(
chat_stream_local
,
[
instruction_txtbox
,
state_chatbot
,
cancel_btn
,
reset_btn
,
instruction_txtbox
,
state_chatbot
,
cancel_btn
,
reset_btn
,
state_session_id
state_session_id
,
top_p
,
temperature
,
request_output_len
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
instruction_txtbox
.
submit
(
instruction_txtbox
.
submit
(
lambda
:
gr
.
Textbox
.
update
(
value
=
''
),
lambda
:
gr
.
Textbox
.
update
(
value
=
''
),
...
@@ -184,14 +223,19 @@ def run_local(model_path: str,
...
@@ -184,14 +223,19 @@ def run_local(model_path: str,
demo
.
load
(
init
,
inputs
=
None
,
outputs
=
[
state_session_id
])
demo
.
load
(
init
,
inputs
=
None
,
outputs
=
[
state_session_id
])
print
(
f
'server is gonna mount on: http://
{
server_name
}
:
{
server_port
}
'
)
if
server_name
==
'huggingface-space'
:
demo
.
queue
(
concurrency_count
=
batch_size
,
max_size
=
100
,
demo
.
queue
(
concurrency_count
=
InterFace
.
async_engine
.
instance_num
,
api_open
=
True
).
launch
(
max_size
=
100
).
launch
()
max_threads
=
10
,
else
:
share
=
True
,
print
(
f
'server is gonna mount on: http://
{
server_name
}
:
{
server_port
}
'
)
server_port
=
server_port
,
demo
.
queue
(
concurrency_count
=
InterFace
.
async_engine
.
instance_num
,
server_name
=
server_name
,
max_size
=
100
,
)
api_open
=
True
).
launch
(
max_threads
=
10
,
share
=
True
,
server_port
=
server_port
,
server_name
=
server_name
,
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
lmdeploy/serve/openai/api_client.py
View file @
d7117b95
...
@@ -4,8 +4,11 @@ from typing import Any, Dict, Iterable, List, Optional, Union
...
@@ -4,8 +4,11 @@ from typing import Any, Dict, Iterable, List, Optional, Union
import
requests
import
requests
from
lmdeploy.utils
import
get_logger
def
get_model_list
(
api_url
:
str
):
def
get_model_list
(
api_url
:
str
):
"""Get model list from api server."""
response
=
requests
.
get
(
api_url
)
response
=
requests
.
get
(
api_url
)
if
hasattr
(
response
,
'text'
):
if
hasattr
(
response
,
'text'
):
model_list
=
json
.
loads
(
response
.
text
)
model_list
=
json
.
loads
(
response
.
text
)
...
@@ -14,15 +17,31 @@ def get_model_list(api_url: str):
...
@@ -14,15 +17,31 @@ def get_model_list(api_url: str):
return
None
return
None
def
json_loads
(
content
):
"""Loads content to json format."""
try
:
content
=
json
.
loads
(
content
)
return
content
except
:
# noqa
logger
=
get_logger
(
'lmdeploy'
)
logger
.
warning
(
f
'weird json content
{
content
}
'
)
return
''
class
APIClient
:
class
APIClient
:
"""Chatbot for LLaMA series models with turbomind as inference engine.
"""Chatbot for LLaMA series models with turbomind as inference engine.
Args:
Args:
api_server_url (str): communicating address 'http://<ip>:<port>' of
api_server_url (str): communicating address 'http://<ip>:<port>' of
api_server
api_server
api_key (str | None): api key. Default to None, which means no
api key will be used.
"""
"""
def
__init__
(
self
,
api_server_url
:
str
,
**
kwargs
):
def
__init__
(
self
,
api_server_url
:
str
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
):
self
.
api_server_url
=
api_server_url
self
.
api_server_url
=
api_server_url
self
.
chat_intractive_v1_url
=
f
'
{
api_server_url
}
/v1/chat/interactive'
self
.
chat_intractive_v1_url
=
f
'
{
api_server_url
}
/v1/chat/interactive'
self
.
chat_completions_v1_url
=
f
'
{
api_server_url
}
/v1/chat/completions'
self
.
chat_completions_v1_url
=
f
'
{
api_server_url
}
/v1/chat/completions'
...
@@ -30,6 +49,10 @@ class APIClient:
...
@@ -30,6 +49,10 @@ class APIClient:
self
.
models_v1_url
=
f
'
{
api_server_url
}
/v1/models'
self
.
models_v1_url
=
f
'
{
api_server_url
}
/v1/models'
self
.
encode_v1_url
=
f
'
{
api_server_url
}
/v1/encode'
self
.
encode_v1_url
=
f
'
{
api_server_url
}
/v1/encode'
self
.
_available_models
=
None
self
.
_available_models
=
None
self
.
api_key
=
api_key
self
.
headers
=
{
'content-type'
:
'application/json'
}
if
api_key
is
not
None
:
self
.
headers
[
'Authorization'
]
=
f
'Bearer
{
api_key
}
'
@
property
@
property
def
available_models
(
self
):
def
available_models
(
self
):
...
@@ -38,7 +61,7 @@ class APIClient:
...
@@ -38,7 +61,7 @@ class APIClient:
return
self
.
_available_models
return
self
.
_available_models
response
=
requests
.
get
(
self
.
models_v1_url
)
response
=
requests
.
get
(
self
.
models_v1_url
)
if
hasattr
(
response
,
'text'
):
if
hasattr
(
response
,
'text'
):
model_list
=
json
.
loads
(
response
.
text
)
model_list
=
json
_
loads
(
response
.
text
)
model_list
=
model_list
.
pop
(
'data'
,
[])
model_list
=
model_list
.
pop
(
'data'
,
[])
self
.
_available_models
=
[
item
[
'id'
]
for
item
in
model_list
]
self
.
_available_models
=
[
item
[
'id'
]
for
item
in
model_list
]
return
self
.
_available_models
return
self
.
_available_models
...
@@ -57,15 +80,14 @@ class APIClient:
...
@@ -57,15 +80,14 @@ class APIClient:
when it is not. Default to True.
when it is not. Default to True.
Return: (input_ids, length)
Return: (input_ids, length)
"""
"""
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
encode_v1_url
,
response
=
requests
.
post
(
self
.
encode_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
dict
(
input
=
input
,
json
=
dict
(
input
=
input
,
do_preprocess
=
do_preprocess
,
do_preprocess
=
do_preprocess
,
add_bos
=
add_bos
),
add_bos
=
add_bos
),
stream
=
False
)
stream
=
False
)
if
hasattr
(
response
,
'text'
):
if
hasattr
(
response
,
'text'
):
output
=
json
.
loads
(
response
.
text
)
output
=
json
_
loads
(
response
.
text
)
return
output
[
'input_ids'
],
output
[
'length'
]
return
output
[
'input_ids'
],
output
[
'length'
]
return
None
,
None
return
None
,
None
...
@@ -75,8 +97,8 @@ class APIClient:
...
@@ -75,8 +97,8 @@ class APIClient:
temperature
:
Optional
[
float
]
=
0.7
,
temperature
:
Optional
[
float
]
=
0.7
,
top_p
:
Optional
[
float
]
=
1.0
,
top_p
:
Optional
[
float
]
=
1.0
,
n
:
Optional
[
int
]
=
1
,
n
:
Optional
[
int
]
=
1
,
max_tokens
:
Optional
[
int
]
=
512
,
max_tokens
:
Optional
[
int
]
=
None
,
stop
:
Optional
[
bool
]
=
Fals
e
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Non
e
,
stream
:
Optional
[
bool
]
=
False
,
stream
:
Optional
[
bool
]
=
False
,
presence_penalty
:
Optional
[
float
]
=
0.0
,
presence_penalty
:
Optional
[
float
]
=
0.0
,
frequency_penalty
:
Optional
[
float
]
=
0.0
,
frequency_penalty
:
Optional
[
float
]
=
0.0
,
...
@@ -84,12 +106,14 @@ class APIClient:
...
@@ -84,12 +106,14 @@ class APIClient:
repetition_penalty
:
Optional
[
float
]
=
1.0
,
repetition_penalty
:
Optional
[
float
]
=
1.0
,
session_id
:
Optional
[
int
]
=
-
1
,
session_id
:
Optional
[
int
]
=
-
1
,
ignore_eos
:
Optional
[
bool
]
=
False
,
ignore_eos
:
Optional
[
bool
]
=
False
,
skip_special_tokens
:
Optional
[
bool
]
=
True
,
**
kwargs
):
**
kwargs
):
"""Chat completion v1.
"""Chat completion v1.
Args:
Args:
model: model name. Available from self.available_models.
model: model name. Available from self.available_models.
messages: string prompt or chat history in OpenAI format.
messages: string prompt or chat history in OpenAI format. Chat
history example: `[{"role": "user", "content": "hi"}]`.
temperature (float): to modulate the next token probability
temperature (float): to modulate the next token probability
top_p (float): If set to float < 1, only the smallest set of most
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
probable tokens with probabilities that add up to top_p or
...
@@ -97,11 +121,15 @@ class APIClient:
...
@@ -97,11 +121,15 @@ class APIClient:
n (int): How many chat completion choices to generate for each
n (int): How many chat completion choices to generate for each
input message. Only support one here.
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
stream: whether to stream the results or not. Default to false.
max_tokens (int): output token nums
max_tokens (int | None): output token nums. Default to None.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields:
Yields:
json objects in openai formats
json objects in openai formats
...
@@ -111,9 +139,8 @@ class APIClient:
...
@@ -111,9 +139,8 @@ class APIClient:
for
k
,
v
in
locals
().
copy
().
items
()
for
k
,
v
in
locals
().
copy
().
items
()
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
}
}
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
chat_completions_v1_url
,
response
=
requests
.
post
(
self
.
chat_completions_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
pload
,
json
=
pload
,
stream
=
stream
)
stream
=
stream
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
...
@@ -126,11 +153,11 @@ class APIClient:
...
@@ -126,11 +153,11 @@ class APIClient:
continue
continue
if
decoded
[:
6
]
==
'data: '
:
if
decoded
[:
6
]
==
'data: '
:
decoded
=
decoded
[
6
:]
decoded
=
decoded
[
6
:]
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
yield
output
else
:
else
:
decoded
=
chunk
.
decode
(
'utf-8'
)
decoded
=
chunk
.
decode
(
'utf-8'
)
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
yield
output
def
chat_interactive_v1
(
self
,
def
chat_interactive_v1
(
self
,
...
@@ -138,13 +165,14 @@ class APIClient:
...
@@ -138,13 +165,14 @@ class APIClient:
session_id
:
int
=
-
1
,
session_id
:
int
=
-
1
,
interactive_mode
:
bool
=
False
,
interactive_mode
:
bool
=
False
,
stream
:
bool
=
False
,
stream
:
bool
=
False
,
stop
:
bool
=
Fals
e
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Non
e
,
request_output_len
:
int
=
512
,
request_output_len
:
Optional
[
int
]
=
None
,
top_p
:
float
=
0.8
,
top_p
:
float
=
0.8
,
top_k
:
int
=
40
,
top_k
:
int
=
40
,
temperature
:
float
=
0.8
,
temperature
:
float
=
0.8
,
repetition_penalty
:
float
=
1.0
,
repetition_penalty
:
float
=
1.0
,
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
skip_special_tokens
:
Optional
[
bool
]
=
True
,
**
kwargs
):
**
kwargs
):
"""Interactive completions.
"""Interactive completions.
...
@@ -162,8 +190,10 @@ class APIClient:
...
@@ -162,8 +190,10 @@ class APIClient:
interactive mode, session history is kept on the server (and
interactive mode, session history is kept on the server (and
vice versa).
vice versa).
stream: whether to stream the results or not.
stream: whether to stream the results or not.
stop: whether to stop the session response or not.
stop (str | List[str] | None): To stop generating further tokens.
request_output_len (int): output token nums
Only accept stop words that's encoded to one token idex.
request_output_len (int): output token nums. If not specified,
will use maximum possible number for a session.
top_p (float): If set to float < 1, only the smallest set of most
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
higher are kept for generation.
...
@@ -173,18 +203,20 @@ class APIClient:
...
@@ -173,18 +203,20 @@ class APIClient:
repetition_penalty (float): The parameter for repetition penalty.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
Yields:
json objects consist of text, tokens, finish_reason
json objects consist of text, tokens, input_tokens,
history_tokens, finish_reason
"""
"""
pload
=
{
pload
=
{
k
:
v
k
:
v
for
k
,
v
in
locals
().
copy
().
items
()
for
k
,
v
in
locals
().
copy
().
items
()
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
}
}
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
chat_intractive_v1_url
,
response
=
requests
.
post
(
self
.
chat_intractive_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
pload
,
json
=
pload
,
stream
=
stream
)
stream
=
stream
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
...
@@ -192,7 +224,7 @@ class APIClient:
...
@@ -192,7 +224,7 @@ class APIClient:
delimiter
=
b
'
\n
'
):
delimiter
=
b
'
\n
'
):
if
chunk
:
if
chunk
:
decoded
=
chunk
.
decode
(
'utf-8'
)
decoded
=
chunk
.
decode
(
'utf-8'
)
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
yield
output
def
completions_v1
(
def
completions_v1
(
...
@@ -204,12 +236,15 @@ class APIClient:
...
@@ -204,12 +236,15 @@ class APIClient:
n
:
Optional
[
int
]
=
1
,
n
:
Optional
[
int
]
=
1
,
max_tokens
:
Optional
[
int
]
=
16
,
max_tokens
:
Optional
[
int
]
=
16
,
stream
:
Optional
[
bool
]
=
False
,
stream
:
Optional
[
bool
]
=
False
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
top_p
:
Optional
[
float
]
=
1.0
,
top_p
:
Optional
[
float
]
=
1.0
,
top_k
:
Optional
[
int
]
=
40
,
user
:
Optional
[
str
]
=
None
,
user
:
Optional
[
str
]
=
None
,
# additional argument of lmdeploy
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
,
repetition_penalty
:
Optional
[
float
]
=
1.0
,
session_id
:
Optional
[
int
]
=
-
1
,
session_id
:
Optional
[
int
]
=
-
1
,
ignore_eos
:
Optional
[
bool
]
=
False
,
ignore_eos
:
Optional
[
bool
]
=
False
,
skip_special_tokens
:
Optional
[
bool
]
=
True
,
**
kwargs
):
**
kwargs
):
"""Chat completion v1.
"""Chat completion v1.
...
@@ -223,14 +258,20 @@ class APIClient:
...
@@ -223,14 +258,20 @@ class APIClient:
top_p (float): If set to float < 1, only the smallest set of most
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
higher are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
n (int): How many chat completion choices to generate for each
n (int): How many chat completion choices to generate for each
input message. Only support one here.
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
stream: whether to stream the results or not. Default to false.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
1.0 means no penalty
user (str): A unique identifier representing your end-user.
user (str): A unique identifier representing your end-user.
ignore_eos (bool): indicator for ignoring eos
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields:
Yields:
json objects in openai formats
json objects in openai formats
...
@@ -240,9 +281,8 @@ class APIClient:
...
@@ -240,9 +281,8 @@ class APIClient:
for
k
,
v
in
locals
().
copy
().
items
()
for
k
,
v
in
locals
().
copy
().
items
()
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
}
}
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
completions_v1_url
,
response
=
requests
.
post
(
self
.
completions_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
pload
,
json
=
pload
,
stream
=
stream
)
stream
=
stream
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
...
@@ -250,16 +290,16 @@ class APIClient:
...
@@ -250,16 +290,16 @@ class APIClient:
delimiter
=
b
'
\n
'
):
delimiter
=
b
'
\n
'
):
if
chunk
:
if
chunk
:
if
stream
:
if
stream
:
decoded
=
chunk
.
decode
(
'utf-8'
)
[
6
:]
decoded
=
chunk
.
decode
(
'utf-8'
)
if
decoded
==
'data: [DONE]'
:
if
decoded
==
'data: [DONE]'
:
continue
continue
if
decoded
[:
6
]
==
'data: '
:
if
decoded
[:
6
]
==
'data: '
:
decoded
=
decoded
[
6
:]
decoded
=
decoded
[
6
:]
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
yield
output
else
:
else
:
decoded
=
chunk
.
decode
(
'utf-8'
)
decoded
=
chunk
.
decode
(
'utf-8'
)
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
yield
output
def
chat
(
self
,
def
chat
(
self
,
...
@@ -307,7 +347,7 @@ class APIClient:
...
@@ -307,7 +347,7 @@ class APIClient:
temperature
=
temperature
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
):
ignore_eos
=
ignore_eos
):
if
outputs
[
'finish_reason'
]
==
'length'
:
if
outputs
[
'finish_reason'
]
==
'length'
and
outputs
[
'tokens'
]
==
0
:
print
(
'WARNING: exceed session max length.'
print
(
'WARNING: exceed session max length.'
' Please end the session.'
)
' Please end the session.'
)
yield
outputs
[
'text'
],
outputs
[
'tokens'
],
outputs
[
'finish_reason'
]
yield
outputs
[
'text'
],
outputs
[
'tokens'
],
outputs
[
'finish_reason'
]
...
@@ -334,15 +374,21 @@ def input_prompt():
...
@@ -334,15 +374,21 @@ def input_prompt():
return
'
\n
'
.
join
(
iter
(
input
,
sentinel
))
return
'
\n
'
.
join
(
iter
(
input
,
sentinel
))
def
get_streaming_response
(
prompt
:
str
,
def
get_streaming_response
(
api_url
:
str
,
prompt
:
str
,
session_id
:
int
,
api_url
:
str
,
request_output_len
:
int
=
512
,
session_id
:
int
,
stream
:
bool
=
True
,
request_output_len
:
int
=
512
,
interactive_mode
:
bool
=
False
,
stream
:
bool
=
True
,
ignore_eos
:
bool
=
False
,
interactive_mode
:
bool
=
False
,
stop
:
bool
=
False
)
->
Iterable
[
List
[
str
]]:
ignore_eos
:
bool
=
False
,
cancel
:
bool
=
False
,
top_p
:
float
=
0.8
,
temperature
:
float
=
0.7
,
api_key
:
Optional
[
str
]
=
None
)
->
Iterable
[
List
[
str
]]:
headers
=
{
'User-Agent'
:
'Test Client'
}
headers
=
{
'User-Agent'
:
'Test Client'
}
if
api_key
is
not
None
:
headers
[
'Authorization'
]
=
f
'Bearer
{
api_key
}
'
pload
=
{
pload
=
{
'prompt'
:
prompt
,
'prompt'
:
prompt
,
'stream'
:
stream
,
'stream'
:
stream
,
...
@@ -350,7 +396,9 @@ def get_streaming_response(prompt: str,
...
@@ -350,7 +396,9 @@ def get_streaming_response(prompt: str,
'request_output_len'
:
request_output_len
,
'request_output_len'
:
request_output_len
,
'interactive_mode'
:
interactive_mode
,
'interactive_mode'
:
interactive_mode
,
'ignore_eos'
:
ignore_eos
,
'ignore_eos'
:
ignore_eos
,
'stop'
:
stop
'cancel'
:
cancel
,
'top_p'
:
top_p
,
'temperature'
:
temperature
}
}
response
=
requests
.
post
(
api_url
,
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
headers
=
headers
,
...
@@ -360,15 +408,18 @@ def get_streaming_response(prompt: str,
...
@@ -360,15 +408,18 @@ def get_streaming_response(prompt: str,
decode_unicode
=
False
,
decode_unicode
=
False
,
delimiter
=
b
'
\n
'
):
delimiter
=
b
'
\n
'
):
if
chunk
:
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
'utf-8'
))
data
=
json
_
loads
(
chunk
.
decode
(
'utf-8'
))
output
=
data
.
pop
(
'text'
,
''
)
output
=
data
.
pop
(
'text'
,
''
)
tokens
=
data
.
pop
(
'tokens'
,
0
)
tokens
=
data
.
pop
(
'tokens'
,
0
)
finish_reason
=
data
.
pop
(
'finish_reason'
,
None
)
finish_reason
=
data
.
pop
(
'finish_reason'
,
None
)
yield
output
,
tokens
,
finish_reason
yield
output
,
tokens
,
finish_reason
def
main
(
api_server_url
:
str
,
session_id
:
int
=
0
):
def
main
(
api_server_url
:
str
,
api_client
=
APIClient
(
api_server_url
)
session_id
:
int
=
0
,
api_key
:
Optional
[
str
]
=
None
):
"""Main function to chat in terminal."""
api_client
=
APIClient
(
api_server_url
,
api_key
=
api_key
)
while
True
:
while
True
:
prompt
=
input_prompt
()
prompt
=
input_prompt
()
if
prompt
in
[
'exit'
,
'end'
]:
if
prompt
in
[
'exit'
,
'end'
]:
...
...
lmdeploy/serve/openai/api_server.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
asyncio
import
asyncio
import
os
import
os
import
random
import
time
import
time
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
AsyncGenerator
,
List
,
Optional
from
typing
import
AsyncGenerator
,
List
,
Literal
,
Optional
,
Union
import
uvicorn
import
uvicorn
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
Request
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
fastapi.security.http
import
HTTPAuthorizationCredentials
,
HTTPBearer
from
lmdeploy.archs
import
get_task
from
lmdeploy.messages
import
(
GenerationConfig
,
PytorchEngineConfig
,
TurbomindEngineConfig
)
from
lmdeploy.model
import
ChatTemplateConfig
from
lmdeploy.serve.async_engine
import
AsyncEngine
from
lmdeploy.serve.async_engine
import
AsyncEngine
from
lmdeploy.serve.openai.protocol
import
(
# noqa: E501
from
lmdeploy.serve.openai.protocol
import
(
# noqa: E501
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionRequest
,
ChatCompletionRequestQos
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
CompletionRequest
,
ChatCompletionStreamResponse
,
ChatMessage
,
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionRequestQos
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
EmbeddingsRequest
,
EncodeRequest
,
EncodeResponse
,
ErrorResponse
,
EmbeddingsRequest
,
EncodeRequest
,
EncodeResponse
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
ModelCard
,
ModelList
,
ModelPermission
,
GenerateRequest
,
GenerateRequestQos
,
GenerateResponse
,
ModelCard
,
UsageInfo
)
ModelList
,
ModelPermission
,
UsageInfo
)
from
lmdeploy.serve.qos_engine.qos_engine
import
QosEngine
from
lmdeploy.utils
import
get_logger
class
VariableInterface
:
class
VariableInterface
:
"""A IO interface maintaining variables."""
"""A IO interface maintaining variables."""
async_engine
:
AsyncEngine
=
None
async_engine
:
AsyncEngine
=
None
session_id
:
int
=
0
api_keys
:
Optional
[
List
[
str
]]
=
None
qos_engine
:
QosEngine
=
None
request_hosts
=
[]
request_hosts
=
[]
app
=
FastAPI
(
docs_url
=
'/'
)
app
=
FastAPI
(
docs_url
=
'/'
)
get_bearer_token
=
HTTPBearer
(
auto_error
=
False
)
async
def
check_api_key
(
auth
:
Optional
[
HTTPAuthorizationCredentials
]
=
Depends
(
get_bearer_token
),
)
->
str
:
"""Check if client provide valid api key.
Adopted from https://github.com/lm-sys/FastChat/blob/v0.2.35/fastchat/serve/openai_api_server.py#L108-L127
"""
# noqa
if
VariableInterface
.
api_keys
:
if
auth
is
None
or
(
token
:
=
auth
.
credentials
)
not
in
VariableInterface
.
api_keys
:
raise
HTTPException
(
status_code
=
401
,
detail
=
{
'error'
:
{
'message'
:
'Please request with valid api key!'
,
'type'
:
'invalid_request_error'
,
'param'
:
None
,
'code'
:
'invalid_api_key'
,
}
},
)
return
token
else
:
# api_keys not set; allow all
return
None
def
get_model_list
():
def
get_model_list
():
...
@@ -37,10 +74,10 @@ def get_model_list():
...
@@ -37,10 +74,10 @@ def get_model_list():
Only provided one now.
Only provided one now.
"""
"""
return
[
VariableInterface
.
async_engine
.
tm_model
.
model_name
]
return
[
VariableInterface
.
async_engine
.
model_name
]
@
app
.
get
(
'/v1/models'
)
@
app
.
get
(
'/v1/models'
,
dependencies
=
[
Depends
(
check_api_key
)]
)
def
available_models
():
def
available_models
():
"""Show available models."""
"""Show available models."""
model_cards
=
[]
model_cards
=
[]
...
@@ -74,17 +111,149 @@ async def check_request(request) -> Optional[JSONResponse]:
...
@@ -74,17 +111,149 @@ async def check_request(request) -> Optional[JSONResponse]:
return
ret
return
ret
def
ip2id
(
host_ip
:
str
):
@
app
.
post
(
'/v1/chat/completions_qos'
)
"""Convert host ip address to session id."""
async
def
chat_completions_v1_qos
(
request
:
ChatCompletionRequestQos
,
if
'.'
in
host_ip
:
# IPv4
raw_request
:
Request
=
None
):
return
int
(
host_ip
.
replace
(
'.'
,
''
)[
-
8
:])
"""Completion API similar to OpenAI's API.
if
':'
in
host_ip
:
# IPv6
return
int
(
host_ip
.
replace
(
':'
,
''
)[
-
8
:],
16
)
Refer to `https://platform.openai.com/docs/api-reference/chat/create`
print
(
'Warning, could not get session id from ip, set it 0'
)
for the API specification.
return
0
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
model_name
=
request
.
model
request_id
=
str
(
request
.
session_id
)
created_time
=
int
(
time
.
time
())
if
VariableInterface
.
qos_engine
is
None
:
return
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
'cannot parse qos engine config, this api is not work'
)
result_generator
=
await
VariableInterface
.
qos_engine
.
generate_with_qos
(
request
)
if
result_generator
is
None
:
return
create_error_response
(
HTTPStatus
.
INTERNAL_SERVER_ERROR
,
'Failed to generate completions'
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
finish_reason
:
Optional
[
str
]
=
None
,
)
->
str
:
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
role
=
'assistant'
,
content
=
text
),
finish_reason
=
finish_reason
,
)
response
=
ChatCompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
response_json
=
response
.
model_dump_json
()
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
# First chunk with role
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
'assistant'
),
finish_reason
=
None
,
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
'data:
{
data
}
\n\n
'
async
for
res
in
result_generator
:
response_json
=
create_stream_response_json
(
index
=
0
,
text
=
res
.
response
,
)
yield
f
'data:
{
response_json
}
\n\n
'
yield
'data: [DONE]
\n\n
'
# Streaming response
if
request
.
stream
:
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
'text/event-stream'
)
# Non-streaming response
final_res
=
None
text
=
''
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
final_res
=
res
text
+=
res
.
response
assert
final_res
is
not
None
choices
=
[]
choice_data
=
ChatCompletionResponseChoice
(
index
=
0
,
message
=
ChatMessage
(
role
=
'assistant'
,
content
=
text
),
finish_reason
=
final_res
.
finish_reason
,
)
choices
.
append
(
choice_data
)
total_tokens
=
sum
([
final_res
.
history_token_len
,
final_res
.
input_token_len
,
final_res
.
generate_token_len
])
usage
=
UsageInfo
(
prompt_tokens
=
final_res
.
input_token_len
,
completion_tokens
=
final_res
.
generate_token_len
,
total_tokens
=
total_tokens
,
)
response
=
ChatCompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
return
response
@
app
.
post
(
'/v1/chat/completions'
)
@
app
.
post
(
'/v1/chat/completions'
,
dependencies
=
[
Depends
(
check_api_key
)]
)
async
def
chat_completions_v1
(
request
:
ChatCompletionRequest
,
async
def
chat_completions_v1
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
=
None
):
raw_request
:
Request
=
None
):
"""Completion API similar to OpenAI's API.
"""Completion API similar to OpenAI's API.
...
@@ -94,7 +263,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
...
@@ -94,7 +263,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
The request should be a JSON object with the following fields:
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format.
- messages: string prompt or chat history in OpenAI format. Chat history
example: `[{"role": "user", "content": "hi"}]`.
- temperature (float): to modulate the next token probability
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
probable tokens with probabilities that add up to top_p or higher
...
@@ -102,13 +272,18 @@ async def chat_completions_v1(request: ChatCompletionRequest,
...
@@ -102,13 +272,18 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- n (int): How many chat completion choices to generate for each input
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- max_tokens (int
| None
): output token nums
. Default to None.
- repetition_penalty (float): The parameter for repetition penalty.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Currently we do not support the following features:
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- function_call (Users should implement this by themselves)
...
@@ -116,8 +291,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
...
@@ -116,8 +291,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- presence_penalty (replaced with repetition_penalty)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
"""
if
request
.
session_id
=
=
-
1
:
VariableInterface
.
session_id
+
=
1
request
.
session_id
=
random
.
randint
(
1
,
10086
)
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
...
@@ -126,18 +301,26 @@ async def chat_completions_v1(request: ChatCompletionRequest,
...
@@ -126,18 +301,26 @@ async def chat_completions_v1(request: ChatCompletionRequest,
request_id
=
str
(
request
.
session_id
)
request_id
=
str
(
request
.
session_id
)
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
if
isinstance
(
request
.
stop
,
str
):
request
.
stop
=
[
request
.
stop
]
gen_config
=
GenerationConfig
(
max_new_tokens
=
request
.
max_tokens
,
top_k
=
request
.
top_k
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
stop_words
=
request
.
stop
,
skip_special_tokens
=
request
.
skip_special_tokens
)
result_generator
=
VariableInterface
.
async_engine
.
generate
(
result_generator
=
VariableInterface
.
async_engine
.
generate
(
request
.
messages
,
request
.
messages
,
request
.
session_id
,
request
.
session_id
,
True
,
# always use stream to enable batching
gen_config
=
gen_config
,
stream_response
=
True
,
# always use stream to enable batching
sequence_start
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
sequence_end
=
True
,
request_output_len
=
request
.
max_tokens
if
request
.
max_tokens
else
512
,
stop
=
request
.
stop
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
do_preprocess
=
not
isinstance
(
request
.
messages
,
do_preprocess
=
not
isinstance
(
request
.
messages
,
str
),
# text completion for string input
str
),
# text completion for string input
)
)
...
@@ -196,7 +379,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
...
@@ -196,7 +379,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
'Client disconnected'
)
final_res
=
res
final_res
=
res
...
@@ -230,7 +414,155 @@ async def chat_completions_v1(request: ChatCompletionRequest,
...
@@ -230,7 +414,155 @@ async def chat_completions_v1(request: ChatCompletionRequest,
return
response
return
response
@
app
.
post
(
'/v1/completions'
)
@
app
.
post
(
'/v1/completions_qos'
)
async
def
completions_v1_qos
(
request
:
CompletionRequestQos
,
raw_request
:
Request
=
None
):
"""Completion API similar to OpenAI's API.
Go to `https://platform.openai.com/docs/api-reference/completions/create`
for the API specification.
The request should be a JSON object with the following fields:
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
model_name
=
request
.
model
request_id
=
str
(
request
.
session_id
)
created_time
=
int
(
time
.
time
())
if
isinstance
(
request
.
prompt
,
str
):
request
.
prompt
=
[
request
.
prompt
]
if
VariableInterface
.
qos_engine
is
None
:
return
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
'cannot parse qos engine config, this api is not work'
)
generators
=
await
VariableInterface
.
qos_engine
.
generate_with_qos
(
request
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
finish_reason
:
Optional
[
str
]
=
None
,
)
->
str
:
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
text
=
text
,
finish_reason
=
finish_reason
,
)
response
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
response_json
=
response
.
model_dump_json
()
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
# First chunk with role
for
generator
in
generators
:
for
i
in
range
(
request
.
n
):
choice_data
=
CompletionResponseStreamChoice
(
index
=
i
,
text
=
''
,
finish_reason
=
None
,
)
chunk
=
CompletionStreamResponse
(
id
=
request_id
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
'data:
{
data
}
\n\n
'
async
for
res
in
generator
:
response_json
=
create_stream_response_json
(
index
=
0
,
text
=
res
.
response
,
)
yield
f
'data:
{
response_json
}
\n\n
'
yield
'data: [DONE]
\n\n
'
# Streaming response
if
request
.
stream
:
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
'text/event-stream'
)
# Non-streaming response
usage
=
UsageInfo
()
choices
=
[]
async
def
_inner_call
(
i
,
generator
):
final_res
=
None
text
=
''
async
for
res
in
generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
final_res
=
res
text
+=
res
.
response
assert
final_res
is
not
None
choice_data
=
CompletionResponseChoice
(
index
=
0
,
text
=
text
,
finish_reason
=
final_res
.
finish_reason
,
)
choices
.
append
(
choice_data
)
total_tokens
=
sum
([
final_res
.
history_token_len
,
final_res
.
input_token_len
,
final_res
.
generate_token_len
])
usage
.
prompt_tokens
+=
final_res
.
input_token_len
usage
.
completion_tokens
+=
final_res
.
generate_token_len
usage
.
total_tokens
+=
total_tokens
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
len
(
generators
))])
response
=
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
return
response
@
app
.
post
(
'/v1/completions'
,
dependencies
=
[
Depends
(
check_api_key
)])
async
def
completions_v1
(
request
:
CompletionRequest
,
async
def
completions_v1
(
request
:
CompletionRequest
,
raw_request
:
Request
=
None
):
raw_request
:
Request
=
None
):
"""Completion API similar to OpenAI's API.
"""Completion API similar to OpenAI's API.
...
@@ -242,7 +574,7 @@ async def completions_v1(request: CompletionRequest,
...
@@ -242,7 +574,7 @@ async def completions_v1(request: CompletionRequest,
- model (str): model name. Available from /v1/models.
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- max_tokens (int): output token nums
. Default to 16.
- temperature (float): to modulate the next token probability
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
probable tokens with probabilities that add up to top_p or higher
...
@@ -253,18 +585,23 @@ async def completions_v1(request: CompletionRequest,
...
@@ -253,18 +585,23 @@ async def completions_v1(request: CompletionRequest,
- repetition_penalty (float): The parameter for repetition penalty.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
- user (str): A unique identifier representing your end-user.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
Currently we do not support the following features:
Currently we do not support the following features:
- logprobs (not supported yet)
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
"""
if
request
.
session_id
=
=
-
1
:
VariableInterface
.
session_id
+
=
1
request
.
session_id
=
random
.
randint
(
1
,
10086
)
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
...
@@ -274,21 +611,26 @@ async def completions_v1(request: CompletionRequest,
...
@@ -274,21 +611,26 @@ async def completions_v1(request: CompletionRequest,
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
if
isinstance
(
request
.
prompt
,
str
):
if
isinstance
(
request
.
prompt
,
str
):
request
.
prompt
=
[
request
.
prompt
]
request
.
prompt
=
[
request
.
prompt
]
if
isinstance
(
request
.
stop
,
str
):
request
.
stop
=
[
request
.
stop
]
gen_config
=
GenerationConfig
(
max_new_tokens
=
request
.
max_tokens
if
request
.
max_tokens
else
512
,
top_k
=
request
.
top_k
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
stop_words
=
request
.
stop
,
skip_special_tokens
=
request
.
skip_special_tokens
)
generators
=
[]
generators
=
[]
for
i
in
range
(
len
(
request
.
prompt
)):
for
i
in
range
(
len
(
request
.
prompt
)):
result_generator
=
VariableInterface
.
async_engine
.
generate
(
result_generator
=
VariableInterface
.
async_engine
.
generate
(
request
.
prompt
[
i
],
request
.
prompt
[
i
],
request
.
session_id
+
i
,
request
.
session_id
+
i
,
True
,
# always use stream to enable batching
gen_config
=
gen_config
,
stream_response
=
True
,
# always use stream to enable batching
sequence_start
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
sequence_end
=
True
,
request_output_len
=
request
.
max_tokens
if
request
.
max_tokens
else
512
,
stop
=
False
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
do_preprocess
=
False
)
do_preprocess
=
False
)
generators
.
append
(
result_generator
)
generators
.
append
(
result_generator
)
...
@@ -351,7 +693,8 @@ async def completions_v1(request: CompletionRequest,
...
@@ -351,7 +693,8 @@ async def completions_v1(request: CompletionRequest,
async
for
res
in
generator
:
async
for
res
in
generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
'Client disconnected'
)
final_res
=
res
final_res
=
res
...
@@ -394,7 +737,7 @@ async def create_embeddings(request: EmbeddingsRequest,
...
@@ -394,7 +737,7 @@ async def create_embeddings(request: EmbeddingsRequest,
'Unsupported by turbomind.'
)
'Unsupported by turbomind.'
)
@
app
.
post
(
'/v1/encode'
)
@
app
.
post
(
'/v1/encode'
,
dependencies
=
[
Depends
(
check_api_key
)]
)
async
def
encode
(
request
:
EncodeRequest
,
raw_request
:
Request
=
None
):
async
def
encode
(
request
:
EncodeRequest
,
raw_request
:
Request
=
None
):
"""Encode prompts.
"""Encode prompts.
...
@@ -407,7 +750,7 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
...
@@ -407,7 +750,7 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
def
encode
(
prompt
:
str
,
do_preprocess
:
bool
,
add_bos
:
bool
):
def
encode
(
prompt
:
str
,
do_preprocess
:
bool
,
add_bos
:
bool
):
if
do_preprocess
:
if
do_preprocess
:
prompt
=
VariableInterface
.
async_engine
.
model
.
get_prompt
(
prompt
=
VariableInterface
.
async_engine
.
chat_template
.
get_prompt
(
prompt
,
sequence_start
=
add_bos
)
prompt
,
sequence_start
=
add_bos
)
input_ids
=
VariableInterface
.
async_engine
.
tokenizer
.
encode
(
input_ids
=
VariableInterface
.
async_engine
.
tokenizer
.
encode
(
prompt
,
add_bos
=
add_bos
)
prompt
,
add_bos
=
add_bos
)
...
@@ -425,12 +768,9 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
...
@@ -425,12 +768,9 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
return
EncodeResponse
(
input_ids
=
encoded
,
length
=
length
)
return
EncodeResponse
(
input_ids
=
encoded
,
length
=
length
)
@
app
.
post
(
'/generate'
,
@
app
.
post
(
'/v1/chat/interactive_qos'
)
tags
=
[
'deprecated'
],
async
def
chat_interactive_v1_qos
(
request
:
GenerateRequestQos
,
description
=
'please use /v1/chat/interactive'
)
raw_request
:
Request
=
None
):
@
app
.
post
(
'/v1/chat/interactive'
)
async
def
chat_interactive_v1
(
request
:
GenerateRequest
,
raw_request
:
Request
=
None
):
"""Generate completion for the request.
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
- On interactive mode, the chat history is kept on the server. Please set
...
@@ -456,33 +796,134 @@ async def chat_interactive_v1(request: GenerateRequest,
...
@@ -456,33 +796,134 @@ async def chat_interactive_v1(request: GenerateRequest,
- repetition_penalty (float): The parameter for repetition penalty.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
"""
"""
if
request
.
session_id
==
-
1
:
if
request
.
session_id
==
-
1
:
request
.
session_id
=
random
.
randint
(
10087
,
23333
)
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
if
VariableInterface
.
qos_engine
is
None
:
return
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
'cannot parse qos engine config, this api is not work'
)
generation
=
await
VariableInterface
.
qos_engine
.
generate_with_qos
(
request
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
out
in
generation
:
chunk
=
GenerateResponse
(
text
=
out
.
response
,
tokens
=
out
.
generate_token_len
,
input_tokens
=
out
.
input_token_len
,
history_tokens
=
out
.
history_token_len
,
finish_reason
=
out
.
finish_reason
)
data
=
chunk
.
model_dump_json
()
yield
f
'
{
data
}
\n
'
if
request
.
stream
:
return
StreamingResponse
(
stream_results
(),
media_type
=
'text/event-stream'
)
else
:
ret
=
{}
text
=
''
tokens
=
0
finish_reason
=
None
async
for
out
in
generation
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
VariableInterface
.
qos_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
text
+=
out
.
response
tokens
=
out
.
generate_token_len
finish_reason
=
out
.
finish_reason
ret
=
{
'text'
:
text
,
'tokens'
:
tokens
,
'finish_reason'
:
finish_reason
}
return
JSONResponse
(
ret
)
@
app
.
post
(
'/v1/chat/interactive'
,
dependencies
=
[
Depends
(
check_api_key
)])
async
def
chat_interactive_v1
(
request
:
GenerateRequest
,
raw_request
:
Request
=
None
):
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
`interactive_mode = True`.
- On normal mode, no chat history is kept on the server. Set
`interactive_mode = False`.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- session_id: determine which instance will be called. If not specified
with a value other than -1, using random value directly.
- interactive_mode (bool): turn on interactive mode or not. On interactive
mode, session history is kept on the server (and vice versa).
- stream: whether to stream the results or not.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
- request_output_len (int): output token nums. If not specified, will use
maximum possible number for a session.
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- temperature (float): to modulate the next token probability
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
"""
if
request
.
cancel
:
if
request
.
session_id
!=
-
1
:
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
{
'text'
:
''
,
'tokens'
:
0
,
'input_tokens'
:
0
,
'history_tokens'
:
0
,
'finish_reason'
:
'stop'
}
else
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'please set a session_id to cancel a request'
)
if
request
.
session_id
==
-
1
:
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
async_engine
=
VariableInterface
.
async_engine
async_engine
=
VariableInterface
.
async_engine
sequence_start
=
async_engine
.
id2step
.
get
(
str
(
request
.
session_id
),
0
)
==
0
sequence_start
=
async_engine
.
id2step
.
get
(
str
(
request
.
session_id
),
0
)
==
0
sequence_end
=
not
request
.
interactive_mode
sequence_end
=
not
request
.
interactive_mode
if
isinstance
(
request
.
stop
,
str
):
request
.
stop
=
[
request
.
stop
]
gen_config
=
GenerationConfig
(
max_new_tokens
=
request
.
request_output_len
,
top_p
=
request
.
top_p
,
top_k
=
request
.
top_k
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
stop_words
=
request
.
stop
,
skip_special_tokens
=
request
.
skip_special_tokens
)
generation
=
async_engine
.
generate
(
generation
=
async_engine
.
generate
(
request
.
prompt
,
request
.
prompt
,
request
.
session_id
,
request
.
session_id
,
gen_config
=
gen_config
,
stream_response
=
True
,
# always use stream to enable batching
stream_response
=
True
,
# always use stream to enable batching
sequence_start
=
sequence_start
,
sequence_start
=
sequence_start
,
sequence_end
=
sequence_end
,
sequence_end
=
sequence_end
)
request_output_len
=
request
.
request_output_len
,
top_p
=
request
.
top_p
,
top_k
=
request
.
top_k
,
stop
=
request
.
stop
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
)
# Streaming case
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
out
in
generation
:
async
for
out
in
generation
:
chunk
=
GenerateResponse
(
text
=
out
.
response
,
chunk
=
GenerateResponse
(
text
=
out
.
response
,
tokens
=
out
.
generate_token_len
,
tokens
=
out
.
generate_token_len
,
input_tokens
=
out
.
input_token_len
,
history_tokens
=
out
.
history_token_len
,
finish_reason
=
out
.
finish_reason
)
finish_reason
=
out
.
finish_reason
)
data
=
chunk
.
model_dump_json
()
data
=
chunk
.
model_dump_json
()
yield
f
'
{
data
}
\n
'
yield
f
'
{
data
}
\n
'
...
@@ -493,32 +934,46 @@ async def chat_interactive_v1(request: GenerateRequest,
...
@@ -493,32 +934,46 @@ async def chat_interactive_v1(request: GenerateRequest,
else
:
else
:
ret
=
{}
ret
=
{}
text
=
''
text
=
''
tokens
=
0
tokens
,
input_tokens
,
history_tokens
=
0
,
0
,
0
finish_reason
=
None
finish_reason
=
None
async
for
out
in
generation
:
async
for
out
in
generation
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
async_engine
.
stop_session
(
request
.
session_id
)
await
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
'Client disconnected'
)
text
+=
out
.
response
text
+=
out
.
response
tokens
=
out
.
generate_token_len
tokens
=
out
.
generate_token_len
input_tokens
=
out
.
input_token_len
history_tokens
=
out
.
history_token_len
finish_reason
=
out
.
finish_reason
finish_reason
=
out
.
finish_reason
ret
=
{
'text'
:
text
,
'tokens'
:
tokens
,
'finish_reason'
:
finish_reason
}
ret
=
{
'text'
:
text
,
'tokens'
:
tokens
,
'input_tokens'
:
input_tokens
,
'history_tokens'
:
history_tokens
,
'finish_reason'
:
finish_reason
}
return
JSONResponse
(
ret
)
return
JSONResponse
(
ret
)
def
serve
(
model_path
:
str
,
def
serve
(
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
PytorchEngineConfig
,
TurbomindEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
server_name
:
str
=
'0.0.0.0'
,
server_name
:
str
=
'0.0.0.0'
,
server_port
:
int
=
23333
,
server_port
:
int
=
23333
,
instance_num
:
int
=
64
,
tp
:
int
=
1
,
tp
:
int
=
1
,
allow_origins
:
List
[
str
]
=
[
'*'
],
allow_origins
:
List
[
str
]
=
[
'*'
],
allow_credentials
:
bool
=
True
,
allow_credentials
:
bool
=
True
,
allow_methods
:
List
[
str
]
=
[
'*'
],
allow_methods
:
List
[
str
]
=
[
'*'
],
allow_headers
:
List
[
str
]
=
[
'*'
],
allow_headers
:
List
[
str
]
=
[
'*'
],
log_level
:
str
=
'ERROR'
,
log_level
:
str
=
'ERROR'
,
api_keys
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
ssl
:
bool
=
False
,
qos_config_path
:
str
=
''
,
**
kwargs
):
**
kwargs
):
"""An example to perform model inference through the command line
"""An example to perform model inference through the command line
interface.
interface.
...
@@ -534,22 +989,34 @@ def serve(model_path: str,
...
@@ -534,22 +989,34 @@ def serve(model_path: str,
"InternLM/internlm-chat-20b-4bit",
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
on huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
and so on.
model_name (str): needed when model_path is a pytorch model on
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b"
huggingface.co, such as "InternLM/internlm-chat-7b"
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): host ip for serving
server_name (str): host ip for serving
server_port (int): server port
server_port (int): server port
instance_num (int): number of instances of turbomind model
tp (int): tensor parallel
tp (int): tensor parallel
allow_origins (List[str]): a list of allowed origins for CORS
allow_origins (List[str]): a list of allowed origins for CORS
allow_credentials (bool): whether to allow credentials for CORS
allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as
a single api_key. Default to None, which means no api key applied.
ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.
qos_config_path (str): qos policy config path
"""
# noqa E501
"""
# noqa E501
os
.
environ
[
'TM_LOG_LEVEL'
]
=
log_level
if
os
.
getenv
(
'TM_LOG_LEVEL'
)
is
None
:
os
.
environ
[
'TM_LOG_LEVEL'
]
=
log_level
logger
=
get_logger
(
'lmdeploy'
)
logger
.
setLevel
(
log_level
)
if
allow_origins
:
if
allow_origins
:
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -559,16 +1026,55 @@ def serve(model_path: str,
...
@@ -559,16 +1026,55 @@ def serve(model_path: str,
allow_methods
=
allow_methods
,
allow_methods
=
allow_methods
,
allow_headers
=
allow_headers
,
allow_headers
=
allow_headers
,
)
)
if
api_keys
is
not
None
:
if
isinstance
(
api_keys
,
str
):
api_keys
=
api_keys
.
split
(
','
)
VariableInterface
.
api_keys
=
api_keys
ssl_keyfile
,
ssl_certfile
,
http_or_https
=
None
,
None
,
'http'
if
ssl
:
ssl_keyfile
=
os
.
environ
[
'SSL_KEYFILE'
]
ssl_certfile
=
os
.
environ
[
'SSL_CERTFILE'
]
http_or_https
=
'https'
pipeline_type
,
pipeline_class
=
get_task
(
model_path
)
VariableInterface
.
async_engine
=
pipeline_class
(
model_path
=
model_path
,
model_name
=
model_name
,
backend
=
backend
,
backend_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
tp
=
tp
,
**
kwargs
)
if
qos_config_path
:
try
:
with
open
(
qos_config_path
,
'r'
)
as
file
:
qos_config_str
=
file
.
read
()
VariableInterface
.
qos_engine
=
QosEngine
(
qos_tag
=
qos_config_str
,
engine
=
VariableInterface
.
async_engine
,
**
kwargs
)
VariableInterface
.
qos_engine
.
start
()
except
FileNotFoundError
:
VariableInterface
.
qos_engine
=
None
else
:
# hide qos functions if not applied
for
i
in
range
(
len
(
app
.
router
.
routes
)):
if
'qos'
in
app
.
router
.
routes
[
i
].
path
:
app
.
router
.
routes
[
i
].
include_in_schema
=
False
VariableInterface
.
async_engine
=
AsyncEngine
(
model_path
=
model_path
,
model_name
=
model_name
,
instance_num
=
instance_num
,
tp
=
tp
,
**
kwargs
)
for
i
in
range
(
3
):
for
i
in
range
(
3
):
print
(
f
'HINT: Please open
\033
[93m
\033
[1mhttp://
{
server_name
}
:'
print
(
f
'
{
server_port
}
\033
[0m in a browser for detailed api usage!!!'
)
f
'HINT: Please open
\033
[93m
\033
[1m
{
http_or_https
}
://'
uvicorn
.
run
(
app
=
app
,
host
=
server_name
,
port
=
server_port
,
log_level
=
'info'
)
f
'
{
server_name
}
:
{
server_port
}
\033
[0m in a browser for detailed api'
' usage!!!'
)
uvicorn
.
run
(
app
=
app
,
host
=
server_name
,
port
=
server_port
,
log_level
=
'info'
,
ssl_keyfile
=
ssl_keyfile
,
ssl_certfile
=
ssl_certfile
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
lmdeploy/serve/openai/protocol.py
View file @
d7117b95
...
@@ -55,23 +55,48 @@ class UsageInfo(BaseModel):
...
@@ -55,23 +55,48 @@ class UsageInfo(BaseModel):
completion_tokens
:
Optional
[
int
]
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
ChatCompletionRequest
(
BaseModel
):
class
ChatCompletionRequest
Qos
(
BaseModel
):
"""Chat completion request."""
"""Chat completion request."""
model
:
str
model
:
str
messages
:
Union
[
str
,
List
[
Dict
[
str
,
str
]]]
messages
:
Union
[
str
,
List
[
Dict
[
str
,
str
]]]
temperature
:
Optional
[
float
]
=
0.7
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
512
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
examples
=
[
None
])
stop
:
Optional
[
bool
]
=
False
stop
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
user_id
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
top_k
:
Optional
[
int
]
=
40
class
ChatCompletionRequest
(
BaseModel
):
"""Chat completion request."""
model
:
str
# yapf: disable
messages
:
Union
[
str
,
List
[
Dict
[
str
,
Any
]]]
=
Field
(
examples
=
[[{
'role'
:
'user'
,
'content'
:
'hi'
}]])
# noqa
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
examples
=
[
None
])
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default
=
None
,
examples
=
[
None
])
# noqa
# yapf: enable
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
ignore_eos
:
Optional
[
bool
]
=
False
skip_special_tokens
:
Optional
[
bool
]
=
True
top_k
:
Optional
[
int
]
=
40
class
ChatMessage
(
BaseModel
):
class
ChatMessage
(
BaseModel
):
...
@@ -120,6 +145,31 @@ class ChatCompletionStreamResponse(BaseModel):
...
@@ -120,6 +145,31 @@ class ChatCompletionStreamResponse(BaseModel):
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
"""Completion request."""
model
:
str
prompt
:
Union
[
str
,
List
[
Any
]]
suffix
:
Optional
[
str
]
=
None
temperature
:
Optional
[
float
]
=
0.7
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
16
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default
=
None
,
examples
=
[
None
])
stream
:
Optional
[
bool
]
=
False
top_p
:
Optional
[
float
]
=
1.0
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
skip_special_tokens
:
Optional
[
bool
]
=
True
top_k
:
Optional
[
int
]
=
40
# for opencompass
class
CompletionRequestQos
(
BaseModel
):
"""Completion request."""
"""Completion request."""
model
:
str
model
:
str
prompt
:
Union
[
str
,
List
[
Any
]]
prompt
:
Union
[
str
,
List
[
Any
]]
...
@@ -136,9 +186,11 @@ class CompletionRequest(BaseModel):
...
@@ -136,9 +186,11 @@ class CompletionRequest(BaseModel):
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
# additional argument of lmdeploy
top_k
:
Optional
[
int
]
=
40
repetition_penalty
:
Optional
[
float
]
=
1.0
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
ignore_eos
:
Optional
[
bool
]
=
False
user_id
:
Optional
[
str
]
=
None
class
CompletionResponseChoice
(
BaseModel
):
class
CompletionResponseChoice
(
BaseModel
):
...
@@ -205,6 +257,25 @@ class EncodeResponse(BaseModel):
...
@@ -205,6 +257,25 @@ class EncodeResponse(BaseModel):
class
GenerateRequest
(
BaseModel
):
class
GenerateRequest
(
BaseModel
):
"""Generate request."""
prompt
:
Union
[
str
,
List
[
Dict
[
str
,
Any
]]]
session_id
:
int
=
-
1
interactive_mode
:
bool
=
False
stream
:
bool
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default
=
None
,
examples
=
[
None
])
request_output_len
:
Optional
[
int
]
=
Field
(
default
=
None
,
examples
=
[
None
])
# noqa
top_p
:
float
=
0.8
top_k
:
int
=
40
temperature
:
float
=
0.8
repetition_penalty
:
float
=
1.0
ignore_eos
:
bool
=
False
skip_special_tokens
:
Optional
[
bool
]
=
True
cancel
:
Optional
[
bool
]
=
False
# cancel a responding request
class
GenerateRequestQos
(
BaseModel
):
"""Generate request."""
"""Generate request."""
prompt
:
Union
[
str
,
List
[
Dict
[
str
,
str
]]]
prompt
:
Union
[
str
,
List
[
Dict
[
str
,
str
]]]
session_id
:
int
=
-
1
session_id
:
int
=
-
1
...
@@ -217,10 +288,13 @@ class GenerateRequest(BaseModel):
...
@@ -217,10 +288,13 @@ class GenerateRequest(BaseModel):
temperature
:
float
=
0.8
temperature
:
float
=
0.8
repetition_penalty
:
float
=
1.0
repetition_penalty
:
float
=
1.0
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
user_id
:
Optional
[
str
]
=
None
class
GenerateResponse
(
BaseModel
):
class
GenerateResponse
(
BaseModel
):
"""Generate response."""
"""Generate response."""
text
:
str
text
:
str
tokens
:
int
tokens
:
int
input_tokens
:
int
history_tokens
:
int
finish_reason
:
Optional
[
Literal
[
'stop'
,
'length'
]]
=
None
finish_reason
:
Optional
[
Literal
[
'stop'
,
'length'
]]
=
None
lmdeploy/serve/turbomind/chatbot.py
View file @
d7117b95
...
@@ -18,7 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse
...
@@ -18,7 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse
from
lmdeploy.model
import
MODELS
from
lmdeploy.model
import
MODELS
from
lmdeploy.serve.turbomind.utils
import
(
Postprocessor
,
Preprocessor
,
from
lmdeploy.serve.turbomind.utils
import
(
Postprocessor
,
Preprocessor
,
prepare_tensor
)
prepare_tensor
)
from
lmdeploy.utils
import
filter_suffix
from
lmdeploy.utils
import
filter_suffix
,
get_logger
@
dataclass
@
dataclass
...
@@ -51,13 +51,6 @@ def stream_callback(que, result, error):
...
@@ -51,13 +51,6 @@ def stream_callback(que, result, error):
que
.
put
(
result
.
get_response
(
as_json
=
True
))
que
.
put
(
result
.
get_response
(
as_json
=
True
))
def
get_logger
(
log_file
=
None
,
log_level
=
logging
.
INFO
):
"""Return the logger."""
from
lmdeploy.utils
import
get_logger
logger
=
get_logger
(
'service.ft'
,
log_file
=
log_file
,
log_level
=
log_level
)
return
logger
class
Chatbot
:
class
Chatbot
:
"""Chatbot for LLaMA series models with turbomind as inference engine.
"""Chatbot for LLaMA series models with turbomind as inference engine.
...
@@ -75,6 +68,10 @@ class Chatbot:
...
@@ -75,6 +68,10 @@ class Chatbot:
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
log_level
:
int
=
logging
.
INFO
,
log_level
:
int
=
logging
.
INFO
,
display
:
bool
=
False
,
display
:
bool
=
False
,
top_p
:
float
=
1.0
,
top_k
:
int
=
1
,
temperature
:
float
=
0.8
,
repetition_penalty
:
float
=
1.0
,
**
model_kwargs
):
**
model_kwargs
):
self
.
tritonserver_addr
=
tritonserver_addr
self
.
tritonserver_addr
=
tritonserver_addr
self
.
model_name
=
model_name
self
.
model_name
=
model_name
...
@@ -97,10 +94,10 @@ class Chatbot:
...
@@ -97,10 +94,10 @@ class Chatbot:
self
.
eos_id
=
-
1
self
.
eos_id
=
-
1
self
.
cfg
=
mmengine
.
Config
(
self
.
cfg
=
mmengine
.
Config
(
dict
(
session_len
=
self
.
model
.
session_len
,
dict
(
session_len
=
self
.
model
.
session_len
,
top_p
=
self
.
model
.
top_p
,
top_p
=
top_p
,
top_k
=
self
.
model
.
top_k
,
top_k
=
top_k
,
temperature
=
self
.
model
.
temperature
,
temperature
=
temperature
,
repetition_penalty
=
self
.
model
.
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
stop_words
=
stop_words
,
stop_words
=
stop_words
,
bad_words
=
bad_words
))
bad_words
=
bad_words
))
self
.
log_level
=
log_level
self
.
log_level
=
log_level
...
@@ -113,6 +110,7 @@ class Chatbot:
...
@@ -113,6 +110,7 @@ class Chatbot:
request_output_len
:
int
=
None
,
request_output_len
:
int
=
None
,
sequence_start
:
bool
=
False
,
sequence_start
:
bool
=
False
,
sequence_end
:
bool
=
False
,
sequence_end
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""Start a new round conversion of a session.
"""Start a new round conversion of a session.
...
@@ -124,13 +122,15 @@ class Chatbot:
...
@@ -124,13 +122,15 @@ class Chatbot:
request_output_len (int): the expected generated token numbers
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
Returns:
iterator: The generated content by chatbot
iterator: The generated content by chatbot
"""
"""
assert
isinstance
(
session_id
,
int
),
\
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'session
{
session_id
}
, request_id
{
request_id
}
, '
logger
.
info
(
f
'session
{
session_id
}
, request_id
{
request_id
}
, '
f
'request_output_len
{
request_output_len
}
'
)
f
'request_output_len
{
request_output_len
}
'
)
...
@@ -149,11 +149,13 @@ class Chatbot:
...
@@ -149,11 +149,13 @@ class Chatbot:
self
.
cfg
.
update
(
**
kwargs
)
self
.
cfg
.
update
(
**
kwargs
)
self
.
_session
.
prompt
=
self
.
_get_prompt
(
prompt
,
sequence_start
)
self
.
_session
.
prompt
=
self
.
_get_prompt
(
prompt
,
sequence_start
)
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
,
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
.
prompt
,
self
.
_session
,
request_output_len
,
self
.
_session
.
prompt
,
sequence_start
,
request_output_len
,
sequence_end
):
sequence_start
,
sequence_end
,
skip_special_tokens
=
skip_special_tokens
):
if
status
==
StatusCode
.
TRITON_STREAM_END
:
# remove stop_words
if
status
==
StatusCode
.
TRITON_STREAM_END
:
# remove stop_words
res
=
filter_suffix
(
res
,
self
.
model
.
stop_words
)
res
=
filter_suffix
(
res
,
self
.
model
.
stop_words
)
if
status
.
value
<
0
:
if
status
.
value
<
0
:
...
@@ -180,7 +182,7 @@ class Chatbot:
...
@@ -180,7 +182,7 @@ class Chatbot:
assert
isinstance
(
session_id
,
int
),
\
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'end session:
{
session_id
}
'
)
logger
.
info
(
f
'end session:
{
session_id
}
'
)
if
self
.
_session
is
None
:
if
self
.
_session
is
None
:
...
@@ -218,7 +220,7 @@ class Chatbot:
...
@@ -218,7 +220,7 @@ class Chatbot:
"""
"""
assert
isinstance
(
session_id
,
int
),
\
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'cancel session:
{
session_id
}
'
)
logger
.
info
(
f
'cancel session:
{
session_id
}
'
)
if
self
.
_session
is
None
:
if
self
.
_session
is
None
:
...
@@ -267,7 +269,7 @@ class Chatbot:
...
@@ -267,7 +269,7 @@ class Chatbot:
assert
isinstance
(
session_id
,
int
),
\
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'resume session:
{
session_id
}
'
)
logger
.
info
(
f
'resume session:
{
session_id
}
'
)
if
self
.
_session
is
None
:
if
self
.
_session
is
None
:
...
@@ -301,6 +303,7 @@ class Chatbot:
...
@@ -301,6 +303,7 @@ class Chatbot:
request_output_len
:
int
=
None
,
request_output_len
:
int
=
None
,
sequence_start
:
bool
=
False
,
sequence_start
:
bool
=
False
,
sequence_end
:
bool
=
False
,
sequence_end
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""Start a new round conversion of a session. Return the chat
"""Start a new round conversion of a session. Return the chat
...
@@ -313,6 +316,8 @@ class Chatbot:
...
@@ -313,6 +316,8 @@ class Chatbot:
request_output_len (int): the expected generated token numbers
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
Returns:
tuple(Status, str, int): status, text/chat completion,
tuple(Status, str, int): status, text/chat completion,
generated token number
generated token number
...
@@ -320,7 +325,7 @@ class Chatbot:
...
@@ -320,7 +325,7 @@ class Chatbot:
assert
isinstance
(
session_id
,
int
),
\
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'session
{
session_id
}
, request_id
{
request_id
}
, '
logger
.
info
(
f
'session
{
session_id
}
, request_id
{
request_id
}
, '
f
'request_output_len
{
request_output_len
}
'
)
f
'request_output_len
{
request_output_len
}
'
)
...
@@ -338,11 +343,13 @@ class Chatbot:
...
@@ -338,11 +343,13 @@ class Chatbot:
self
.
_session
.
prompt
=
self
.
_get_prompt
(
prompt
,
sequence_start
)
self
.
_session
.
prompt
=
self
.
_get_prompt
(
prompt
,
sequence_start
)
status
,
res
,
tokens
=
None
,
''
,
0
status
,
res
,
tokens
=
None
,
''
,
0
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
,
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
.
prompt
,
self
.
_session
,
request_output_len
,
self
.
_session
.
prompt
,
sequence_start
,
request_output_len
,
sequence_end
):
sequence_start
,
sequence_end
,
skip_special_tokens
=
skip_special_tokens
):
if
status
.
value
<
0
:
if
status
.
value
<
0
:
break
break
if
status
==
StatusCode
.
TRITON_STREAM_END
:
# remove stop_words
if
status
==
StatusCode
.
TRITON_STREAM_END
:
# remove stop_words
...
@@ -420,6 +427,7 @@ class Chatbot:
...
@@ -420,6 +427,7 @@ class Chatbot:
request_output_len
:
int
=
512
,
request_output_len
:
int
=
512
,
sequence_start
:
bool
=
True
,
sequence_start
:
bool
=
True
,
sequence_end
:
bool
=
False
,
sequence_end
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
cancel
:
bool
=
False
):
cancel
:
bool
=
False
):
"""communicate with inference server to chat, or cancel a session, or
"""communicate with inference server to chat, or cancel a session, or
end a session.
end a session.
...
@@ -431,10 +439,12 @@ class Chatbot:
...
@@ -431,10 +439,12 @@ class Chatbot:
sequence_start (bool): indicator for starting a sequence
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
sequence_end (bool): indicator for ending a sequence
cancel (bool): indicator for cancelling the session
cancel (bool): indicator for cancelling the session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
Yields:
tuple: status, text, generated token number
tuple: status, text, generated token number
"""
"""
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'session
{
session
.
session_id
}
, '
logger
.
info
(
f
'session
{
session
.
session_id
}
, '
f
'request id
{
session
.
request_id
}
, '
f
'request id
{
session
.
request_id
}
, '
f
'request_output_len
{
request_output_len
}
, '
f
'request_output_len
{
request_output_len
}
, '
...
@@ -498,7 +508,8 @@ class Chatbot:
...
@@ -498,7 +508,8 @@ class Chatbot:
producer
.
start
()
producer
.
start
()
for
status
,
res
,
n_token
in
self
.
stream_consumer
(
for
status
,
res
,
n_token
in
self
.
stream_consumer
(
self
.
postprocess
,
que
,
session
,
input_tokens
,
preseq_length
,
self
.
postprocess
,
que
,
session
,
input_tokens
,
preseq_length
,
cancel
,
logger
,
self
.
display
,
self
.
eos_id
):
cancel
,
logger
,
self
.
display
,
self
.
eos_id
,
skip_special_tokens
):
yield
status
,
res
,
n_token
yield
status
,
res
,
n_token
producer
.
join
()
producer
.
join
()
...
@@ -591,7 +602,8 @@ class Chatbot:
...
@@ -591,7 +602,8 @@ class Chatbot:
@
staticmethod
@
staticmethod
def
stream_consumer
(
postprocess
,
res_queue
,
session
,
n_input_token
,
def
stream_consumer
(
postprocess
,
res_queue
,
session
,
n_input_token
,
preseq_length
,
cancel
,
logger
,
display
,
eos_id
):
preseq_length
,
cancel
,
logger
,
display
,
eos_id
,
skip_special_tokens
):
"""Consume the response from the triton inference server.
"""Consume the response from the triton inference server.
Args:
Args:
...
@@ -605,11 +617,15 @@ class Chatbot:
...
@@ -605,11 +617,15 @@ class Chatbot:
logger (util.Logger):
logger (util.Logger):
display (bool): display the text in the consolo interface or not
display (bool): display the text in the consolo interface or not
eos_id (int): eos token id
eos_id (int): eos token id
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
Yields:
tuple: status, text, generated token number
tuple: status, text, generated token number
"""
"""
status
,
res
,
n_token
=
None
,
''
,
0
status
,
res
,
n_token
=
None
,
''
,
0
output_ids
=
np
.
zeros
((
1
,
1
,
0
),
dtype
=
np
.
uint32
)
text
=
''
while
True
:
while
True
:
result
=
res_queue
.
get
()
result
=
res_queue
.
get
()
if
result
is
None
:
if
result
is
None
:
...
@@ -648,7 +664,8 @@ class Chatbot:
...
@@ -648,7 +664,8 @@ class Chatbot:
output_ids
=
output_ids
[:,
:,
:
-
1
]
output_ids
=
output_ids
[:,
:,
:
-
1
]
output_str
=
postprocess
(
output_str
=
postprocess
(
output_ids
,
np
.
array
([[
n_token
]],
dtype
=
np
.
uint32
))
output_ids
,
np
.
array
([[
n_token
]],
dtype
=
np
.
uint32
),
np
.
array
([[
int
(
skip_special_tokens
)]],
dtype
=
np
.
int32
))
text
=
output_str
[
0
].
decode
()
text
=
output_str
[
0
].
decode
()
# utf-8 char at the end means it's a potential unfinished
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# byte sequence, continue to concate it with the next
...
...
lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py
View file @
d7117b95
...
@@ -84,10 +84,13 @@ class TritonPythonModel:
...
@@ -84,10 +84,13 @@ class TritonPythonModel:
request
,
'TOKENS_BATCH'
).
as_numpy
()
request
,
'TOKENS_BATCH'
).
as_numpy
()
sequence_length
=
pb_utils
.
get_input_tensor_by_name
(
sequence_length
=
pb_utils
.
get_input_tensor_by_name
(
request
,
'sequence_length'
).
as_numpy
()
request
,
'sequence_length'
).
as_numpy
()
skip_special_tokens
=
pb_utils
.
get_input_tensor_by_name
(
request
,
'skip_special_tokens'
).
as_numpy
()
# Postprocessing output data.
# Postprocessing output data.
outputs
=
self
.
_postprocessing
(
tokens_batch
.
tolist
(),
outputs
=
self
.
_postprocessing
(
tokens_batch
.
tolist
(),
sequence_length
)
sequence_length
,
skip_special_tokens
)
# Create output tensors. You need pb_utils.Tensor
# Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse.
# objects to create pb_utils.InferenceResponse.
...
@@ -118,12 +121,16 @@ class TritonPythonModel:
...
@@ -118,12 +121,16 @@ class TritonPythonModel:
"""
"""
print
(
'Cleaning up...'
)
print
(
'Cleaning up...'
)
def
_postprocessing
(
self
,
tokens_batch
,
sequence_length
):
def
_postprocessing
(
self
,
tokens_batch
,
sequence_length
,
skip_special_tokens
):
"""decode token ids into texts."""
"""decode token ids into texts."""
outputs
=
[]
outputs
=
[]
for
beam_tokens
,
beam_len
in
zip
(
tokens_batch
,
sequence_length
):
for
beam_tokens
,
beam_len
,
beam_skip_special
in
zip
(
for
tokens
,
_len
in
zip
(
beam_tokens
,
beam_len
):
tokens_batch
,
sequence_length
,
skip_special_tokens
):
output
=
self
.
tokenizer
.
decode
(
tokens
,
_len
)
for
tokens
,
_len
,
skip_special
in
zip
(
beam_tokens
,
beam_len
,
beam_skip_special
):
output
=
self
.
tokenizer
.
decode
(
tokens
,
_len
,
skip_special_tokens
=
bool
(
skip_special
))
output
=
output
.
encode
(
'utf8'
)
output
=
output
.
encode
(
'utf8'
)
outputs
.
append
(
output
)
outputs
.
append
(
output
)
return
outputs
return
outputs
lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt
View file @
d7117b95
...
@@ -11,6 +11,11 @@ input [
...
@@ -11,6 +11,11 @@ input [
name: "sequence_length"
name: "sequence_length"
data_type: TYPE_UINT32
data_type: TYPE_UINT32
dims: [ -1 ]
dims: [ -1 ]
},
{
name: "skip_special_tokens"
data_type: TYPE_INT32
dims: [ -1 ]
}
}
]
]
output [
output [
...
...
lmdeploy/serve/turbomind/utils.py
View file @
d7117b95
...
@@ -72,22 +72,29 @@ class Postprocessor:
...
@@ -72,22 +72,29 @@ class Postprocessor:
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
infer
(
*
args
,
**
kwargs
)
return
self
.
infer
(
*
args
,
**
kwargs
)
def
infer
(
self
,
output_ids
:
np
.
ndarray
,
seqlen
:
np
.
ndarray
):
def
infer
(
self
,
output_ids
:
np
.
ndarray
,
seqlen
:
np
.
ndarray
,
skip_special_tokens
:
bool
=
True
):
"""De-tokenize tokens for text.
"""De-tokenize tokens for text.
Args:
Args:
output_ids(np.ndarray): tokens' id
output_ids(np.ndarray): tokens' id
seqlen(np.ndarray): sequence length
seqlen(np.ndarray): sequence length
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
Returns:
str: decoded tokens
str: decoded tokens
"""
"""
inputs
=
[
inputs
=
[
prepare_tensor
(
'TOKENS_BATCH'
,
output_ids
),
prepare_tensor
(
'TOKENS_BATCH'
,
output_ids
),
prepare_tensor
(
'sequence_length'
,
seqlen
)
prepare_tensor
(
'sequence_length'
,
seqlen
),
prepare_tensor
(
'skip_special_tokens'
,
skip_special_tokens
)
]
]
inputs
[
0
].
set_data_from_numpy
(
output_ids
)
inputs
[
0
].
set_data_from_numpy
(
output_ids
)
inputs
[
1
].
set_data_from_numpy
(
seqlen
)
inputs
[
1
].
set_data_from_numpy
(
seqlen
)
inputs
[
2
].
set_data_from_numpy
(
skip_special_tokens
)
model_name
=
'postprocessing'
model_name
=
'postprocessing'
with
grpcclient
.
InferenceServerClient
(
self
.
tritonserver_addr
)
\
with
grpcclient
.
InferenceServerClient
(
self
.
tritonserver_addr
)
\
as
client
:
as
client
:
...
...
lmdeploy/tokenizer.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
json
import
json
import
os
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Optional
,
Sequence
,
Union
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
import
torch
from
lmdeploy.utils
import
get_logger
# this file will be copied to triton server, make sure all
# importing are starting from the package root lmdeploy
@
dataclass
class
DetokenizeState
:
"""A state collection of incrementally detekenization.
Args:
ids_offset (int): offset to all input ids. In LMDeploy, the output
ids length is not one by one. It could be random by random.
prev_tokens (List[str] | None): for incrementally decoding.
Default to None, which means the first round.
prefix_offset (int): the start index of tokens to be converted to
string (prev + new tokens). Default to 0 for the first round.
read_offset (int): the end index of tokens to be converted to
string (prev token). Default to 0 for the first round.
"""
ids_offset
:
int
=
0
prev_tokens
:
Optional
[
List
[
str
]]
=
None
prefix_offset
:
int
=
0
read_offset
:
int
=
0
def
as_tuple
(
self
)
->
Tuple
:
"""Return a tuple of states."""
return
(
self
.
ids_offset
,
self
.
prev_tokens
,
self
.
prefix_offset
,
self
.
read_offset
)
class
SentencePieceTokenizer
:
class
SentencePieceTokenizer
:
"""Tokenizer of sentencepiece.
"""Tokenizer of sentencepiece.
...
@@ -18,6 +49,12 @@ class SentencePieceTokenizer:
...
@@ -18,6 +49,12 @@ class SentencePieceTokenizer:
from
sentencepiece
import
SentencePieceProcessor
from
sentencepiece
import
SentencePieceProcessor
self
.
model
=
SentencePieceProcessor
(
model_file
=
model_file
)
self
.
model
=
SentencePieceProcessor
(
model_file
=
model_file
)
self
.
_prefix_space_tokens
=
None
self
.
_prefix_space_tokens
=
None
# for stop words
self
.
_maybe_decode_bytes
:
bool
=
None
# TODO maybe lack a constant.py
self
.
_indexes_tokens_deque
=
deque
(
maxlen
=
10
)
self
.
max_indexes_num
=
5
self
.
logger
=
get_logger
(
'lmdeploy'
)
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
...
@@ -53,6 +90,27 @@ class SentencePieceTokenizer:
...
@@ -53,6 +90,27 @@ class SentencePieceTokenizer:
else
:
else
:
return
decoded
return
decoded
def
indexes_containing_token
(
self
,
token
:
str
):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for
_token
,
_indexes
in
self
.
_indexes_tokens_deque
:
if
token
==
_token
:
return
_indexes
if
token
==
' '
:
# ' ' is special
token
=
'▁'
vocab
=
self
.
model
.
IdToPiece
(
list
(
range
(
self
.
vocab_size
)))
indexes
=
[
i
for
i
,
voc
in
enumerate
(
vocab
)
if
token
in
voc
]
if
len
(
indexes
)
>
self
.
max_indexes_num
:
indexes
=
self
.
encode
(
token
,
add_bos
=
False
)[
-
1
:]
self
.
logger
.
warning
(
f
'There are too many(>
{
self
.
max_indexes_num
}
) possible '
f
'indexes may decoding
{
token
}
, we will use
{
indexes
}
only'
)
self
.
_indexes_tokens_deque
.
append
((
token
,
indexes
))
return
indexes
def
encode
(
self
,
s
:
str
,
add_bos
:
bool
=
True
,
**
kwargs
):
def
encode
(
self
,
s
:
str
,
add_bos
:
bool
=
True
,
**
kwargs
):
"""Tokenize a prompt.
"""Tokenize a prompt.
...
@@ -63,13 +121,18 @@ class SentencePieceTokenizer:
...
@@ -63,13 +121,18 @@ class SentencePieceTokenizer:
"""
"""
return
self
.
model
.
Encode
(
s
,
add_bos
=
add_bos
,
**
kwargs
)
return
self
.
model
.
Encode
(
s
,
add_bos
=
add_bos
,
**
kwargs
)
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
):
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
**
kwargs
):
"""De-tokenize.
"""De-tokenize.
Args:
Args:
t (List[int]): a list of token ids
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
offset (int): for incrementally decoding. Default to None, which
means not applied.
means not applied.
skip_special_tokens (boo): not used in SentencePieceTokenizer.
Returns:
Returns:
str: text of decoding tokens
str: text of decoding tokens
"""
"""
...
@@ -81,6 +144,34 @@ class SentencePieceTokenizer:
...
@@ -81,6 +144,34 @@ class SentencePieceTokenizer:
out_string
=
self
.
_maybe_add_prefix_space
(
t
,
out_string
)
out_string
=
self
.
_maybe_add_prefix_space
(
t
,
out_string
)
return
out_string
return
out_string
def
detokenize_incrementally
(
self
,
all_input_ids
:
Sequence
[
int
],
state
:
DetokenizeState
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
out_string
=
self
.
model
.
Decode
(
all_input_ids
)
if
state
.
prev_tokens
is
not
None
:
out_string
=
self
.
_maybe_add_prefix_space
(
all_input_ids
,
out_string
)
state
.
prev_tokens
=
[]
# not None for the above condition
return
out_string
,
state
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
"""Tokenize prompts.
"""Tokenize prompts.
...
@@ -106,20 +197,10 @@ class HuggingFaceTokenizer:
...
@@ -106,20 +197,10 @@ class HuggingFaceTokenizer:
def
__init__
(
self
,
model_dir
:
str
):
def
__init__
(
self
,
model_dir
:
str
):
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
model_file
=
osp
.
join
(
model_dir
,
'tokenizer.model'
)
self
.
logger
=
get_logger
(
'lmdeploy'
)
backend_tokenizer_file
=
osp
.
join
(
model_dir
,
'tokenizer.json'
)
model_file_exists
=
osp
.
exists
(
model_file
)
if
not
osp
.
exists
(
backend_tokenizer_file
)
and
model_file_exists
:
print
(
'WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.'
)
self
.
model
=
AutoTokenizer
.
from_pretrained
(
model_dir
,
self
.
model
=
AutoTokenizer
.
from_pretrained
(
model_dir
,
trust_remote_code
=
True
)
trust_remote_code
=
True
)
self
.
_prefix_space_tokens
=
None
self
.
_prefix_space_tokens
=
None
# save tokenizer.json to reuse
if
not
osp
.
exists
(
backend_tokenizer_file
)
and
model_file_exists
:
if
hasattr
(
self
.
model
,
'backend_tokenizer'
):
if
os
.
access
(
model_dir
,
os
.
W_OK
):
self
.
model
.
backend_tokenizer
.
save
(
backend_tokenizer_file
)
if
self
.
model
.
eos_token_id
is
None
:
if
self
.
model
.
eos_token_id
is
None
:
generation_config_file
=
osp
.
join
(
model_dir
,
generation_config_file
=
osp
.
join
(
model_dir
,
...
@@ -131,11 +212,27 @@ class HuggingFaceTokenizer:
...
@@ -131,11 +212,27 @@ class HuggingFaceTokenizer:
elif
hasattr
(
self
.
model
,
'eod_id'
):
# Qwen remote
elif
hasattr
(
self
.
model
,
'eod_id'
):
# Qwen remote
self
.
model
.
eos_token_id
=
self
.
model
.
eod_id
self
.
model
.
eos_token_id
=
self
.
model
.
eod_id
# for stop words
self
.
_vocab_size_with_added
:
int
=
None
self
.
_maybe_decode_bytes
:
bool
=
None
# TODO maybe lack a constant.py
self
.
_indexes_tokens_deque
=
deque
(
maxlen
=
10
)
self
.
max_indexes_num
=
5
self
.
token2id
=
{}
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
"""vocabulary size."""
"""vocabulary size."""
return
self
.
model
.
vocab_size
return
self
.
model
.
vocab_size
@
property
def
vocab_size_with_added
(
self
):
"""vocabulary size with added vocab."""
if
self
.
_vocab_size_with_added
is
not
None
:
return
self
.
_vocab_size_with_added
self
.
_vocab_size_with_added
=
len
(
self
.
model
.
get_vocab
())
return
self
.
_vocab_size_with_added
@
property
@
property
def
bos_token_id
(
self
):
def
bos_token_id
(
self
):
"""begine of the sentence token id."""
"""begine of the sentence token id."""
...
@@ -159,7 +256,7 @@ class HuggingFaceTokenizer:
...
@@ -159,7 +256,7 @@ class HuggingFaceTokenizer:
}
}
return
self
.
_prefix_space_tokens
return
self
.
_prefix_space_tokens
def
_maybe_add_prefix_space
(
self
,
tokens
,
decoded
):
def
_maybe_add_prefix_space
(
self
,
tokens
:
List
[
int
]
,
decoded
:
str
):
"""maybe add prefix space for incremental decoding."""
"""maybe add prefix space for incremental decoding."""
if
len
(
tokens
)
and
not
decoded
.
startswith
(
' '
)
and
\
if
len
(
tokens
)
and
not
decoded
.
startswith
(
' '
)
and
\
tokens
[
0
]
in
self
.
prefix_space_tokens
:
tokens
[
0
]
in
self
.
prefix_space_tokens
:
...
@@ -167,6 +264,66 @@ class HuggingFaceTokenizer:
...
@@ -167,6 +264,66 @@ class HuggingFaceTokenizer:
else
:
else
:
return
decoded
return
decoded
@
property
def
maybe_decode_bytes
(
self
):
"""Check if self.model.convert_ids_to_tokens return not a str value."""
if
self
.
_maybe_decode_bytes
is
None
:
self
.
_maybe_decode_bytes
=
False
vocab
=
self
.
model
.
convert_ids_to_tokens
(
list
(
range
(
self
.
vocab_size
)))
for
tok
in
vocab
:
if
not
isinstance
(
tok
,
str
):
self
.
_maybe_decode_bytes
=
True
break
return
self
.
_maybe_decode_bytes
def
indexes_containing_token
(
self
,
token
:
str
):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for
_token
,
_indexes
in
self
.
_indexes_tokens_deque
:
if
token
==
_token
:
return
_indexes
if
self
.
token2id
==
{}:
# decode is slower than convert_ids_to_tokens
if
self
.
maybe_decode_bytes
:
try
:
self
.
token2id
=
{
self
.
model
.
decode
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)
}
except
Exception
as
e
:
# qwen-vl
assert
str
(
e
)
==
'Unclosed image token'
else
:
self
.
token2id
=
{
self
.
model
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)
}
if
token
==
' '
:
# ' ' is special
token
=
'▁'
indexes
=
[
i
for
_token
,
i
in
self
.
token2id
.
items
()
if
token
in
_token
]
if
len
(
indexes
)
>
self
.
max_indexes_num
:
# multiple id decode to same token
indexes
=
[
i
for
i
in
indexes
if
self
.
decode
([
i
])
==
token
]
indexes
=
indexes
[:
self
.
max_indexes_num
]
self
.
logger
.
warning
(
f
'There are too many(>
{
self
.
max_indexes_num
}
) possible '
f
'indexes may decoding
{
token
}
, we will use
{
indexes
}
only'
)
# there might be token id that exceeds self.vocab_size
if
len
(
indexes
)
==
0
:
indexes
=
self
.
encode
(
token
,
False
)
if
len
(
indexes
)
!=
1
:
self
.
logger
.
warning
(
f
'The token
{
token
}
, its length of indexes
{
indexes
}
is '
'not 1. Currently, it can not be used as stop words'
)
indexes
=
[]
self
.
_indexes_tokens_deque
.
append
((
token
,
indexes
))
return
indexes
def
encode
(
self
,
s
:
str
,
add_bos
:
bool
=
True
,
**
kwargs
):
def
encode
(
self
,
s
:
str
,
add_bos
:
bool
=
True
,
**
kwargs
):
"""Tokenize a prompt.
"""Tokenize a prompt.
...
@@ -182,7 +339,10 @@ class HuggingFaceTokenizer:
...
@@ -182,7 +339,10 @@ class HuggingFaceTokenizer:
encoded
=
encoded
[
1
:]
encoded
=
encoded
[
1
:]
return
encoded
return
encoded
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
):
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
):
"""De-tokenize.
"""De-tokenize.
Args:
Args:
...
@@ -192,14 +352,121 @@ class HuggingFaceTokenizer:
...
@@ -192,14 +352,121 @@ class HuggingFaceTokenizer:
Returns:
Returns:
str: text of decoding tokens
str: text of decoding tokens
"""
"""
skip_special_tokens
=
True
t
=
t
[
offset
:]
t
=
t
[
offset
:]
out_string
=
self
.
model
.
decode
(
t
,
out_string
=
self
.
model
.
decode
(
t
,
skip_special_tokens
=
skip_special_tokens
)
skip_special_tokens
=
skip_special_tokens
)
if
offset
:
if
offset
:
logger
=
get_logger
(
'lmdeploy'
)
logger
.
warning
(
'For incrementally detokenization, please try '
'detokenize_incrementally function instead.'
)
out_string
=
self
.
_maybe_add_prefix_space
(
t
,
out_string
)
out_string
=
self
.
_maybe_add_prefix_space
(
t
,
out_string
)
return
out_string
return
out_string
@
staticmethod
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
)
->
str
:
if
tokenizer
.
is_fast
or
not
tokenizer
.
get_added_vocab
():
return
tokenizer
.
convert_tokens_to_string
(
output_tokens
)
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99
sub_texts
=
[]
current_sub_text
=
[]
all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
for
token
in
output_tokens
:
if
skip_special_tokens
and
token
in
all_special_tokens
:
continue
if
token
in
tokenizer
.
get_added_vocab
():
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
current_sub_text
=
[]
sub_texts
.
append
(
token
)
else
:
current_sub_text
.
append
(
token
)
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
if
spaces_between_special_tokens
:
return
' '
.
join
(
sub_texts
)
else
:
return
''
.
join
(
sub_texts
)
# Based on
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165
def
detokenize_incrementally
(
self
,
all_input_ids
:
Sequence
[
int
],
state
:
DetokenizeState
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
tokenizer
=
self
.
model
ids_offset
,
prev_tokens
,
prefix_offset
,
read_offset
=
state
.
as_tuple
()
# This is the first iteration for this sequence
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
all_input_ids
[
ids_offset
:],
skip_special_tokens
=
skip_special_tokens
)
if
prev_tokens
is
None
:
# Please notice that in VLLM, indexes are detokenized one by one
# while in LMDeploy, every turn, the detokenized indexes length
# can be different.
if
skip_special_tokens
and
new_tokens
and
new_tokens
[
0
]
in
tokenizer
.
all_special_ids
:
read_offset
=
1
# skip special token
output_tokens
=
new_tokens
prev_tokens
=
new_tokens
else
:
# Put new_token_id in a list so skip_special_tokens is respected
output_tokens
=
prev_tokens
+
new_tokens
prev_tokens
+=
new_tokens
prefix_text
=
self
.
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
new_text
=
self
.
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
# update state and get final decoded output
if
len
(
new_text
)
>
len
(
prefix_text
)
and
not
new_text
.
endswith
(
'�'
):
# utf-8 char at the end means it's a potential unfinished byte
# sequence from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
prefix_offset
=
read_offset
read_offset
=
len
(
output_tokens
)
new_text
=
new_text
[
len
(
prefix_text
):]
else
:
new_text
=
''
return
new_text
,
DetokenizeState
(
len
(
all_input_ids
),
prev_tokens
,
prefix_offset
,
read_offset
)
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
"""Tokenize prompts.
"""Tokenize prompts.
...
@@ -230,7 +497,7 @@ class Tokenizer:
...
@@ -230,7 +497,7 @@ class Tokenizer:
model_file_exists
=
osp
.
exists
(
model_file
)
model_file_exists
=
osp
.
exists
(
model_file
)
config_exists
=
osp
.
exists
(
tokenizer_config_file
)
config_exists
=
osp
.
exists
(
tokenizer_config_file
)
use_hf_model
=
config_exists
or
not
model_file_exists
use_hf_model
=
config_exists
or
not
model_file_exists
self
.
logger
=
get_logger
(
'lmdeploy'
)
if
not
use_hf_model
:
if
not
use_hf_model
:
self
.
model
=
SentencePieceTokenizer
(
model_file
)
self
.
model
=
SentencePieceTokenizer
(
model_file
)
else
:
else
:
...
@@ -261,7 +528,12 @@ class Tokenizer:
...
@@ -261,7 +528,12 @@ class Tokenizer:
"""
"""
return
self
.
model
.
encode
(
s
,
add_bos
,
**
kwargs
)
return
self
.
model
.
encode
(
s
,
add_bos
,
**
kwargs
)
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
):
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
):
"""De-tokenize.
"""De-tokenize.
Args:
Args:
...
@@ -271,7 +543,34 @@ class Tokenizer:
...
@@ -271,7 +543,34 @@ class Tokenizer:
Returns:
Returns:
str: text of decoding tokens
str: text of decoding tokens
"""
"""
return
self
.
model
.
decode
(
t
,
offset
)
return
self
.
model
.
decode
(
t
,
offset
,
skip_special_tokens
)
def
detokenize_incrementally
(
self
,
all_input_ids
:
Sequence
[
int
],
state
:
DetokenizeState
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
return
self
.
model
.
detokenize_incrementally
(
all_input_ids
,
state
=
state
,
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
)
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
"""Tokenize prompts.
"""Tokenize prompts.
...
@@ -282,3 +581,14 @@ class Tokenizer:
...
@@ -282,3 +581,14 @@ class Tokenizer:
list[int]: token ids
list[int]: token ids
"""
"""
return
self
.
model
(
s
)
return
self
.
model
(
s
)
def
indexes_containing_token
(
self
,
token
):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
encoded
=
self
.
encode
(
token
,
add_bos
=
False
)
if
len
(
encoded
)
>
1
:
self
.
logger
.
warning
(
f
'The token
{
token
}
, its length of indexes
{
encoded
}
is over '
'than 1. Currently, it can not be used as stop words'
)
return
[]
return
self
.
model
.
indexes_containing_token
(
token
)
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment