Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
d7117b95
Commit
d7117b95
authored
Mar 22, 2024
by
zhouxiang
Browse files
同步0.2.6代码
parent
5f83e392
Changes
151
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1849 additions
and
1015 deletions
+1849
-1015
lmdeploy/pytorch/dist.py
lmdeploy/pytorch/dist.py
+0
-98
lmdeploy/pytorch/model.py
lmdeploy/pytorch/model.py
+0
-159
lmdeploy/pytorch/modules/__init__.py
lmdeploy/pytorch/modules/__init__.py
+0
-4
lmdeploy/pytorch/modules/linear.py
lmdeploy/pytorch/modules/linear.py
+0
-154
lmdeploy/pytorch/session.py
lmdeploy/pytorch/session.py
+0
-81
lmdeploy/pytorch/utils.py
lmdeploy/pytorch/utils.py
+18
-93
lmdeploy/serve/async_engine.py
lmdeploy/serve/async_engine.py
+513
-189
lmdeploy/serve/gradio/api_server_backend.py
lmdeploy/serve/gradio/api_server_backend.py
+22
-6
lmdeploy/serve/gradio/app.py
lmdeploy/serve/gradio/app.py
+28
-2
lmdeploy/serve/gradio/constants.py
lmdeploy/serve/gradio/constants.py
+2
-2
lmdeploy/serve/gradio/triton_server_backend.py
lmdeploy/serve/gradio/triton_server_backend.py
+21
-3
lmdeploy/serve/gradio/turbomind_coupled.py
lmdeploy/serve/gradio/turbomind_coupled.py
+90
-46
lmdeploy/serve/openai/api_client.py
lmdeploy/serve/openai/api_client.py
+92
-41
lmdeploy/serve/openai/api_server.py
lmdeploy/serve/openai/api_server.py
+583
-77
lmdeploy/serve/openai/protocol.py
lmdeploy/serve/openai/protocol.py
+76
-2
lmdeploy/serve/turbomind/chatbot.py
lmdeploy/serve/turbomind/chatbot.py
+48
-31
lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py
...y/serve/turbomind/triton_models/postprocessing/1/model.py
+12
-5
lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt
...serve/turbomind/triton_models/postprocessing/config.pbtxt
+5
-0
lmdeploy/serve/turbomind/utils.py
lmdeploy/serve/turbomind/utils.py
+9
-2
lmdeploy/tokenizer.py
lmdeploy/tokenizer.py
+330
-20
No files found.
lmdeploy/pytorch/dist.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
"""Helpers for parallel and distributed inference."""
import
functools
import
os
import
torch
from
torch.distributed
import
broadcast
,
broadcast_object_list
,
is_initialized
def
get_local_rank
():
"""Get local rank of current process.
Assume environment variable ``LOCAL_RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
"""
# noqa: E501
return
int
(
os
.
getenv
(
'LOCAL_RANK'
,
'0'
))
def
get_rank
():
"""Get rank of current process.
Assume environment variable ``RANK`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
"""
# noqa: E501
return
int
(
os
.
getenv
(
'RANK'
,
'0'
))
def
get_world_size
():
"""Get rank of current process.
Assume environment variable ``WORLD_SIZE`` is properly set by some launcher.
See: https://pytorch.org/docs/stable/elastic/run.html#environment-variables
"""
# noqa: E501
return
int
(
os
.
getenv
(
'WORLD_SIZE'
,
'1'
))
def
master_only
(
func
):
"""Decorator to run a function only on the master process."""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
is_initialized
():
if
get_rank
()
!=
0
:
return
None
return
func
(
*
args
,
**
kwargs
)
return
wrapper
def
master_only_and_broadcast_general
(
func
):
"""Decorator to run a function only on the master process and broadcast the
result to all processes."""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
is_initialized
():
if
get_rank
()
==
0
:
result
=
[
func
(
*
args
,
**
kwargs
)]
else
:
result
=
[
None
]
broadcast_object_list
(
result
,
src
=
0
)
result
=
result
[
0
]
else
:
result
=
func
(
*
args
,
**
kwargs
)
return
result
return
wrapper
def
master_only_and_broadcast_tensor
(
func
):
"""Decorator to run a function only on the master process and broadcast the
result to all processes.
Note: Require CUDA tensor.
Note: Not really work because we don't know the shape aforehand,
for cpu tensors, use master_only_and_broadcast_general
"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
size
,
dtype
,
**
kwargs
):
if
is_initialized
():
if
get_rank
()
==
0
:
result
=
func
(
*
args
,
**
kwargs
)
else
:
result
=
torch
.
empty
(
size
=
size
,
dtype
=
dtype
,
device
=
get_local_rank
())
broadcast
(
result
,
src
=
0
)
# print(f'rank {get_rank()} received {result}')
else
:
result
=
func
(
*
args
,
**
kwargs
)
return
result
return
wrapper
lmdeploy/pytorch/model.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
time
import
warnings
from
typing
import
Optional
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
.dist
import
get_local_rank
logger
=
logging
.
getLogger
(
__name__
)
class
LoadWoInit
:
"""Context manager that disable parameter initialization."""
def
__init__
(
self
):
self
.
constant_
=
torch
.
nn
.
init
.
constant_
self
.
zeros_
=
torch
.
nn
.
init
.
zeros_
self
.
ones_
=
torch
.
nn
.
init
.
ones_
self
.
uniform_
=
torch
.
nn
.
init
.
uniform_
self
.
normal_
=
torch
.
nn
.
init
.
normal_
self
.
kaiming_uniform_
=
torch
.
nn
.
init
.
kaiming_uniform_
self
.
kaiming_normal_
=
torch
.
nn
.
init
.
kaiming_normal_
def
__enter__
(
self
,
*
args
,
**
kwargs
):
torch
.
nn
.
init
.
constant_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
zeros_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
ones_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
uniform_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
normal_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
kaiming_uniform_
=
lambda
*
args
,
**
kwargs
:
None
torch
.
nn
.
init
.
kaiming_normal_
=
lambda
*
args
,
**
kwargs
:
None
def
__exit__
(
self
,
*
args
,
**
kwargs
):
torch
.
nn
.
init
.
constant_
=
self
.
constant_
torch
.
nn
.
init
.
zeros_
=
self
.
zeros_
torch
.
nn
.
init
.
ones_
=
self
.
ones_
torch
.
nn
.
init
.
uniform_
=
self
.
uniform_
torch
.
nn
.
init
.
normal_
=
self
.
normal_
torch
.
nn
.
init
.
kaiming_uniform_
=
self
.
kaiming_uniform_
torch
.
nn
.
init
.
kaiming_normal_
=
self
.
kaiming_normal_
def
init_model
(
model_path
:
str
,
tokenizer_path
:
Optional
[
str
]
=
None
,
use_fast_tokenizer
=
True
):
"""Initialize model and tokenizer from given model path.
Args:
model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
"""
start
=
time
.
monotonic
()
if
not
tokenizer_path
:
tokenizer_path
=
model_path
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
use_fast
=
use_fast_tokenizer
,
trust_remote_code
=
True
)
with
LoadWoInit
():
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
logger
.
info
(
f
'Model loaded in
{
time
.
monotonic
()
-
start
:.
1
f
}
seconds'
)
logger
.
info
(
f
'Model loaded from
{
model_path
}
'
)
logger
.
debug
(
model
)
return
model
,
tokenizer
def
accel_model
(
model
,
accel
:
Optional
[
str
]
=
None
,
gpu_id
=
None
,
max_alloc
=
2048
,
tp_size
=
1
):
"""Accelerate model with given accelerator.
Note:
Currently we support only deepspeed or just no acceleration.
"""
logger
.
info
(
f
'Accelerate model with
{
accel
}
'
)
if
accel
is
None
:
# No acceleration, just to cuda
# assume single gpu single process
# user is responsible to assign the gpu id via CUDA_VISIBLE_DEVICES # noqa: E501
gpu_id
=
gpu_id
if
gpu_id
is
not
None
else
get_local_rank
()
model
=
model
.
cuda
(
gpu_id
)
elif
accel
.
lower
()
==
'deepspeed'
:
# Use deepspeed inference inject fast kernel and/or tensor parallel
try
:
import
deepspeed
except
ImportError
as
e
:
raise
ImportError
(
'--accel=deepspeed is specified but '
'deepspeed is not installed.
\n
'
'Install with `pip install deepspeed`.'
)
from
e
config
=
dict
(
tensor_parallel
=
dict
(
tp_size
=
tp_size
),
# Use world size in general
dtype
=
torch
.
float16
,
replace_with_kernel_inject
=
True
,
max_out_tokens
=
max_alloc
,
)
if
'InternLM'
in
model
.
__class__
.
__name__
:
try
:
# Use customized deepspeed supporting InternLM
# https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0 (commit cdef2ce) # noqa: E501
from
deepspeed.module_inject.containers.internlm
import
\
InternLMLayerPolicy
# noqa: E501
except
ImportError
:
# InternLM is not officially supported by DeepSpeed
# Set replace_with_kernel_inject=False to use AutoTP
config
.
update
({
'replace_with_kernel_inject'
:
False
})
warnings
.
warn
(
'
\033
[0;93m'
'Current installation of deepspeed does not '
'support InternLM. Disable kernel injection. '
'To support InternLM, install customized deepspeed with '
'`pip install git+https://github.com/wangruohui/DeepSpeed@support_internlm_0.10.0`'
# noqa: E501
'
\033
[0m'
)
else
:
for
module
in
model
.
modules
():
# Since remote code is dynamically located,
# we need to do this dynamically
if
module
.
__class__
.
__name__
==
'InternLMDecoderLayer'
:
InternLMLayerPolicy
.
_orig_layer_class
=
module
.
__class__
# noqa: E501
break
logger
.
debug
(
f
'Using deepspeed config
\n
{
config
}
'
)
model
=
deepspeed
.
init_inference
(
model
=
model
,
# Transformers models
config
=
config
,
)
# for k, v in model.named_parameters():
# logger.debug(f"{k}: v.device")
else
:
raise
ValueError
(
f
'Unsupported accelerator
{
accel
}
.'
)
logger
.
debug
(
model
)
return
model
lmdeploy/pytorch/modules/__init__.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
from
.linear
import
WeightOnlyQLinear
__all__
=
[
'WeightOnlyQLinear'
]
lmdeploy/pytorch/modules/linear.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Type
,
TypeVar
import
torch
from
torch
import
nn
try
:
import
awq_inference_engine
except
ModuleNotFoundError
:
awq_inference_engine
=
None
class
WeightOnlyQLinear
(
nn
.
Module
):
"""This class implements weight only quantization linear.
Args:
w_bit (int): number of bits for quantization.
symmetry (bool): If true, use symmetric quantization,
otherwise use asymmetric quantization.
group_size (int): size of the quantization group.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (Tensor, optional): Defaults to None.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
True
,
w_bit
:
int
=
4
,
symmetry
:
bool
=
False
,
group_size
:
int
=
128
,
)
->
None
:
super
().
__init__
()
if
w_bit
not
in
[
2
,
4
,
8
]:
raise
NotImplementedError
(
'Only 2,4,8 bit are supported for now.'
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
assert
self
.
in_features
%
self
.
group_size
==
0
assert
out_features
%
(
32
//
self
.
w_bit
)
==
0
w_pack_oc
=
out_features
//
(
32
//
self
.
w_bit
)
w_inc
=
in_features
weight
=
torch
.
zeros
((
w_inc
,
w_pack_oc
),
dtype
=
torch
.
int32
)
self
.
register_buffer
(
'qweight'
,
weight
)
if
bias
:
self
.
register_buffer
(
'bias'
,
torch
.
zeros
(
out_features
))
else
:
self
.
bias
=
None
s_inc
=
in_features
//
self
.
group_size
s_oc
=
out_features
scales
=
torch
.
zeros
((
s_inc
,
s_oc
),
dtype
=
torch
.
float16
)
self
.
register_buffer
(
'scales'
,
scales
)
if
not
symmetry
:
z_inc
=
in_features
//
self
.
group_size
z_oc
=
out_features
//
(
32
//
self
.
w_bit
)
zeros
=
torch
.
zeros
((
z_inc
,
z_oc
),
dtype
=
torch
.
int32
)
self
.
register_buffer
(
'qzeros'
,
zeros
)
else
:
self
.
qzeros
=
None
@
classmethod
def
from_linear
(
cls
:
Type
[
'WeightOnlyQLinear'
],
linear
:
nn
.
Linear
,
quantizer
:
TypeVar
(
'Quantizer'
),
awq_layout
:
bool
=
True
)
->
'WeightOnlyQLinear'
:
"""Create a WeightOnlyQLinear object from a PyTorch Linear object.
Args:
linear (nn.Linear): PyTorch Linear object.
quantizer (Quantizer): Object that handles quantization.
awq_layout (bool): AWQ layout. Defaults to True.
Returns:
WeightOnlyQLinear: A WeightOnlyQLinear object.
"""
device
=
linear
.
weight
.
device
w_bit
=
quantizer
.
bits
pack_num
=
32
//
w_bit
if
awq_layout
:
assert
w_bit
==
4
pack_order
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
else
:
pack_order
=
torch
.
arange
(
pack_num
)
group_size
=
quantizer
.
group_size
symmetry
=
quantizer
.
symmetry
in_features
=
linear
.
in_features
out_features
=
linear
.
out_features
bias
=
False
if
linear
.
bias
is
None
else
True
qlinear
=
cls
(
in_features
,
out_features
,
bias
,
w_bit
,
symmetry
,
group_size
)
qlinear
.
bias
=
linear
.
bias
qparams
=
quantizer
.
calculate_qparams
(
linear
.
weight
)
i32_w
=
quantizer
.
quant
(
linear
.
weight
,
qparams
,
real
=
True
)
i32_w
=
i32_w
.
t
().
contiguous
()
pack_int_w
=
torch
.
zeros_like
(
qlinear
.
qweight
).
to
(
device
)
for
col
in
range
(
pack_int_w
.
shape
[
1
]):
for
i
in
range
(
pack_num
):
pack_int_w_col
=
i32_w
[:,
col
*
pack_num
+
pack_order
[
i
]]
pack_int_w
[:,
col
]
|=
pack_int_w_col
<<
(
i
*
w_bit
)
qlinear
.
qweight
=
pack_int_w
qlinear
.
scales
=
qparams
.
scales
.
squeeze
(
-
1
).
t
().
contiguous
()
if
qparams
.
zero_points
is
not
None
:
zeros
=
qparams
.
zero_points
.
to
(
torch
.
int32
).
to
(
device
)
zeros
=
zeros
.
squeeze
(
-
1
).
t
().
contiguous
()
pack_int_zeros
=
torch
.
zeros_like
(
qlinear
.
qzeros
).
to
(
device
)
for
col
in
range
(
pack_int_zeros
.
shape
[
1
]):
for
i
in
range
(
pack_num
):
qzero_col
=
zeros
[:,
col
*
pack_num
+
pack_order
[
i
]]
pack_int_zeros
[:,
col
]
|=
qzero_col
<<
(
i
*
w_bit
)
qlinear
.
qzeros
=
pack_int_zeros
qlinear
.
to
(
'cpu'
)
return
qlinear
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
if
awq_inference_engine
is
None
:
raise
RuntimeError
(
'Run the following command to install '
'the kernel for 4bit inference
\n\n
'
'git clone https://github.com/mit-han-lab/llm-awq.git
\n
'
'cd awq/kernels
\n
'
'python setup.py install
\n
'
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,
)
inputs
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
awq_inference_engine
.
gemm_forward_cuda
(
inputs
.
half
(),
self
.
qweight
,
self
.
scales
.
half
(),
self
.
qzeros
,
self
.
group_size
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
lmdeploy/pytorch/session.py
deleted
100644 → 0
View file @
5f83e392
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
torch
from
transformers.generation.utils
import
ModelOutput
logger
=
logging
.
getLogger
(
__name__
)
class
BasicSessionManager
:
"""Basic session manager without history."""
def
prepend_history
(
self
,
input_ids
):
return
input_ids
def
add_to_history
(
self
,
output
):
pass
class
BasicSessionManagerWithHistory
:
"""Basic session manager with chat history.
Args:
max_session_len (int): Maximum number of tokens allowed for all chat sessions.
reduce_size (int): Number of tokens to be trimmed when reaching maximum
session length. Default: 256.
start_ids (list[int]): Sequences of ids at the start of the chat session.
sep_ids (list[int]): Sequences of ids separating chat sessions.
"""
# noqa: E501
bs
=
1
def
__init__
(
self
,
max_session_len
=
2048
,
reduce_size
=
256
,
start_ids
=
[
1
],
sep_ids
=
[
13
])
->
None
:
self
.
start_ids
=
torch
.
tensor
(
start_ids
,
dtype
=
torch
.
long
)
self
.
sep_ids
=
torch
.
tensor
(
sep_ids
,
dtype
=
torch
.
long
)
assert
self
.
start_ids
.
ndim
==
1
assert
self
.
sep_ids
.
ndim
==
1
self
.
max_session_len
=
max
(
len
(
start_ids
),
max_session_len
)
self
.
reduce_size
=
min
(
reduce_size
,
max_session_len
-
len
(
start_ids
))
assert
self
.
max_session_len
>
self
.
reduce_size
self
.
new_session
()
def
new_session
(
self
):
self
.
history_ids
=
self
.
start_ids
.
repeat
(
self
.
bs
,
1
)
def
prepend_history
(
self
,
input_ids
:
torch
.
Tensor
):
"""Prepend history ids to input ids and trim if over-length."""
input_ids
=
input_ids
.
to
(
self
.
history_ids
.
device
).
long
()
sep_ids
=
self
.
sep_ids
.
to
(
self
.
history_ids
.
device
).
long
().
repeat
(
1
,
1
)
input_ids
=
torch
.
cat
([
self
.
history_ids
,
sep_ids
,
input_ids
],
dim
=
1
)
if
input_ids
.
shape
[
1
]
>
self
.
max_session_len
:
input_ids
=
input_ids
[:,
(
self
.
reduce_size
-
self
.
max_session_len
):]
input_ids
[:,
:
len
(
self
.
start_ids
)]
=
self
.
start_ids
.
repeat
(
self
.
bs
,
1
)
return
input_ids
def
add_to_history
(
self
,
output
):
"""Save history output ids.
Note:
Output returned by HuggingFace generator contains both input
and output ids.
"""
if
isinstance
(
output
,
ModelOutput
):
self
.
history_ids
=
output
.
sequences
elif
isinstance
(
output
,
torch
.
Tensor
):
self
.
history_ids
=
output
else
:
raise
ValueError
(
f
'Unknown output type
{
type
(
output
)
}
'
)
lmdeploy/pytorch/utils.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import
inspect
from
inspect
import
Parameter
,
Signature
from
typing
import
Dict
,
Sequence
import
logging
import
psutil
from
transformers.generation.streamers
import
BaseStreamer
from
.dist
import
get_rank
,
master_only
,
master_only_and_broadcast_general
def
get_gpu_memory
(
id
:
int
=
0
)
->
int
:
"""Returns the free and total physical memory of the GPU in bytes."""
import
torch
return
torch
.
cuda
.
mem_get_info
(
id
)
try
:
import
readline
# To support command line history # noqa: F401
except
ImportError
:
# readline not available
pass
logger
=
logging
.
getLogger
(
__name__
)
def
get_cpu_memory
()
->
int
:
"""Returns the total CPU memory of the node in bytes."""
return
psutil
.
virtual_memory
().
total
class
TerminalIO
:
"""Terminal input and output."""
def
bind_sigature
(
input_names
:
str
,
args
:
Sequence
,
kwargs
:
Dict
):
"""Bind args and kwargs to given input names."""
kind
=
inspect
.
_ParameterKind
.
POSITIONAL_OR_KEYWORD
end_of_output
=
'
\n
'
@
master_only_and_broadcast_general
def
input
(
self
):
"""Read input from terminal."""
print
(
'
\n
double enter to end input >>> '
,
end
=
''
)
sentinel
=
''
# ends when this string is seen
try
:
return
'
\n
'
.
join
(
iter
(
input
,
sentinel
))
except
EOFError
:
print
(
'Detect EOF, exit'
)
exit
()
@
master_only
def
output
(
self
,
string
):
"""Output to terminal with flush."""
print
(
string
,
end
=
''
,
flush
=
True
)
class
BasicStreamer
(
BaseStreamer
):
"""Basic streamer for HuggingFace models."""
def
__init__
(
self
,
decode_func
,
output_func
,
end_of_output
=
'
\n
'
,
skip_prompt
=
True
):
self
.
decode
=
decode_func
self
.
output
=
output_func
self
.
end_of_output
=
end_of_output
self
.
skip_prompt
=
skip_prompt
self
.
gen_len
=
0
def
put
(
self
,
value
):
"""Callback before forwarding current token id to model."""
if
self
.
gen_len
==
0
and
self
.
skip_prompt
:
pass
else
:
token
=
self
.
decode
(
value
)
self
.
output
(
token
)
self
.
gen_len
+=
1
def
end
(
self
):
"""Callback at the end of generation."""
self
.
output
(
self
.
end_of_output
)
self
.
gen_len
=
0
def
control
(
prompt
,
gen_config
,
sm
):
"""Allow user to control generation config and session manager.
Return:
True if control command applied, False otherwise.
"""
if
prompt
==
'exit'
:
exit
(
0
)
if
prompt
==
'clear'
:
sm
.
new_session
()
logger
.
info
(
'Session cleared'
)
return
True
# Re-config during runtime
if
prompt
.
startswith
(
'config set'
):
try
:
keqv
=
prompt
.
split
()[
-
1
]
k
,
v
=
keqv
.
split
(
'='
)
v
=
eval
(
v
)
gen_config
.
__setattr__
(
k
,
v
)
logger
.
info
(
f
'Worker
{
get_rank
()
}
set
{
k
}
to
{
repr
(
v
)
}
'
)
logger
.
info
(
f
'Generator config changed to:
{
gen_config
}
'
)
return
True
except
:
# noqa
logger
.
info
(
'illegal instruction, treated as normal conversation. '
)
return
False
sig
=
Signature
([
Parameter
(
name
,
kind
)
for
name
in
input_names
])
bind
=
sig
.
bind
(
*
args
,
**
kwargs
)
return
bind
.
arguments
lmdeploy/serve/async_engine.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
import
asyncio
import
dataclasses
import
os
import
random
from
contextlib
import
contextmanager
from
typing
import
List
,
Literal
,
Optional
,
Union
from
argparse
import
ArgumentError
from
contextlib
import
asynccontextmanager
from
itertools
import
count
from
queue
import
Empty
,
Queue
from
threading
import
Thread
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
lmdeploy.messages
import
(
EngineGenerationConfig
,
GenerationConfig
,
PytorchEngineConfig
,
Response
,
TurbomindEngineConfig
)
from
lmdeploy.model
import
ChatTemplateConfig
,
best_match_model
from
lmdeploy.tokenizer
import
DetokenizeState
from
lmdeploy.utils
import
_stop_words
,
get_logger
logger
=
get_logger
(
'lmdeploy'
)
def
get_model_name_from_workspace_model
(
model_dir
:
str
):
"""Get model name from workspace model."""
from
configparser
import
ConfigParser
triton_model_path
=
os
.
path
.
join
(
model_dir
,
'triton_models'
,
'weights'
)
if
not
os
.
path
.
exists
(
triton_model_path
):
return
None
ini_path
=
os
.
path
.
join
(
triton_model_path
,
'config.ini'
)
# load cfg
with
open
(
ini_path
,
'r'
)
as
f
:
parser
=
ConfigParser
()
parser
.
read_file
(
f
)
return
parser
[
'llama'
][
'model_name'
]
def
deduce_a_name
(
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
)
->
str
:
"""Deduce a model name from all the possible arguments."""
def
_config_model_name
(
config
):
if
config
and
config
.
model_name
:
return
config
.
model_name
return
None
backend_config_model_name
=
_config_model_name
(
backend_config
)
chat_template_config_model_name
=
_config_model_name
(
chat_template_config
)
model_name
=
model_name
or
chat_template_config_model_name
or
backend_config_model_name
# noqa
if
model_name
is
None
:
# model maybe from workspace for turbomind
model_name
=
get_model_name_from_workspace_model
(
model_path
)
# may get a model name from model_path
if
model_name
is
None
:
model_name
=
best_match_model
(
model_path
)
if
model_name
is
None
:
raise
ArgumentError
(
None
,
f
'Please set model_name for
{
model_path
}
'
)
else
:
logger
.
info
(
f
'matched chat template name:
{
model_name
}
'
)
return
model_name
@
dataclasses
.
dataclass
...
...
@@ -16,6 +74,55 @@ class GenOut:
finish_reason
:
Optional
[
Literal
[
'stop'
,
'length'
]]
=
None
class
Session
:
"""Session for AsyncEngine.chat.
Args:
_id (int): session_id for internal use.
_step (int): the offset of the k/v cache for internal use.
_prompt (Any): input prompt for internal use.
_response (Reaponse): model output for prompt.
_engine (Any): engine for internal use.
history (List[Any, str]): chat history.
"""
_ids
=
count
(
0
)
def
__init__
(
self
):
self
.
_id
:
int
=
next
(
self
.
_ids
)
self
.
_step
:
int
=
0
self
.
_prompt
:
Any
=
None
self
.
_response
:
Response
=
None
self
.
_engine
:
Any
=
None
self
.
history
:
List
[
Tuple
[
Any
,
str
]]
=
[]
def
_merge_response
(
self
,
resp
:
Response
,
step
:
Union
[
Response
,
GenOut
]):
"""merge response."""
resp
.
text
+=
step
.
text
if
isinstance
(
step
,
Response
)
else
step
.
response
resp
.
input_token_len
=
step
.
input_token_len
resp
.
generate_token_len
=
step
.
generate_token_len
resp
.
finish_reason
=
step
.
finish_reason
return
resp
@
property
def
response
(
self
)
->
Response
:
"""return response."""
return
self
.
_response
def
close
(
self
):
"""release engine storage for this session."""
if
self
.
_engine
:
inst
=
self
.
_engine
.
create_instance
()
inst
.
end
(
self
.
_id
)
def
__repr__
(
self
)
->
str
:
res
=
''
for
user
,
assistant
in
self
.
history
:
if
isinstance
(
user
,
list
):
user
=
str
(
user
)
res
+=
f
'USER:
\n
{
user
}
\n
ASSISTANT:
\n
{
assistant
}
\n
'
return
res
class
AsyncEngine
:
"""Async inference engine. Maintaining a bunch of tm_model instances.
...
...
@@ -30,51 +137,150 @@ class AsyncEngine:
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
on huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
instance_num (int): instance numbers to be created
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel
"""
def
__init__
(
self
,
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
instance_num
:
int
=
32
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
tp
:
int
=
1
,
**
kwargs
)
->
None
:
from
lmdeploy
import
turbomind
as
tm
self
.
tm_model
=
tm
.
TurboMind
.
from_pretrained
(
model_path
,
model_name
=
model_name
,
tp
=
tp
,
**
kwargs
)
self
.
tokenizer
=
self
.
tm_model
.
tokenizer
self
.
instance_num
=
instance_num
self
.
model
=
self
.
tm_model
.
model
logger
.
info
(
f
'input backend=
{
backend
}
, backend_config=
{
backend_config
}
'
)
logger
.
info
(
f
'input chat_template_config=
{
chat_template_config
}
'
)
self
.
model_name
=
deduce_a_name
(
model_path
,
model_name
,
backend_config
,
chat_template_config
)
# build chat template config
if
chat_template_config
is
None
:
chat_template_config
=
ChatTemplateConfig
(
self
.
model_name
)
elif
chat_template_config
.
model_name
is
None
:
chat_template_config
.
model_name
=
self
.
model_name
self
.
chat_template
=
chat_template_config
.
chat_template
# prevent bc
for
k
in
list
(
kwargs
.
keys
()):
if
hasattr
(
chat_template_config
,
k
):
logger
.
warning
(
f
'
{
k
}
was deprecated. Please use '
'chat_template_config instead'
)
v
=
kwargs
.
pop
(
k
)
setattr
(
chat_template_config
,
k
,
v
)
logger
.
info
(
f
'updated chat_template_onfig=
{
chat_template_config
}
'
)
# build backend engine
if
backend
==
'turbomind'
:
self
.
_build_turbomind
(
model_path
=
model_path
,
backend_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
tp
=
tp
,
**
kwargs
)
elif
backend
==
'pytorch'
:
self
.
_build_pytorch
(
model_path
=
model_path
,
backend_config
=
backend_config
,
**
kwargs
)
else
:
raise
ValueError
(
f
'unsupported backend
{
backend
}
'
)
logger
.
info
(
f
'updated backend_config=
{
self
.
backend_config
}
'
)
# parameters for member functions
self
.
session_len
=
self
.
backend_config
.
session_len
self
.
stop_words
=
_stop_words
(
self
.
chat_template
.
stop_words
,
self
.
engine
.
tokenizer
)
if
self
.
stop_words
is
not
None
:
self
.
stop_words
=
self
.
stop_words
[
0
][
0
].
tolist
()
self
.
backend
=
backend
self
.
instance_num
=
self
.
backend_config
.
max_batch_size
self
.
tokenizer
=
self
.
engine
.
tokenizer
self
.
id2step
=
{}
self
.
id2generator
=
{}
self
.
loop
=
asyncio
.
get_event_loop
()
self
.
running_session_ids
=
set
()
self
.
gens_set
=
set
()
for
i
in
range
(
instance_num
):
self
.
gens_set
.
add
(
self
.
tm_model
.
create_instance
())
for
i
in
range
(
self
.
instance_num
):
self
.
gens_set
.
add
(
self
.
engine
.
create_instance
())
def
_build_turbomind
(
self
,
model_path
:
str
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
tp
:
int
=
1
,
**
kwargs
):
"""Innter build method for turbomind backend."""
if
backend_config
is
None
:
backend_config
=
TurbomindEngineConfig
(
model_name
=
self
.
model_name
,
tp
=
tp
)
assert
isinstance
(
backend_config
,
TurbomindEngineConfig
),
'Please'
\
' use TurbomindEngineConfig imported from lmdeploy.messages for '
\
'turbomind backend'
if
backend_config
.
session_len
is
None
:
backend_config
.
session_len
=
self
.
chat_template
.
session_len
from
lmdeploy
import
turbomind
as
tm
self
.
engine
=
tm
.
TurboMind
.
from_pretrained
(
model_path
,
engine_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
**
kwargs
)
self
.
backend_config
=
backend_config
def
_build_pytorch
(
self
,
model_path
:
str
,
backend_config
:
Optional
[
Union
[
TurbomindEngineConfig
,
PytorchEngineConfig
]]
=
None
,
**
kwargs
):
"""Innter build method for pytorch backend."""
from
lmdeploy.pytorch.engine
import
Engine
if
backend_config
is
None
:
backend_config
=
PytorchEngineConfig
(
self
.
model_name
)
assert
isinstance
(
backend_config
,
PytorchEngineConfig
),
'Please '
\
'use PytorchEngineConfig imported from lmdeploy.messages for '
\
'pytorch backend'
if
backend_config
.
session_len
is
None
:
backend_config
.
session_len
=
self
.
chat_template
.
session_len
self
.
engine
=
Engine
(
model_path
=
model_path
,
engine_config
=
backend_config
)
self
.
backend_config
=
backend_config
def
__call__
(
self
,
prompts
:
List
[
str
],
prompts
:
Union
[
List
[
str
],
str
,
List
[
Dict
],
List
[
List
[
Dict
]]],
gen_config
:
Optional
[
GenerationConfig
]
=
None
,
request_output_len
=
512
,
top_k
=
40
,
top_p
=
0.8
,
temperature
=
0.8
,
repetition_penalty
=
1.0
,
ignore_eos
=
False
,
do_preprocess
=
True
,
top_k
:
int
=
40
,
top_p
:
float
=
0.8
,
temperature
:
float
=
0.8
,
repetition_penalty
:
float
=
1.0
,
ignore_eos
:
bool
=
False
,
do_preprocess
:
bool
=
True
,
**
kwargs
):
"""Inference a batch of prompts.
Args:
prompts (List[str]): a batch of prompts
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
chat_template_config (ChatTemplateConfig | None):a instance of
ChatTemplateConfig. Default to None.
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
...
...
@@ -85,245 +291,363 @@ class AsyncEngine:
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
if
gen_config
is
None
:
gen_config
=
GenerationConfig
(
max_new_tokens
=
request_output_len
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
)
return
self
.
batch_infer
(
prompts
,
request_output_len
=
request_output_len
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
,
gen_config
=
gen_config
,
do_preprocess
=
do_preprocess
,
**
kwargs
)
def
stop_session
(
self
,
session_id
:
int
):
async
def
stop_session
(
self
,
session_id
:
int
):
"""Stop a session by a session_id."""
input_ids
=
[
self
.
tm_model
.
eos_id
]
stop_generator
=
self
.
tm_model
.
create_instance
()
for
outputs
in
stop_generator
.
stream_infer
(
session_id
,
input_ids
,
request_output_len
=
0
,
sequence_start
=
False
,
sequence_end
=
False
,
stop
=
True
):
pass
if
str
(
session_id
)
in
self
.
id2generator
and
self
.
id2generator
[
str
(
session_id
)]
not
in
self
.
gens_set
:
if
str
(
session_id
)
in
self
.
id2generator
:
await
self
.
id2generator
[
str
(
session_id
)].
async_cancel
(
session_id
)
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
def
end_session
(
self
,
session_id
:
int
):
self
.
running_session_ids
.
discard
(
session_id
)
async
def
end_session
(
self
,
session_id
:
int
):
"""Clear a session by a session_id."""
input_ids
=
[
self
.
tm_model
.
eos_id
]
end_generator
=
self
.
tm_model
.
create_instance
()
for
outputs
in
end_generator
.
stream_infer
(
session_id
,
input_ids
,
request_output_len
=
0
,
sequence_start
=
False
,
sequence_end
=
True
):
pass
self
.
id2step
[
str
(
session_id
)]
=
0
if
str
(
session_id
)
in
self
.
id2generator
and
self
.
id2generator
[
str
(
session_id
)]
not
in
self
.
gens_set
:
if
str
(
session_id
)
in
self
.
id2generator
:
await
self
.
id2generator
[
str
(
session_id
)].
async_end
(
session_id
)
self
.
id2step
[
str
(
session_id
)]
=
0
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
@
contextmanager
def
safe_run
(
self
,
session_id
:
Optional
[
int
]
=
None
):
self
.
running_session_ids
.
discard
(
session_id
)
@
asynccontextmanager
async
def
safe_run
(
self
,
session_id
:
Optional
[
int
]
=
None
):
"""A context manager to make sure server's safe running."""
try
:
yield
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
# noqa
self
.
stop_session
(
session_id
)
await
self
.
stop_session
(
session_id
)
raise
e
if
str
(
session_id
)
in
self
.
id2generator
and
self
.
id2generator
[
str
(
session_id
)]
not
in
self
.
gens_set
:
if
str
(
session_id
)
in
self
.
id2generator
:
self
.
gens_set
.
add
(
self
.
id2generator
[
str
(
session_id
)])
async
def
get_embeddings
(
self
,
prompt
,
do_prerpocess
=
False
):
if
do_prerpocess
:
prompt
=
self
.
model
.
get_prompt
(
prompt
)
input_ids
=
self
.
tokenizer
.
encode
(
prompt
)
return
input_ids
self
.
running_session_ids
.
discard
(
session_id
)
async
def
get_generator
(
self
,
stop
:
bool
,
session_id
:
int
):
"""Only return the model instance if it is available."""
if
stop
:
return
self
.
tm_model
.
create_instance
()
while
self
.
gens_set
==
set
():
await
asyncio
.
sleep
(
0
)
return
self
.
engine
.
create_instance
()
# waiting no generator is available or the same session_id is running
while
self
.
gens_set
==
set
()
or
session_id
in
self
.
running_session_ids
:
await
asyncio
.
sleep
(
0.1
)
generator
=
self
.
gens_set
.
pop
()
self
.
id2generator
[
str
(
session_id
)]
=
generator
self
.
running_session_ids
.
add
(
session_id
)
return
generator
def
batch_infer
(
self
,
prompts
:
Union
[
List
[
str
],
str
],
request_output_len
=
512
,
top_k
=
40
,
top_p
=
0.8
,
temperature
=
0.8
,
repetition_penalty
=
1.0
,
ignore_eos
=
False
,
do_preprocess
=
True
,
prompts
:
Union
[
List
[
str
],
str
,
List
[
Dict
],
List
[
List
[
Dict
]]],
gen_config
:
Optional
[
Union
[
GenerationConfig
,
EngineGenerationConfig
]]
=
None
,
do_preprocess
:
bool
=
True
,
**
kwargs
):
"""Inference a batch of prompts.
Args:
prompts (List[str] | str): a batch of prompts
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
input_str
=
isinstance
(
prompts
,
str
)
prompts
=
[
prompts
]
if
input_str
else
prompts
need_list_wrap
=
isinstance
(
prompts
,
str
)
or
isinstance
(
prompts
[
0
],
Dict
)
prompts
=
[
prompts
]
if
need_list_wrap
else
prompts
assert
isinstance
(
prompts
,
List
),
'prompts should be a list'
batch_size
=
len
(
prompts
)
outputs
=
[
''
]
*
batch_size
generators
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
generators
.
append
(
self
.
generate
(
prompt
,
i
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
request_output_len
=
request_output_len
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
ignore_eos
=
ignore_eos
,
repetition_penalty
=
repetition_penalty
,
do_preprocess
=
do_preprocess
,
**
kwargs
))
async
def
_inner_call
(
i
,
generator
):
async
for
out
in
generator
:
outputs
[
i
]
+=
out
.
response
async
def
gather
():
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
batch_size
)])
self
.
loop
.
run_until_complete
(
gather
())
outputs
=
outputs
[
0
]
if
input_str
else
outputs
if
gen_config
is
None
:
gen_config
=
GenerationConfig
()
if
type
(
gen_config
)
is
GenerationConfig
:
gen_config
=
EngineGenerationConfig
.
From
(
gen_config
,
self
.
tokenizer
)
# set random if it is not set
if
gen_config
.
random_seed
is
None
:
gen_config
.
random_seed
=
random
.
getrandbits
(
64
)
prompt_num
=
len
(
prompts
)
outputs
=
[
Response
(
''
,
0
,
0
,
i
)
for
i
in
range
(
prompt_num
)]
for
j
in
range
(
0
,
prompt_num
,
self
.
instance_num
):
batch_prompts
=
prompts
[
j
:
j
+
self
.
instance_num
]
generators
=
[]
for
i
,
prompt
in
enumerate
(
batch_prompts
):
generators
.
append
(
self
.
generate
(
prompt
,
i
,
gen_config
=
gen_config
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
do_preprocess
=
do_preprocess
,
**
kwargs
))
async
def
_inner_call
(
i
,
generator
):
async
for
out
in
generator
:
outputs
[
i
+
j
].
text
+=
out
.
response
outputs
[
i
+
j
].
generate_token_len
=
out
.
generate_token_len
outputs
[
i
+
j
].
input_token_len
=
out
.
input_token_len
outputs
[
i
+
j
].
finish_reason
=
out
.
finish_reason
async
def
gather
():
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
len
(
batch_prompts
))
])
self
.
loop
.
run_until_complete
(
gather
())
outputs
=
outputs
[
0
]
if
need_list_wrap
else
outputs
return
outputs
def
stream_infer
(
self
,
prompts
:
Union
[
List
[
str
],
str
,
List
[
Dict
],
List
[
List
[
Dict
]]],
gen_config
:
Optional
[
Union
[
GenerationConfig
,
EngineGenerationConfig
]]
=
None
,
do_preprocess
:
bool
=
True
,
**
kwargs
):
"""Inference a batch of prompts with stream mode.
Args:
prompts (List[str] | str | List[Dict] | List[Dict]): a batch of
prompts. It accepts: string prompt, a list of string prompts,
a chat history in OpenAI format or a list of chat history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
need_list_wrap
=
isinstance
(
prompts
,
str
)
or
isinstance
(
prompts
[
0
],
Dict
)
prompts
=
[
prompts
]
if
need_list_wrap
else
prompts
assert
isinstance
(
prompts
,
List
),
'prompts should be a list'
if
gen_config
is
None
:
gen_config
=
GenerationConfig
()
if
type
(
gen_config
)
is
GenerationConfig
:
gen_config
=
EngineGenerationConfig
.
From
(
gen_config
,
self
.
tokenizer
)
# set random if it is not set
if
gen_config
.
random_seed
is
None
:
gen_config
.
random_seed
=
random
.
getrandbits
(
64
)
prompt_num
=
len
(
prompts
)
outputs
=
Queue
()
generators
=
[]
for
j
in
range
(
0
,
prompt_num
,
self
.
instance_num
):
batch_prompts
=
prompts
[
j
:
j
+
self
.
instance_num
]
generators
=
[]
for
i
,
prompt
in
enumerate
(
batch_prompts
):
generators
.
append
(
self
.
generate
(
prompt
,
i
,
gen_config
=
gen_config
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
do_preprocess
=
do_preprocess
,
**
kwargs
))
async
def
_inner_call
(
i
,
generator
):
async
for
out
in
generator
:
outputs
.
put
(
Response
(
out
.
response
,
out
.
generate_token_len
,
out
.
input_token_len
,
i
+
j
,
out
.
finish_reason
))
async
def
gather
():
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
len
(
batch_prompts
))
])
outputs
.
put
(
None
)
proc
=
Thread
(
target
=
lambda
:
self
.
loop
.
run_until_complete
(
gather
()))
proc
.
start
()
while
True
:
try
:
out
=
outputs
.
get
(
timeout
=
0.001
)
if
out
is
None
:
break
yield
out
except
Empty
:
pass
proc
.
join
()
async
def
_get_prompt_input
(
self
,
prompt
:
str
,
do_preprocess
:
bool
,
sequence_start
:
bool
):
if
do_preprocess
:
prompt
=
self
.
chat_template
.
messages2prompt
(
prompt
,
sequence_start
)
input_ids
=
self
.
tokenizer
.
encode
(
prompt
,
add_bos
=
sequence_start
)
return
{
'prompt'
:
prompt
,
'input_ids'
:
input_ids
}
async
def
generate
(
self
,
messages
,
session_id
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
# no interactive mode by default
step
=
0
,
request_output_len
=
512
,
stop
=
False
,
top_k
=
40
,
top_p
=
0.8
,
temperature
=
0.8
,
repetition_penalty
=
1.0
,
ignore_eos
=
False
,
do_preprocess
=
True
,
session_id
:
int
,
gen_config
:
Optional
[
Union
[
GenerationConfig
,
EngineGenerationConfig
]]
=
None
,
stream_response
:
bool
=
True
,
sequence_start
:
bool
=
True
,
sequence_end
:
bool
=
True
,
# no interactive mode by default
step
:
int
=
0
,
do_preprocess
:
bool
=
True
,
**
kwargs
):
"""Generate responses.
Args:
messages (str | List): chat history or prompt
session_id (int): the session id
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
stream_response (bool): whether return responses streamingly
request_output_len (int): output token nums
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): whether stop inference
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
if
str
(
session_id
)
not
in
self
.
id2step
:
self
.
id2step
[
str
(
session_id
)]
=
0
if
step
!=
0
:
self
.
id2step
[
str
(
session_id
)]
=
step
seed
=
random
.
getrandbits
(
64
)
if
gen_config
is
None
:
gen_config
=
GenerationConfig
()
if
type
(
gen_config
)
is
GenerationConfig
:
gen_config
=
EngineGenerationConfig
.
From
(
gen_config
,
self
.
tokenizer
)
if
gen_config
.
stop_words
is
None
:
gen_config
.
stop_words
=
self
.
stop_words
# set random if it is not set and sequence_start is True
if
gen_config
.
random_seed
is
None
and
sequence_start
:
gen_config
.
random_seed
=
random
.
getrandbits
(
64
)
prompt
=
messages
if
do_preprocess
:
prompt
=
self
.
model
.
messages2prompt
(
prompt
,
sequence_start
)
input_ids
=
self
.
tokenizer
.
encode
(
prompt
,
add_bos
=
sequence_start
)
finish_reason
=
None
request_output_len
=
min
(
request_output_len
,
self
.
tm_model
.
session_len
-
self
.
id2step
[
str
(
session_id
)]
-
prompt_input
=
await
self
.
_get_prompt_input
(
prompt
,
do_preprocess
,
sequence_start
)
prompt
=
prompt_input
[
'prompt'
]
logger
.
info
(
f
'Prompt with applied chat template:
\n
{
prompt
}
'
)
input_ids
=
prompt_input
[
'input_ids'
]
if
gen_config
.
max_new_tokens
is
None
:
# for interactive endpoint, will try maximum possible token num
gen_config
.
max_new_tokens
=
max
(
128
,
self
.
session_len
-
self
.
id2step
[
str
(
session_id
)]
-
len
(
input_ids
))
request_output_len
=
max
(
0
,
request_output_len
)
if
stop
is
True
:
self
.
stop_session
(
session_id
)
yield
GenOut
(
''
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
0
,
finish_reason
)
elif
self
.
id2step
[
str
(
session_id
)]
+
len
(
input_ids
)
+
request_output_len
>
self
.
tm_model
.
session_len
:
finish_reason
=
None
logger
.
info
(
f
'session_id=
{
session_id
}
, '
f
'history_tokens=
{
self
.
id2step
[
str
(
session_id
)]
}
, '
f
'input_tokens=
{
len
(
input_ids
)
}
, '
f
'max_new_tokens=
{
gen_config
.
max_new_tokens
}
, '
f
'seq_start=
{
sequence_start
}
, seq_end=
{
sequence_end
}
, '
f
'step=
{
step
}
, prep=
{
do_preprocess
}
'
)
if
self
.
id2step
[
str
(
session_id
)]
+
len
(
input_ids
)
+
gen_config
.
max_new_tokens
>
self
.
session_len
:
logger
.
warning
(
f
'run out of tokens. session_id=
{
session_id
}
'
)
finish_reason
=
'length'
yield
GenOut
(
''
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
0
,
finish_reason
)
if
sequence_end
is
True
and
sequence_start
is
False
:
self
.
end_session
(
session_id
)
await
self
.
end_session
(
session_id
)
else
:
generator
=
await
self
.
get_generator
(
stop
,
session_id
)
with
self
.
safe_run
(
session_id
):
response_size
=
0
generator
=
await
self
.
get_generator
(
False
,
session_id
)
async
with
self
.
safe_run
(
session_id
):
state
=
DetokenizeState
()
async
for
outputs
in
generator
.
async_stream_infer
(
session_id
=
session_id
,
input_ids
=
[
input_ids
],
**
prompt_input
,
gen_config
=
gen_config
,
stream_output
=
stream_response
,
request_output_len
=
request_output_len
,
sequence_start
=
(
sequence_start
),
sequence_start
=
sequence_start
,
sequence_end
=
sequence_end
,
step
=
self
.
id2step
[
str
(
session_id
)],
stop
=
stop
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
,
random_seed
=
seed
if
sequence_start
else
None
):
res
,
tokens
=
outputs
[
0
]
step
=
self
.
id2step
[
str
(
session_id
)]):
_
,
res
,
tokens
=
outputs
# decode res
response
=
self
.
tokenizer
.
decode
(
res
.
tolist
(),
offset
=
response_size
)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if
response
.
endswith
(
'�'
):
continue
response
,
state
=
self
.
tokenizer
.
detokenize_incrementally
(
res
,
state
,
skip_special_tokens
=
gen_config
.
skip_special_tokens
)
# response, history token len,
# input token len, gen token len
yield
GenOut
(
response
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
tokens
,
finish_reason
)
response_size
=
tokens
finish_reason
=
'length'
\
if
tokens
>=
request_output_l
en
else
'stop'
#
`response_size` might be note updated since
#
` if response.endswith('�')`
if
response
_size
==
tokens
:
if
tokens
>=
gen_config
.
max_new_tok
en
s
else
'stop'
#
utf-8 char at the end means it's a potential unfinished
#
byte sequence
if
not
response
.
endswith
(
'�'
)
:
response
=
''
# avaid returning the last response twice
yield
GenOut
(
response
,
self
.
id2step
[
str
(
session_id
)],
len
(
input_ids
),
tokens
,
finish_reason
)
# update step
self
.
id2step
[
str
(
session_id
)]
+=
len
(
input_ids
)
+
tokens
if
sequence_end
or
stop
:
if
sequence_end
:
self
.
id2step
[
str
(
session_id
)]
=
0
# manually end pytorch session
# TODO modify pytorch or turbomind api
if
self
.
backend
==
'pytorch'
and
sequence_end
:
await
self
.
end_session
(
session_id
)
def
chat
(
self
,
prompt
:
str
,
session
=
None
,
gen_config
:
Optional
[
Union
[
GenerationConfig
,
EngineGenerationConfig
]]
=
None
,
do_preprocess
:
bool
=
True
,
**
kwargs
)
->
Session
:
"""Chat.
Args:
prompt (str): prompt
session (Session): the chat session
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
**kwargs (dict): ad hoc parametrization of `gen_config
"""
if
session
is
None
:
session
=
Session
()
session
.
_engine
=
self
.
engine
# sync & init
session
.
_prompt
=
prompt
session
.
_response
=
None
sequence_start
=
session
.
_step
==
0
async
def
_work
():
resp
=
Response
(
''
,
-
1
,
-
1
,
session
.
_id
)
async
for
output
in
self
.
generate
(
prompt
,
session_id
=
session
.
_id
,
gen_config
=
gen_config
,
stream_response
=
False
,
sequence_start
=
sequence_start
,
sequence_end
=
False
,
step
=
session
.
_step
,
do_preprocess
=
do_preprocess
,
**
kwargs
):
resp
=
session
.
_merge_response
(
resp
,
output
)
return
resp
from
lmdeploy.pytorch.engine.request
import
_run_until_complete
resp
=
_run_until_complete
(
_work
())
session
.
_response
=
resp
session
.
_step
+=
resp
.
generate_token_len
+
resp
.
input_token_len
session
.
history
.
append
((
session
.
_prompt
,
resp
.
text
))
return
session
lmdeploy/serve/gradio/api_server_backend.py
View file @
d7117b95
...
...
@@ -17,7 +17,8 @@ class InterFace:
def
chat_stream_restful
(
instruction
:
str
,
state_chatbot
:
Sequence
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
):
session_id
:
int
,
top_p
:
float
,
temperature
:
float
,
request_output_len
:
int
):
"""Chat with AI assistant.
Args:
...
...
@@ -33,9 +34,11 @@ def chat_stream_restful(instruction: str, state_chatbot: Sequence,
instruction
,
f
'
{
InterFace
.
api_server_url
}
/v1/chat/interactive'
,
session_id
=
session_id
,
request_output_len
=
512
,
interactive_mode
=
True
):
if
finish_reason
==
'length'
:
request_output_len
=
request_output_len
,
interactive_mode
=
True
,
top_p
=
top_p
,
temperature
=
temperature
):
if
finish_reason
==
'length'
and
tokens
==
0
:
gr
.
Warning
(
'WARNING: exceed session max length.'
' Please restart the session by reset button.'
)
if
tokens
<
0
:
...
...
@@ -94,7 +97,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
f
'
{
InterFace
.
api_server_url
}
/v1/chat/interactive'
,
session_id
=
session_id
,
request_output_len
=
0
,
stop
=
True
,
cancel
=
True
,
interactive_mode
=
True
):
pass
# end the session
...
...
@@ -106,6 +109,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
interactive_mode
=
False
):
pass
# resume the session
# TODO this is not proper if api server is running pytorch backend
messages
=
[]
for
qa
in
state_chatbot
:
messages
.
append
(
dict
(
role
=
'user'
,
content
=
qa
[
0
]))
...
...
@@ -155,10 +159,22 @@ def run_api_server(api_server_url: str,
with
gr
.
Row
():
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
with
gr
.
Row
():
request_output_len
=
gr
.
Slider
(
1
,
2048
,
value
=
512
,
step
=
1
,
label
=
'Maximum new tokens'
)
top_p
=
gr
.
Slider
(
0.01
,
1
,
value
=
0.8
,
step
=
0.01
,
label
=
'Top_p'
)
temperature
=
gr
.
Slider
(
0.01
,
1.5
,
value
=
0.7
,
step
=
0.01
,
label
=
'Temperature'
)
send_event
=
instruction_txtbox
.
submit
(
chat_stream_restful
,
[
instruction_txtbox
,
state_chatbot
,
cancel_btn
,
reset_btn
,
state_session_id
state_session_id
,
top_p
,
temperature
,
request_output_len
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
instruction_txtbox
.
submit
(
lambda
:
gr
.
Textbox
.
update
(
value
=
''
),
...
...
lmdeploy/serve/gradio/app.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Literal
,
Optional
,
Union
from
lmdeploy.archs
import
get_task
from
lmdeploy.messages
import
PytorchEngineConfig
,
TurbomindEngineConfig
from
lmdeploy.model
import
ChatTemplateConfig
def
run
(
model_path_or_server
:
str
,
server_name
:
str
=
'0.0.0.0'
,
server_port
:
int
=
6006
,
batch_size
:
int
=
32
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
PytorchEngineConfig
,
TurbomindEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
tp
:
int
=
1
,
model_name
:
str
=
None
,
**
kwargs
):
...
...
@@ -19,6 +28,12 @@ def run(model_path_or_server: str,
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
tp (int): tensor parallel for Turbomind
"""
if
':'
in
model_path_or_server
:
...
...
@@ -31,11 +46,22 @@ def run(model_path_or_server: str,
run_triton_server
run_triton_server
(
model_path_or_server
,
server_name
,
server_port
)
else
:
from
lmdeploy.serve.gradio.turbomind_coupled
import
run_local
pipeline_type
,
_
=
get_task
(
model_path_or_server
)
if
pipeline_type
==
'vlm'
:
from
lmdeploy.serve.gradio.vl
import
run_local
assert
backend
==
'turbomind'
,
'vlm only support turbomind backend'
if
backend_config
is
not
None
and
\
backend_config
.
session_len
is
None
:
backend_config
.
session_len
=
8192
else
:
from
lmdeploy.serve.gradio.turbomind_coupled
import
run_local
run_local
(
model_path_or_server
,
model_name
=
model_name
,
server_name
=
server_name
,
server_port
=
server_port
,
backend
=
backend
,
backend_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
model_name
=
model_name
,
batch_size
=
batch_size
,
tp
=
tp
,
**
kwargs
)
...
...
lmdeploy/serve/gradio/constants.py
View file @
d7117b95
...
...
@@ -24,5 +24,5 @@ THEME = gr.themes.Soft(
secondary_hue
=
gr
.
themes
.
colors
.
sky
,
font
=
[
gr
.
themes
.
GoogleFont
(
'Inconsolata'
),
'Arial'
,
'sans-serif'
])
enable_btn
=
gr
.
Button
.
update
(
interactive
=
True
)
disable_btn
=
gr
.
Button
.
update
(
interactive
=
False
)
enable_btn
=
gr
.
update
(
interactive
=
True
)
disable_btn
=
gr
.
update
(
interactive
=
False
)
lmdeploy/serve/gradio/triton_server_backend.py
View file @
d7117b95
...
...
@@ -16,7 +16,8 @@ class InterFace:
def
chat_stream
(
state_chatbot
:
Sequence
,
llama_chatbot
:
Chatbot
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
):
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
,
top_p
:
float
,
temperature
:
float
,
request_output_len
:
int
):
"""Chat with AI assistant.
Args:
...
...
@@ -30,7 +31,12 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction
=
state_chatbot
[
-
1
][
0
]
bot_response
=
llama_chatbot
.
stream_infer
(
session_id
,
instruction
,
f
'
{
session_id
}
-
{
len
(
state_chatbot
)
}
'
)
session_id
,
instruction
,
f
'
{
session_id
}
-
{
len
(
state_chatbot
)
}
'
,
request_output_len
=
request_output_len
,
top_p
=
top_p
,
temperature
=
temperature
)
for
status
,
tokens
,
_
in
bot_response
:
state_chatbot
[
-
1
]
=
(
state_chatbot
[
-
1
][
0
],
tokens
)
...
...
@@ -108,12 +114,24 @@ def run_triton_server(triton_server_addr: str,
with
gr
.
Row
():
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
with
gr
.
Row
():
request_output_len
=
gr
.
Slider
(
1
,
2048
,
value
=
512
,
step
=
1
,
label
=
'Maximum new tokens'
)
top_p
=
gr
.
Slider
(
0.01
,
1
,
value
=
0.8
,
step
=
0.01
,
label
=
'Top_p'
)
temperature
=
gr
.
Slider
(
0.01
,
1.5
,
value
=
0.7
,
step
=
0.01
,
label
=
'Temperature'
)
send_event
=
instruction_txtbox
.
submit
(
add_instruction
,
[
instruction_txtbox
,
state_chatbot
],
[
instruction_txtbox
,
state_chatbot
]).
then
(
chat_stream
,
[
state_chatbot
,
llama_chatbot
,
cancel_btn
,
reset_btn
,
state_session_id
state_session_id
,
top_p
,
temperature
,
request_output_len
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
cancel_btn
.
click
(
cancel_func
,
...
...
lmdeploy/serve/gradio/turbomind_coupled.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
import
random
from
threading
import
Lock
from
typing
import
Optional
,
Sequence
from
typing
import
Literal
,
Optional
,
Sequence
,
Union
import
gradio
as
gr
from
lmdeploy.messages
import
(
GenerationConfig
,
PytorchEngineConfig
,
TurbomindEngineConfig
)
from
lmdeploy.model
import
ChatTemplateConfig
from
lmdeploy.serve.async_engine
import
AsyncEngine
from
lmdeploy.serve.gradio.constants
import
CSS
,
THEME
,
disable_btn
,
enable_btn
...
...
@@ -14,13 +18,10 @@ class InterFace:
lock
=
Lock
()
async
def
chat_stream_local
(
instruction
:
str
,
state_chatbot
:
Sequence
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
,
):
async
def
chat_stream_local
(
instruction
:
str
,
state_chatbot
:
Sequence
,
cancel_btn
:
gr
.
Button
,
reset_btn
:
gr
.
Button
,
session_id
:
int
,
top_p
:
float
,
temperature
:
float
,
request_output_len
:
int
):
"""Chat with AI assistant.
Args:
...
...
@@ -33,15 +34,23 @@ async def chat_stream_local(
state_chatbot
=
state_chatbot
+
[(
instruction
,
None
)]
yield
(
state_chatbot
,
state_chatbot
,
disable_btn
,
enable_btn
)
gen_config
=
GenerationConfig
(
max_new_tokens
=
request_output_len
,
top_p
=
top_p
,
top_k
=
40
,
temperature
=
temperature
,
random_seed
=
random
.
getrandbits
(
64
)
if
len
(
state_chatbot
)
==
1
else
None
)
async
for
outputs
in
InterFace
.
async_engine
.
generate
(
instruction
,
session_id
,
gen_config
=
gen_config
,
stream_response
=
True
,
sequence_start
=
(
len
(
state_chatbot
)
==
1
),
sequence_end
=
False
):
response
=
outputs
.
response
if
outputs
.
finish_reason
==
'length'
:
if
outputs
.
finish_reason
==
'length'
and
\
outputs
.
generate_token_len
==
0
:
gr
.
Warning
(
'WARNING: exceed session max length.'
' Please restart the session by reset button.'
)
if
outputs
.
generate_token_len
<
0
:
...
...
@@ -69,7 +78,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
"""
state_chatbot
=
[]
# end the session
InterFace
.
async_engine
.
end_session
(
session_id
)
await
InterFace
.
async_engine
.
end_session
(
session_id
)
return
(
state_chatbot
,
state_chatbot
,
gr
.
Textbox
.
update
(
value
=
''
))
...
...
@@ -85,28 +94,36 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
session_id (int): the session id
"""
yield
(
state_chatbot
,
disable_btn
,
disable_btn
)
InterFace
.
async_engine
.
stop_session
(
session_id
)
InterFace
.
async_engine
.
end_session
(
session_id
)
messages
=
[]
for
qa
in
state_chatbot
:
messages
.
append
(
dict
(
role
=
'user'
,
content
=
qa
[
0
]))
if
qa
[
1
]
is
not
None
:
messages
.
append
(
dict
(
role
=
'assistant'
,
content
=
qa
[
1
]))
async
for
out
in
InterFace
.
async_engine
.
generate
(
messages
,
session_id
,
request_output_len
=
0
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
False
):
pass
yield
(
state_chatbot
,
disable_btn
,
enable_btn
)
await
InterFace
.
async_engine
.
stop_session
(
session_id
)
# pytorch backend does not support resume chat history now
if
InterFace
.
async_engine
.
backend
==
'pytorch'
:
yield
(
state_chatbot
,
disable_btn
,
enable_btn
)
else
:
await
InterFace
.
async_engine
.
end_session
(
session_id
)
messages
=
[]
for
qa
in
state_chatbot
:
messages
.
append
(
dict
(
role
=
'user'
,
content
=
qa
[
0
]))
if
qa
[
1
]
is
not
None
:
messages
.
append
(
dict
(
role
=
'assistant'
,
content
=
qa
[
1
]))
gen_config
=
GenerationConfig
(
max_new_tokens
=
0
)
async
for
out
in
InterFace
.
async_engine
.
generate
(
messages
,
session_id
,
gen_config
=
gen_config
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
False
):
pass
yield
(
state_chatbot
,
disable_btn
,
enable_btn
)
def
run_local
(
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
server_name
:
str
=
'localhost'
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
PytorchEngineConfig
,
TurbomindEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
server_name
:
str
=
'0.0.0.0'
,
server_port
:
int
=
6006
,
batch_size
:
int
=
4
,
tp
:
int
=
1
,
**
kwargs
):
"""chat with AI assistant through web ui.
...
...
@@ -122,22 +139,32 @@ def run_local(model_path: str,
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
on huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
server_name (str): the ip address of gradio server
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): the ip address of gradio server. Default to
"0.0.0.0". For huggingface space demo, it should be
"huggingface-space".
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
"""
InterFace
.
async_engine
=
AsyncEngine
(
model_path
=
model_path
,
model_name
=
model_name
,
instance_num
=
batch_size
,
tp
=
tp
,
**
kwargs
)
InterFace
.
async_engine
=
AsyncEngine
(
model_path
=
model_path
,
backend
=
backend
,
backend_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
model_name
=
model_name
,
tp
=
tp
,
**
kwargs
)
with
gr
.
Blocks
(
css
=
CSS
,
theme
=
THEME
)
as
demo
:
state_chatbot
=
gr
.
State
([])
...
...
@@ -148,17 +175,29 @@ def run_local(model_path: str,
chatbot
=
gr
.
Chatbot
(
elem_id
=
'chatbot'
,
label
=
InterFace
.
async_engine
.
tm_model
.
model_name
)
label
=
InterFace
.
async_engine
.
engine
.
model_name
)
instruction_txtbox
=
gr
.
Textbox
(
placeholder
=
'Please input the instruction'
,
label
=
'Instruction'
)
with
gr
.
Row
():
cancel_btn
=
gr
.
Button
(
value
=
'Cancel'
,
interactive
=
False
)
reset_btn
=
gr
.
Button
(
value
=
'Reset'
)
with
gr
.
Row
():
request_output_len
=
gr
.
Slider
(
1
,
2048
,
value
=
512
,
step
=
1
,
label
=
'Maximum new tokens'
)
top_p
=
gr
.
Slider
(
0.01
,
1
,
value
=
0.8
,
step
=
0.01
,
label
=
'Top_p'
)
temperature
=
gr
.
Slider
(
0.01
,
1.5
,
value
=
0.7
,
step
=
0.01
,
label
=
'Temperature'
)
send_event
=
instruction_txtbox
.
submit
(
chat_stream_local
,
[
instruction_txtbox
,
state_chatbot
,
cancel_btn
,
reset_btn
,
state_session_id
state_session_id
,
top_p
,
temperature
,
request_output_len
],
[
state_chatbot
,
chatbot
,
cancel_btn
,
reset_btn
])
instruction_txtbox
.
submit
(
lambda
:
gr
.
Textbox
.
update
(
value
=
''
),
...
...
@@ -184,14 +223,19 @@ def run_local(model_path: str,
demo
.
load
(
init
,
inputs
=
None
,
outputs
=
[
state_session_id
])
print
(
f
'server is gonna mount on: http://
{
server_name
}
:
{
server_port
}
'
)
demo
.
queue
(
concurrency_count
=
batch_size
,
max_size
=
100
,
api_open
=
True
).
launch
(
max_threads
=
10
,
share
=
True
,
server_port
=
server_port
,
server_name
=
server_name
,
)
if
server_name
==
'huggingface-space'
:
demo
.
queue
(
concurrency_count
=
InterFace
.
async_engine
.
instance_num
,
max_size
=
100
).
launch
()
else
:
print
(
f
'server is gonna mount on: http://
{
server_name
}
:
{
server_port
}
'
)
demo
.
queue
(
concurrency_count
=
InterFace
.
async_engine
.
instance_num
,
max_size
=
100
,
api_open
=
True
).
launch
(
max_threads
=
10
,
share
=
True
,
server_port
=
server_port
,
server_name
=
server_name
,
)
if
__name__
==
'__main__'
:
...
...
lmdeploy/serve/openai/api_client.py
View file @
d7117b95
...
...
@@ -4,8 +4,11 @@ from typing import Any, Dict, Iterable, List, Optional, Union
import
requests
from
lmdeploy.utils
import
get_logger
def
get_model_list
(
api_url
:
str
):
"""Get model list from api server."""
response
=
requests
.
get
(
api_url
)
if
hasattr
(
response
,
'text'
):
model_list
=
json
.
loads
(
response
.
text
)
...
...
@@ -14,15 +17,31 @@ def get_model_list(api_url: str):
return
None
def
json_loads
(
content
):
"""Loads content to json format."""
try
:
content
=
json
.
loads
(
content
)
return
content
except
:
# noqa
logger
=
get_logger
(
'lmdeploy'
)
logger
.
warning
(
f
'weird json content
{
content
}
'
)
return
''
class
APIClient
:
"""Chatbot for LLaMA series models with turbomind as inference engine.
Args:
api_server_url (str): communicating address 'http://<ip>:<port>' of
api_server
api_key (str | None): api key. Default to None, which means no
api key will be used.
"""
def
__init__
(
self
,
api_server_url
:
str
,
**
kwargs
):
def
__init__
(
self
,
api_server_url
:
str
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
):
self
.
api_server_url
=
api_server_url
self
.
chat_intractive_v1_url
=
f
'
{
api_server_url
}
/v1/chat/interactive'
self
.
chat_completions_v1_url
=
f
'
{
api_server_url
}
/v1/chat/completions'
...
...
@@ -30,6 +49,10 @@ class APIClient:
self
.
models_v1_url
=
f
'
{
api_server_url
}
/v1/models'
self
.
encode_v1_url
=
f
'
{
api_server_url
}
/v1/encode'
self
.
_available_models
=
None
self
.
api_key
=
api_key
self
.
headers
=
{
'content-type'
:
'application/json'
}
if
api_key
is
not
None
:
self
.
headers
[
'Authorization'
]
=
f
'Bearer
{
api_key
}
'
@
property
def
available_models
(
self
):
...
...
@@ -38,7 +61,7 @@ class APIClient:
return
self
.
_available_models
response
=
requests
.
get
(
self
.
models_v1_url
)
if
hasattr
(
response
,
'text'
):
model_list
=
json
.
loads
(
response
.
text
)
model_list
=
json
_
loads
(
response
.
text
)
model_list
=
model_list
.
pop
(
'data'
,
[])
self
.
_available_models
=
[
item
[
'id'
]
for
item
in
model_list
]
return
self
.
_available_models
...
...
@@ -57,15 +80,14 @@ class APIClient:
when it is not. Default to True.
Return: (input_ids, length)
"""
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
encode_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
dict
(
input
=
input
,
do_preprocess
=
do_preprocess
,
add_bos
=
add_bos
),
stream
=
False
)
if
hasattr
(
response
,
'text'
):
output
=
json
.
loads
(
response
.
text
)
output
=
json
_
loads
(
response
.
text
)
return
output
[
'input_ids'
],
output
[
'length'
]
return
None
,
None
...
...
@@ -75,8 +97,8 @@ class APIClient:
temperature
:
Optional
[
float
]
=
0.7
,
top_p
:
Optional
[
float
]
=
1.0
,
n
:
Optional
[
int
]
=
1
,
max_tokens
:
Optional
[
int
]
=
512
,
stop
:
Optional
[
bool
]
=
Fals
e
,
max_tokens
:
Optional
[
int
]
=
None
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Non
e
,
stream
:
Optional
[
bool
]
=
False
,
presence_penalty
:
Optional
[
float
]
=
0.0
,
frequency_penalty
:
Optional
[
float
]
=
0.0
,
...
...
@@ -84,12 +106,14 @@ class APIClient:
repetition_penalty
:
Optional
[
float
]
=
1.0
,
session_id
:
Optional
[
int
]
=
-
1
,
ignore_eos
:
Optional
[
bool
]
=
False
,
skip_special_tokens
:
Optional
[
bool
]
=
True
,
**
kwargs
):
"""Chat completion v1.
Args:
model: model name. Available from self.available_models.
messages: string prompt or chat history in OpenAI format.
messages: string prompt or chat history in OpenAI format. Chat
history example: `[{"role": "user", "content": "hi"}]`.
temperature (float): to modulate the next token probability
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
...
...
@@ -97,11 +121,15 @@ class APIClient:
n (int): How many chat completion choices to generate for each
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
max_tokens (int): output token nums
max_tokens (int | None): output token nums. Default to None.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields:
json objects in openai formats
...
...
@@ -111,9 +139,8 @@ class APIClient:
for
k
,
v
in
locals
().
copy
().
items
()
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
}
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
chat_completions_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
pload
,
stream
=
stream
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
...
...
@@ -126,11 +153,11 @@ class APIClient:
continue
if
decoded
[:
6
]
==
'data: '
:
decoded
=
decoded
[
6
:]
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
else
:
decoded
=
chunk
.
decode
(
'utf-8'
)
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
def
chat_interactive_v1
(
self
,
...
...
@@ -138,13 +165,14 @@ class APIClient:
session_id
:
int
=
-
1
,
interactive_mode
:
bool
=
False
,
stream
:
bool
=
False
,
stop
:
bool
=
Fals
e
,
request_output_len
:
int
=
512
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Non
e
,
request_output_len
:
Optional
[
int
]
=
None
,
top_p
:
float
=
0.8
,
top_k
:
int
=
40
,
temperature
:
float
=
0.8
,
repetition_penalty
:
float
=
1.0
,
ignore_eos
:
bool
=
False
,
skip_special_tokens
:
Optional
[
bool
]
=
True
,
**
kwargs
):
"""Interactive completions.
...
...
@@ -162,8 +190,10 @@ class APIClient:
interactive mode, session history is kept on the server (and
vice versa).
stream: whether to stream the results or not.
stop: whether to stop the session response or not.
request_output_len (int): output token nums
stop (str | List[str] | None): To stop generating further tokens.
Only accept stop words that's encoded to one token idex.
request_output_len (int): output token nums. If not specified,
will use maximum possible number for a session.
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
...
...
@@ -173,18 +203,20 @@ class APIClient:
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
json objects consist of text, tokens, finish_reason
json objects consist of text, tokens, input_tokens,
history_tokens, finish_reason
"""
pload
=
{
k
:
v
for
k
,
v
in
locals
().
copy
().
items
()
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
}
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
chat_intractive_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
pload
,
stream
=
stream
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
...
...
@@ -192,7 +224,7 @@ class APIClient:
delimiter
=
b
'
\n
'
):
if
chunk
:
decoded
=
chunk
.
decode
(
'utf-8'
)
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
def
completions_v1
(
...
...
@@ -204,12 +236,15 @@ class APIClient:
n
:
Optional
[
int
]
=
1
,
max_tokens
:
Optional
[
int
]
=
16
,
stream
:
Optional
[
bool
]
=
False
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
top_p
:
Optional
[
float
]
=
1.0
,
top_k
:
Optional
[
int
]
=
40
,
user
:
Optional
[
str
]
=
None
,
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
,
session_id
:
Optional
[
int
]
=
-
1
,
ignore_eos
:
Optional
[
bool
]
=
False
,
skip_special_tokens
:
Optional
[
bool
]
=
True
,
**
kwargs
):
"""Chat completion v1.
...
...
@@ -223,14 +258,20 @@ class APIClient:
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
n (int): How many chat completion choices to generate for each
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
user (str): A unique identifier representing your end-user.
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
session_id (int): Deprecated.
Yields:
json objects in openai formats
...
...
@@ -240,9 +281,8 @@ class APIClient:
for
k
,
v
in
locals
().
copy
().
items
()
if
k
[:
2
]
!=
'__'
and
k
not
in
[
'self'
]
}
headers
=
{
'content-type'
:
'application/json'
}
response
=
requests
.
post
(
self
.
completions_v1_url
,
headers
=
headers
,
headers
=
self
.
headers
,
json
=
pload
,
stream
=
stream
)
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
...
...
@@ -250,16 +290,16 @@ class APIClient:
delimiter
=
b
'
\n
'
):
if
chunk
:
if
stream
:
decoded
=
chunk
.
decode
(
'utf-8'
)
[
6
:]
decoded
=
chunk
.
decode
(
'utf-8'
)
if
decoded
==
'data: [DONE]'
:
continue
if
decoded
[:
6
]
==
'data: '
:
decoded
=
decoded
[
6
:]
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
else
:
decoded
=
chunk
.
decode
(
'utf-8'
)
output
=
json
.
loads
(
decoded
)
output
=
json
_
loads
(
decoded
)
yield
output
def
chat
(
self
,
...
...
@@ -307,7 +347,7 @@ class APIClient:
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
ignore_eos
):
if
outputs
[
'finish_reason'
]
==
'length'
:
if
outputs
[
'finish_reason'
]
==
'length'
and
outputs
[
'tokens'
]
==
0
:
print
(
'WARNING: exceed session max length.'
' Please end the session.'
)
yield
outputs
[
'text'
],
outputs
[
'tokens'
],
outputs
[
'finish_reason'
]
...
...
@@ -334,15 +374,21 @@ def input_prompt():
return
'
\n
'
.
join
(
iter
(
input
,
sentinel
))
def
get_streaming_response
(
prompt
:
str
,
api_url
:
str
,
session_id
:
int
,
request_output_len
:
int
=
512
,
stream
:
bool
=
True
,
interactive_mode
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
stop
:
bool
=
False
)
->
Iterable
[
List
[
str
]]:
def
get_streaming_response
(
prompt
:
str
,
api_url
:
str
,
session_id
:
int
,
request_output_len
:
int
=
512
,
stream
:
bool
=
True
,
interactive_mode
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
cancel
:
bool
=
False
,
top_p
:
float
=
0.8
,
temperature
:
float
=
0.7
,
api_key
:
Optional
[
str
]
=
None
)
->
Iterable
[
List
[
str
]]:
headers
=
{
'User-Agent'
:
'Test Client'
}
if
api_key
is
not
None
:
headers
[
'Authorization'
]
=
f
'Bearer
{
api_key
}
'
pload
=
{
'prompt'
:
prompt
,
'stream'
:
stream
,
...
...
@@ -350,7 +396,9 @@ def get_streaming_response(prompt: str,
'request_output_len'
:
request_output_len
,
'interactive_mode'
:
interactive_mode
,
'ignore_eos'
:
ignore_eos
,
'stop'
:
stop
'cancel'
:
cancel
,
'top_p'
:
top_p
,
'temperature'
:
temperature
}
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
...
...
@@ -360,15 +408,18 @@ def get_streaming_response(prompt: str,
decode_unicode
=
False
,
delimiter
=
b
'
\n
'
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
'utf-8'
))
data
=
json
_
loads
(
chunk
.
decode
(
'utf-8'
))
output
=
data
.
pop
(
'text'
,
''
)
tokens
=
data
.
pop
(
'tokens'
,
0
)
finish_reason
=
data
.
pop
(
'finish_reason'
,
None
)
yield
output
,
tokens
,
finish_reason
def
main
(
api_server_url
:
str
,
session_id
:
int
=
0
):
api_client
=
APIClient
(
api_server_url
)
def
main
(
api_server_url
:
str
,
session_id
:
int
=
0
,
api_key
:
Optional
[
str
]
=
None
):
"""Main function to chat in terminal."""
api_client
=
APIClient
(
api_server_url
,
api_key
=
api_key
)
while
True
:
prompt
=
input_prompt
()
if
prompt
in
[
'exit'
,
'end'
]:
...
...
lmdeploy/serve/openai/api_server.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
import
asyncio
import
os
import
random
import
time
from
http
import
HTTPStatus
from
typing
import
AsyncGenerator
,
List
,
Optional
from
typing
import
AsyncGenerator
,
List
,
Literal
,
Optional
,
Union
import
uvicorn
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
Request
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
fastapi.security.http
import
HTTPAuthorizationCredentials
,
HTTPBearer
from
lmdeploy.archs
import
get_task
from
lmdeploy.messages
import
(
GenerationConfig
,
PytorchEngineConfig
,
TurbomindEngineConfig
)
from
lmdeploy.model
import
ChatTemplateConfig
from
lmdeploy.serve.async_engine
import
AsyncEngine
from
lmdeploy.serve.openai.protocol
import
(
# noqa: E501
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionRequest
,
ChatCompletionRequestQos
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionRequestQos
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
EmbeddingsRequest
,
EncodeRequest
,
EncodeResponse
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
)
GenerateRequest
,
GenerateRequestQos
,
GenerateResponse
,
ModelCard
,
ModelList
,
ModelPermission
,
UsageInfo
)
from
lmdeploy.serve.qos_engine.qos_engine
import
QosEngine
from
lmdeploy.utils
import
get_logger
class
VariableInterface
:
"""A IO interface maintaining variables."""
async_engine
:
AsyncEngine
=
None
session_id
:
int
=
0
api_keys
:
Optional
[
List
[
str
]]
=
None
qos_engine
:
QosEngine
=
None
request_hosts
=
[]
app
=
FastAPI
(
docs_url
=
'/'
)
get_bearer_token
=
HTTPBearer
(
auto_error
=
False
)
async
def
check_api_key
(
auth
:
Optional
[
HTTPAuthorizationCredentials
]
=
Depends
(
get_bearer_token
),
)
->
str
:
"""Check if client provide valid api key.
Adopted from https://github.com/lm-sys/FastChat/blob/v0.2.35/fastchat/serve/openai_api_server.py#L108-L127
"""
# noqa
if
VariableInterface
.
api_keys
:
if
auth
is
None
or
(
token
:
=
auth
.
credentials
)
not
in
VariableInterface
.
api_keys
:
raise
HTTPException
(
status_code
=
401
,
detail
=
{
'error'
:
{
'message'
:
'Please request with valid api key!'
,
'type'
:
'invalid_request_error'
,
'param'
:
None
,
'code'
:
'invalid_api_key'
,
}
},
)
return
token
else
:
# api_keys not set; allow all
return
None
def
get_model_list
():
...
...
@@ -37,10 +74,10 @@ def get_model_list():
Only provided one now.
"""
return
[
VariableInterface
.
async_engine
.
tm_model
.
model_name
]
return
[
VariableInterface
.
async_engine
.
model_name
]
@
app
.
get
(
'/v1/models'
)
@
app
.
get
(
'/v1/models'
,
dependencies
=
[
Depends
(
check_api_key
)]
)
def
available_models
():
"""Show available models."""
model_cards
=
[]
...
...
@@ -74,17 +111,149 @@ async def check_request(request) -> Optional[JSONResponse]:
return
ret
def
ip2id
(
host_ip
:
str
):
"""Convert host ip address to session id."""
if
'.'
in
host_ip
:
# IPv4
return
int
(
host_ip
.
replace
(
'.'
,
''
)[
-
8
:])
if
':'
in
host_ip
:
# IPv6
return
int
(
host_ip
.
replace
(
':'
,
''
)[
-
8
:],
16
)
print
(
'Warning, could not get session id from ip, set it 0'
)
return
0
@
app
.
post
(
'/v1/chat/completions_qos'
)
async
def
chat_completions_v1_qos
(
request
:
ChatCompletionRequestQos
,
raw_request
:
Request
=
None
):
"""Completion API similar to OpenAI's API.
Refer to `https://platform.openai.com/docs/api-reference/chat/create`
for the API specification.
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
model_name
=
request
.
model
request_id
=
str
(
request
.
session_id
)
created_time
=
int
(
time
.
time
())
if
VariableInterface
.
qos_engine
is
None
:
return
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
'cannot parse qos engine config, this api is not work'
)
result_generator
=
await
VariableInterface
.
qos_engine
.
generate_with_qos
(
request
)
if
result_generator
is
None
:
return
create_error_response
(
HTTPStatus
.
INTERNAL_SERVER_ERROR
,
'Failed to generate completions'
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
finish_reason
:
Optional
[
str
]
=
None
,
)
->
str
:
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
role
=
'assistant'
,
content
=
text
),
finish_reason
=
finish_reason
,
)
response
=
ChatCompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
response_json
=
response
.
model_dump_json
()
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
# First chunk with role
for
i
in
range
(
request
.
n
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
role
=
'assistant'
),
finish_reason
=
None
,
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
'data:
{
data
}
\n\n
'
async
for
res
in
result_generator
:
response_json
=
create_stream_response_json
(
index
=
0
,
text
=
res
.
response
,
)
yield
f
'data:
{
response_json
}
\n\n
'
yield
'data: [DONE]
\n\n
'
# Streaming response
if
request
.
stream
:
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
'text/event-stream'
)
# Non-streaming response
final_res
=
None
text
=
''
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
final_res
=
res
text
+=
res
.
response
assert
final_res
is
not
None
choices
=
[]
choice_data
=
ChatCompletionResponseChoice
(
index
=
0
,
message
=
ChatMessage
(
role
=
'assistant'
,
content
=
text
),
finish_reason
=
final_res
.
finish_reason
,
)
choices
.
append
(
choice_data
)
total_tokens
=
sum
([
final_res
.
history_token_len
,
final_res
.
input_token_len
,
final_res
.
generate_token_len
])
usage
=
UsageInfo
(
prompt_tokens
=
final_res
.
input_token_len
,
completion_tokens
=
final_res
.
generate_token_len
,
total_tokens
=
total_tokens
,
)
response
=
ChatCompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
return
response
@
app
.
post
(
'/v1/chat/completions'
)
@
app
.
post
(
'/v1/chat/completions'
,
dependencies
=
[
Depends
(
check_api_key
)]
)
async
def
chat_completions_v1
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
=
None
):
"""Completion API similar to OpenAI's API.
...
...
@@ -94,7 +263,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format.
- messages: string prompt or chat history in OpenAI format. Chat history
example: `[{"role": "user", "content": "hi"}]`.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
...
...
@@ -102,13 +272,18 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- max_tokens (int
| None
): output token nums
. Default to None.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
...
...
@@ -116,8 +291,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
if
request
.
session_id
=
=
-
1
:
request
.
session_id
=
random
.
randint
(
1
,
10086
)
VariableInterface
.
session_id
+
=
1
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
...
...
@@ -126,18 +301,26 @@ async def chat_completions_v1(request: ChatCompletionRequest,
request_id
=
str
(
request
.
session_id
)
created_time
=
int
(
time
.
time
())
if
isinstance
(
request
.
stop
,
str
):
request
.
stop
=
[
request
.
stop
]
gen_config
=
GenerationConfig
(
max_new_tokens
=
request
.
max_tokens
,
top_k
=
request
.
top_k
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
stop_words
=
request
.
stop
,
skip_special_tokens
=
request
.
skip_special_tokens
)
result_generator
=
VariableInterface
.
async_engine
.
generate
(
request
.
messages
,
request
.
session_id
,
True
,
# always use stream to enable batching
gen_config
=
gen_config
,
stream_response
=
True
,
# always use stream to enable batching
sequence_start
=
True
,
sequence_end
=
True
,
request_output_len
=
request
.
max_tokens
if
request
.
max_tokens
else
512
,
stop
=
request
.
stop
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
do_preprocess
=
not
isinstance
(
request
.
messages
,
str
),
# text completion for string input
)
...
...
@@ -196,7 +379,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
final_res
=
res
...
...
@@ -230,7 +414,155 @@ async def chat_completions_v1(request: ChatCompletionRequest,
return
response
@
app
.
post
(
'/v1/completions'
)
@
app
.
post
(
'/v1/completions_qos'
)
async
def
completions_v1_qos
(
request
:
CompletionRequestQos
,
raw_request
:
Request
=
None
):
"""Completion API similar to OpenAI's API.
Go to `https://platform.openai.com/docs/api-reference/completions/create`
for the API specification.
The request should be a JSON object with the following fields:
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
model_name
=
request
.
model
request_id
=
str
(
request
.
session_id
)
created_time
=
int
(
time
.
time
())
if
isinstance
(
request
.
prompt
,
str
):
request
.
prompt
=
[
request
.
prompt
]
if
VariableInterface
.
qos_engine
is
None
:
return
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
'cannot parse qos engine config, this api is not work'
)
generators
=
await
VariableInterface
.
qos_engine
.
generate_with_qos
(
request
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
finish_reason
:
Optional
[
str
]
=
None
,
)
->
str
:
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
text
=
text
,
finish_reason
=
finish_reason
,
)
response
=
CompletionStreamResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
[
choice_data
],
)
response_json
=
response
.
model_dump_json
()
return
response_json
async
def
completion_stream_generator
()
->
AsyncGenerator
[
str
,
None
]:
# First chunk with role
for
generator
in
generators
:
for
i
in
range
(
request
.
n
):
choice_data
=
CompletionResponseStreamChoice
(
index
=
i
,
text
=
''
,
finish_reason
=
None
,
)
chunk
=
CompletionStreamResponse
(
id
=
request_id
,
choices
=
[
choice_data
],
model
=
model_name
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
'data:
{
data
}
\n\n
'
async
for
res
in
generator
:
response_json
=
create_stream_response_json
(
index
=
0
,
text
=
res
.
response
,
)
yield
f
'data:
{
response_json
}
\n\n
'
yield
'data: [DONE]
\n\n
'
# Streaming response
if
request
.
stream
:
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
'text/event-stream'
)
# Non-streaming response
usage
=
UsageInfo
()
choices
=
[]
async
def
_inner_call
(
i
,
generator
):
final_res
=
None
text
=
''
async
for
res
in
generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
final_res
=
res
text
+=
res
.
response
assert
final_res
is
not
None
choice_data
=
CompletionResponseChoice
(
index
=
0
,
text
=
text
,
finish_reason
=
final_res
.
finish_reason
,
)
choices
.
append
(
choice_data
)
total_tokens
=
sum
([
final_res
.
history_token_len
,
final_res
.
input_token_len
,
final_res
.
generate_token_len
])
usage
.
prompt_tokens
+=
final_res
.
input_token_len
usage
.
completion_tokens
+=
final_res
.
generate_token_len
usage
.
total_tokens
+=
total_tokens
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
len
(
generators
))])
response
=
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
choices
=
choices
,
usage
=
usage
,
)
return
response
@
app
.
post
(
'/v1/completions'
,
dependencies
=
[
Depends
(
check_api_key
)])
async
def
completions_v1
(
request
:
CompletionRequest
,
raw_request
:
Request
=
None
):
"""Completion API similar to OpenAI's API.
...
...
@@ -242,7 +574,7 @@ async def completions_v1(request: CompletionRequest,
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- max_tokens (int): output token nums
. Default to 16.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
...
...
@@ -253,18 +585,23 @@ async def completions_v1(request: CompletionRequest,
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
if
request
.
session_id
=
=
-
1
:
request
.
session_id
=
random
.
randint
(
1
,
10086
)
VariableInterface
.
session_id
+
=
1
request
.
session_id
=
VariableInterface
.
session_id
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
...
...
@@ -274,21 +611,26 @@ async def completions_v1(request: CompletionRequest,
created_time
=
int
(
time
.
time
())
if
isinstance
(
request
.
prompt
,
str
):
request
.
prompt
=
[
request
.
prompt
]
if
isinstance
(
request
.
stop
,
str
):
request
.
stop
=
[
request
.
stop
]
gen_config
=
GenerationConfig
(
max_new_tokens
=
request
.
max_tokens
if
request
.
max_tokens
else
512
,
top_k
=
request
.
top_k
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
stop_words
=
request
.
stop
,
skip_special_tokens
=
request
.
skip_special_tokens
)
generators
=
[]
for
i
in
range
(
len
(
request
.
prompt
)):
result_generator
=
VariableInterface
.
async_engine
.
generate
(
request
.
prompt
[
i
],
request
.
session_id
+
i
,
True
,
# always use stream to enable batching
gen_config
=
gen_config
,
stream_response
=
True
,
# always use stream to enable batching
sequence_start
=
True
,
sequence_end
=
True
,
request_output_len
=
request
.
max_tokens
if
request
.
max_tokens
else
512
,
stop
=
False
,
top_p
=
request
.
top_p
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
do_preprocess
=
False
)
generators
.
append
(
result_generator
)
...
...
@@ -351,7 +693,8 @@ async def completions_v1(request: CompletionRequest,
async
for
res
in
generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
final_res
=
res
...
...
@@ -394,7 +737,7 @@ async def create_embeddings(request: EmbeddingsRequest,
'Unsupported by turbomind.'
)
@
app
.
post
(
'/v1/encode'
)
@
app
.
post
(
'/v1/encode'
,
dependencies
=
[
Depends
(
check_api_key
)]
)
async
def
encode
(
request
:
EncodeRequest
,
raw_request
:
Request
=
None
):
"""Encode prompts.
...
...
@@ -407,7 +750,7 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
def
encode
(
prompt
:
str
,
do_preprocess
:
bool
,
add_bos
:
bool
):
if
do_preprocess
:
prompt
=
VariableInterface
.
async_engine
.
model
.
get_prompt
(
prompt
=
VariableInterface
.
async_engine
.
chat_template
.
get_prompt
(
prompt
,
sequence_start
=
add_bos
)
input_ids
=
VariableInterface
.
async_engine
.
tokenizer
.
encode
(
prompt
,
add_bos
=
add_bos
)
...
...
@@ -425,12 +768,9 @@ async def encode(request: EncodeRequest, raw_request: Request = None):
return
EncodeResponse
(
input_ids
=
encoded
,
length
=
length
)
@
app
.
post
(
'/generate'
,
tags
=
[
'deprecated'
],
description
=
'please use /v1/chat/interactive'
)
@
app
.
post
(
'/v1/chat/interactive'
)
async
def
chat_interactive_v1
(
request
:
GenerateRequest
,
raw_request
:
Request
=
None
):
@
app
.
post
(
'/v1/chat/interactive_qos'
)
async
def
chat_interactive_v1_qos
(
request
:
GenerateRequestQos
,
raw_request
:
Request
=
None
):
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
...
...
@@ -456,33 +796,134 @@ async def chat_interactive_v1(request: GenerateRequest,
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
- user_id (str): for qos; if not specified, will set to "default"
"""
if
request
.
session_id
==
-
1
:
request
.
session_id
=
random
.
randint
(
10087
,
23333
)
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
if
VariableInterface
.
qos_engine
is
None
:
return
create_error_response
(
HTTPStatus
.
NOT_FOUND
,
'cannot parse qos engine config, this api is not work'
)
generation
=
await
VariableInterface
.
qos_engine
.
generate_with_qos
(
request
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
out
in
generation
:
chunk
=
GenerateResponse
(
text
=
out
.
response
,
tokens
=
out
.
generate_token_len
,
input_tokens
=
out
.
input_token_len
,
history_tokens
=
out
.
history_token_len
,
finish_reason
=
out
.
finish_reason
)
data
=
chunk
.
model_dump_json
()
yield
f
'
{
data
}
\n
'
if
request
.
stream
:
return
StreamingResponse
(
stream_results
(),
media_type
=
'text/event-stream'
)
else
:
ret
=
{}
text
=
''
tokens
=
0
finish_reason
=
None
async
for
out
in
generation
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
VariableInterface
.
qos_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
text
+=
out
.
response
tokens
=
out
.
generate_token_len
finish_reason
=
out
.
finish_reason
ret
=
{
'text'
:
text
,
'tokens'
:
tokens
,
'finish_reason'
:
finish_reason
}
return
JSONResponse
(
ret
)
@
app
.
post
(
'/v1/chat/interactive'
,
dependencies
=
[
Depends
(
check_api_key
)])
async
def
chat_interactive_v1
(
request
:
GenerateRequest
,
raw_request
:
Request
=
None
):
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
`interactive_mode = True`.
- On normal mode, no chat history is kept on the server. Set
`interactive_mode = False`.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- session_id: determine which instance will be called. If not specified
with a value other than -1, using random value directly.
- interactive_mode (bool): turn on interactive mode or not. On interactive
mode, session history is kept on the server (and vice versa).
- stream: whether to stream the results or not.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
- request_output_len (int): output token nums. If not specified, will use
maximum possible number for a session.
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- temperature (float): to modulate the next token probability
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
"""
if
request
.
cancel
:
if
request
.
session_id
!=
-
1
:
await
VariableInterface
.
async_engine
.
stop_session
(
request
.
session_id
)
return
{
'text'
:
''
,
'tokens'
:
0
,
'input_tokens'
:
0
,
'history_tokens'
:
0
,
'finish_reason'
:
'stop'
}
else
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'please set a session_id to cancel a request'
)
if
request
.
session_id
==
-
1
:
VariableInterface
.
session_id
+=
1
request
.
session_id
=
VariableInterface
.
session_id
async_engine
=
VariableInterface
.
async_engine
sequence_start
=
async_engine
.
id2step
.
get
(
str
(
request
.
session_id
),
0
)
==
0
sequence_end
=
not
request
.
interactive_mode
if
isinstance
(
request
.
stop
,
str
):
request
.
stop
=
[
request
.
stop
]
gen_config
=
GenerationConfig
(
max_new_tokens
=
request
.
request_output_len
,
top_p
=
request
.
top_p
,
top_k
=
request
.
top_k
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
,
stop_words
=
request
.
stop
,
skip_special_tokens
=
request
.
skip_special_tokens
)
generation
=
async_engine
.
generate
(
request
.
prompt
,
request
.
session_id
,
gen_config
=
gen_config
,
stream_response
=
True
,
# always use stream to enable batching
sequence_start
=
sequence_start
,
sequence_end
=
sequence_end
,
request_output_len
=
request
.
request_output_len
,
top_p
=
request
.
top_p
,
top_k
=
request
.
top_k
,
stop
=
request
.
stop
,
temperature
=
request
.
temperature
,
repetition_penalty
=
request
.
repetition_penalty
,
ignore_eos
=
request
.
ignore_eos
)
sequence_end
=
sequence_end
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
out
in
generation
:
chunk
=
GenerateResponse
(
text
=
out
.
response
,
tokens
=
out
.
generate_token_len
,
input_tokens
=
out
.
input_token_len
,
history_tokens
=
out
.
history_token_len
,
finish_reason
=
out
.
finish_reason
)
data
=
chunk
.
model_dump_json
()
yield
f
'
{
data
}
\n
'
...
...
@@ -493,32 +934,46 @@ async def chat_interactive_v1(request: GenerateRequest,
else
:
ret
=
{}
text
=
''
tokens
=
0
tokens
,
input_tokens
,
history_tokens
=
0
,
0
,
0
finish_reason
=
None
async
for
out
in
generation
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
async_engine
.
stop_session
(
request
.
session_id
)
await
async_engine
.
stop_session
(
request
.
session_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
'Client disconnected'
)
text
+=
out
.
response
tokens
=
out
.
generate_token_len
input_tokens
=
out
.
input_token_len
history_tokens
=
out
.
history_token_len
finish_reason
=
out
.
finish_reason
ret
=
{
'text'
:
text
,
'tokens'
:
tokens
,
'finish_reason'
:
finish_reason
}
ret
=
{
'text'
:
text
,
'tokens'
:
tokens
,
'input_tokens'
:
input_tokens
,
'history_tokens'
:
history_tokens
,
'finish_reason'
:
finish_reason
}
return
JSONResponse
(
ret
)
def
serve
(
model_path
:
str
,
model_name
:
Optional
[
str
]
=
None
,
backend
:
Literal
[
'turbomind'
,
'pytorch'
]
=
'turbomind'
,
backend_config
:
Optional
[
Union
[
PytorchEngineConfig
,
TurbomindEngineConfig
]]
=
None
,
chat_template_config
:
Optional
[
ChatTemplateConfig
]
=
None
,
server_name
:
str
=
'0.0.0.0'
,
server_port
:
int
=
23333
,
instance_num
:
int
=
64
,
tp
:
int
=
1
,
allow_origins
:
List
[
str
]
=
[
'*'
],
allow_credentials
:
bool
=
True
,
allow_methods
:
List
[
str
]
=
[
'*'
],
allow_headers
:
List
[
str
]
=
[
'*'
],
log_level
:
str
=
'ERROR'
,
api_keys
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
ssl
:
bool
=
False
,
qos_config_path
:
str
=
''
,
**
kwargs
):
"""An example to perform model inference through the command line
interface.
...
...
@@ -534,22 +989,34 @@ def serve(model_path: str,
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "
I
ntern
LM
/internlm-chat-7b",
on huggingface.co, such as "
i
ntern
lm
/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "InternLM/internlm-chat-7b"
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
server_name (str): host ip for serving
server_port (int): server port
instance_num (int): number of instances of turbomind model
tp (int): tensor parallel
allow_origins (List[str]): a list of allowed origins for CORS
allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
log_level(str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG]
api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as
a single api_key. Default to None, which means no api key applied.
ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.
qos_config_path (str): qos policy config path
"""
# noqa E501
os
.
environ
[
'TM_LOG_LEVEL'
]
=
log_level
if
os
.
getenv
(
'TM_LOG_LEVEL'
)
is
None
:
os
.
environ
[
'TM_LOG_LEVEL'
]
=
log_level
logger
=
get_logger
(
'lmdeploy'
)
logger
.
setLevel
(
log_level
)
if
allow_origins
:
app
.
add_middleware
(
...
...
@@ -559,16 +1026,55 @@ def serve(model_path: str,
allow_methods
=
allow_methods
,
allow_headers
=
allow_headers
,
)
if
api_keys
is
not
None
:
if
isinstance
(
api_keys
,
str
):
api_keys
=
api_keys
.
split
(
','
)
VariableInterface
.
api_keys
=
api_keys
ssl_keyfile
,
ssl_certfile
,
http_or_https
=
None
,
None
,
'http'
if
ssl
:
ssl_keyfile
=
os
.
environ
[
'SSL_KEYFILE'
]
ssl_certfile
=
os
.
environ
[
'SSL_CERTFILE'
]
http_or_https
=
'https'
pipeline_type
,
pipeline_class
=
get_task
(
model_path
)
VariableInterface
.
async_engine
=
pipeline_class
(
model_path
=
model_path
,
model_name
=
model_name
,
backend
=
backend
,
backend_config
=
backend_config
,
chat_template_config
=
chat_template_config
,
tp
=
tp
,
**
kwargs
)
if
qos_config_path
:
try
:
with
open
(
qos_config_path
,
'r'
)
as
file
:
qos_config_str
=
file
.
read
()
VariableInterface
.
qos_engine
=
QosEngine
(
qos_tag
=
qos_config_str
,
engine
=
VariableInterface
.
async_engine
,
**
kwargs
)
VariableInterface
.
qos_engine
.
start
()
except
FileNotFoundError
:
VariableInterface
.
qos_engine
=
None
else
:
# hide qos functions if not applied
for
i
in
range
(
len
(
app
.
router
.
routes
)):
if
'qos'
in
app
.
router
.
routes
[
i
].
path
:
app
.
router
.
routes
[
i
].
include_in_schema
=
False
VariableInterface
.
async_engine
=
AsyncEngine
(
model_path
=
model_path
,
model_name
=
model_name
,
instance_num
=
instance_num
,
tp
=
tp
,
**
kwargs
)
for
i
in
range
(
3
):
print
(
f
'HINT: Please open
\033
[93m
\033
[1mhttp://
{
server_name
}
:'
f
'
{
server_port
}
\033
[0m in a browser for detailed api usage!!!'
)
uvicorn
.
run
(
app
=
app
,
host
=
server_name
,
port
=
server_port
,
log_level
=
'info'
)
print
(
f
'HINT: Please open
\033
[93m
\033
[1m
{
http_or_https
}
://'
f
'
{
server_name
}
:
{
server_port
}
\033
[0m in a browser for detailed api'
' usage!!!'
)
uvicorn
.
run
(
app
=
app
,
host
=
server_name
,
port
=
server_port
,
log_level
=
'info'
,
ssl_keyfile
=
ssl_keyfile
,
ssl_certfile
=
ssl_certfile
)
if
__name__
==
'__main__'
:
...
...
lmdeploy/serve/openai/protocol.py
View file @
d7117b95
...
...
@@ -55,23 +55,48 @@ class UsageInfo(BaseModel):
completion_tokens
:
Optional
[
int
]
=
0
class
ChatCompletionRequest
(
BaseModel
):
class
ChatCompletionRequest
Qos
(
BaseModel
):
"""Chat completion request."""
model
:
str
messages
:
Union
[
str
,
List
[
Dict
[
str
,
str
]]]
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
512
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
examples
=
[
None
])
stop
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
user_id
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
top_k
:
Optional
[
int
]
=
40
class
ChatCompletionRequest
(
BaseModel
):
"""Chat completion request."""
model
:
str
# yapf: disable
messages
:
Union
[
str
,
List
[
Dict
[
str
,
Any
]]]
=
Field
(
examples
=
[[{
'role'
:
'user'
,
'content'
:
'hi'
}]])
# noqa
temperature
:
Optional
[
float
]
=
0.7
top_p
:
Optional
[
float
]
=
1.0
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
examples
=
[
None
])
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default
=
None
,
examples
=
[
None
])
# noqa
# yapf: enable
stream
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
skip_special_tokens
:
Optional
[
bool
]
=
True
top_k
:
Optional
[
int
]
=
40
class
ChatMessage
(
BaseModel
):
...
...
@@ -120,6 +145,31 @@ class ChatCompletionStreamResponse(BaseModel):
class
CompletionRequest
(
BaseModel
):
"""Completion request."""
model
:
str
prompt
:
Union
[
str
,
List
[
Any
]]
suffix
:
Optional
[
str
]
=
None
temperature
:
Optional
[
float
]
=
0.7
n
:
Optional
[
int
]
=
1
max_tokens
:
Optional
[
int
]
=
16
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default
=
None
,
examples
=
[
None
])
stream
:
Optional
[
bool
]
=
False
top_p
:
Optional
[
float
]
=
1.0
logprobs
:
Optional
[
int
]
=
None
echo
:
Optional
[
bool
]
=
False
presence_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
skip_special_tokens
:
Optional
[
bool
]
=
True
top_k
:
Optional
[
int
]
=
40
# for opencompass
class
CompletionRequestQos
(
BaseModel
):
"""Completion request."""
model
:
str
prompt
:
Union
[
str
,
List
[
Any
]]
...
...
@@ -136,9 +186,11 @@ class CompletionRequest(BaseModel):
frequency_penalty
:
Optional
[
float
]
=
0.0
user
:
Optional
[
str
]
=
None
# additional argument of lmdeploy
top_k
:
Optional
[
int
]
=
40
repetition_penalty
:
Optional
[
float
]
=
1.0
session_id
:
Optional
[
int
]
=
-
1
ignore_eos
:
Optional
[
bool
]
=
False
user_id
:
Optional
[
str
]
=
None
class
CompletionResponseChoice
(
BaseModel
):
...
...
@@ -205,6 +257,25 @@ class EncodeResponse(BaseModel):
class
GenerateRequest
(
BaseModel
):
"""Generate request."""
prompt
:
Union
[
str
,
List
[
Dict
[
str
,
Any
]]]
session_id
:
int
=
-
1
interactive_mode
:
bool
=
False
stream
:
bool
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default
=
None
,
examples
=
[
None
])
request_output_len
:
Optional
[
int
]
=
Field
(
default
=
None
,
examples
=
[
None
])
# noqa
top_p
:
float
=
0.8
top_k
:
int
=
40
temperature
:
float
=
0.8
repetition_penalty
:
float
=
1.0
ignore_eos
:
bool
=
False
skip_special_tokens
:
Optional
[
bool
]
=
True
cancel
:
Optional
[
bool
]
=
False
# cancel a responding request
class
GenerateRequestQos
(
BaseModel
):
"""Generate request."""
prompt
:
Union
[
str
,
List
[
Dict
[
str
,
str
]]]
session_id
:
int
=
-
1
...
...
@@ -217,10 +288,13 @@ class GenerateRequest(BaseModel):
temperature
:
float
=
0.8
repetition_penalty
:
float
=
1.0
ignore_eos
:
bool
=
False
user_id
:
Optional
[
str
]
=
None
class
GenerateResponse
(
BaseModel
):
"""Generate response."""
text
:
str
tokens
:
int
input_tokens
:
int
history_tokens
:
int
finish_reason
:
Optional
[
Literal
[
'stop'
,
'length'
]]
=
None
lmdeploy/serve/turbomind/chatbot.py
View file @
d7117b95
...
...
@@ -18,7 +18,7 @@ from tritonclient.grpc.service_pb2 import ModelInferResponse
from
lmdeploy.model
import
MODELS
from
lmdeploy.serve.turbomind.utils
import
(
Postprocessor
,
Preprocessor
,
prepare_tensor
)
from
lmdeploy.utils
import
filter_suffix
from
lmdeploy.utils
import
filter_suffix
,
get_logger
@
dataclass
...
...
@@ -51,13 +51,6 @@ def stream_callback(que, result, error):
que
.
put
(
result
.
get_response
(
as_json
=
True
))
def
get_logger
(
log_file
=
None
,
log_level
=
logging
.
INFO
):
"""Return the logger."""
from
lmdeploy.utils
import
get_logger
logger
=
get_logger
(
'service.ft'
,
log_file
=
log_file
,
log_level
=
log_level
)
return
logger
class
Chatbot
:
"""Chatbot for LLaMA series models with turbomind as inference engine.
...
...
@@ -75,6 +68,10 @@ class Chatbot:
ignore_eos
:
bool
=
False
,
log_level
:
int
=
logging
.
INFO
,
display
:
bool
=
False
,
top_p
:
float
=
1.0
,
top_k
:
int
=
1
,
temperature
:
float
=
0.8
,
repetition_penalty
:
float
=
1.0
,
**
model_kwargs
):
self
.
tritonserver_addr
=
tritonserver_addr
self
.
model_name
=
model_name
...
...
@@ -97,10 +94,10 @@ class Chatbot:
self
.
eos_id
=
-
1
self
.
cfg
=
mmengine
.
Config
(
dict
(
session_len
=
self
.
model
.
session_len
,
top_p
=
self
.
model
.
top_p
,
top_k
=
self
.
model
.
top_k
,
temperature
=
self
.
model
.
temperature
,
repetition_penalty
=
self
.
model
.
repetition_penalty
,
top_p
=
top_p
,
top_k
=
top_k
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
stop_words
=
stop_words
,
bad_words
=
bad_words
))
self
.
log_level
=
log_level
...
...
@@ -113,6 +110,7 @@ class Chatbot:
request_output_len
:
int
=
None
,
sequence_start
:
bool
=
False
,
sequence_end
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
*
args
,
**
kwargs
):
"""Start a new round conversion of a session.
...
...
@@ -124,13 +122,15 @@ class Chatbot:
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
iterator: The generated content by chatbot
"""
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'session
{
session_id
}
, request_id
{
request_id
}
, '
f
'request_output_len
{
request_output_len
}
'
)
...
...
@@ -149,11 +149,13 @@ class Chatbot:
self
.
cfg
.
update
(
**
kwargs
)
self
.
_session
.
prompt
=
self
.
_get_prompt
(
prompt
,
sequence_start
)
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
,
self
.
_session
.
prompt
,
request_output_len
,
sequence_start
,
sequence_end
):
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
,
self
.
_session
.
prompt
,
request_output_len
,
sequence_start
,
sequence_end
,
skip_special_tokens
=
skip_special_tokens
):
if
status
==
StatusCode
.
TRITON_STREAM_END
:
# remove stop_words
res
=
filter_suffix
(
res
,
self
.
model
.
stop_words
)
if
status
.
value
<
0
:
...
...
@@ -180,7 +182,7 @@ class Chatbot:
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'end session:
{
session_id
}
'
)
if
self
.
_session
is
None
:
...
...
@@ -218,7 +220,7 @@ class Chatbot:
"""
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'cancel session:
{
session_id
}
'
)
if
self
.
_session
is
None
:
...
...
@@ -267,7 +269,7 @@ class Chatbot:
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'resume session:
{
session_id
}
'
)
if
self
.
_session
is
None
:
...
...
@@ -301,6 +303,7 @@ class Chatbot:
request_output_len
:
int
=
None
,
sequence_start
:
bool
=
False
,
sequence_end
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
*
args
,
**
kwargs
):
"""Start a new round conversion of a session. Return the chat
...
...
@@ -313,6 +316,8 @@ class Chatbot:
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
...
...
@@ -320,7 +325,7 @@ class Chatbot:
assert
isinstance
(
session_id
,
int
),
\
f
'INT session id is required, but got
{
type
(
session_id
)
}
'
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'session
{
session_id
}
, request_id
{
request_id
}
, '
f
'request_output_len
{
request_output_len
}
'
)
...
...
@@ -338,11 +343,13 @@ class Chatbot:
self
.
_session
.
prompt
=
self
.
_get_prompt
(
prompt
,
sequence_start
)
status
,
res
,
tokens
=
None
,
''
,
0
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
,
self
.
_session
.
prompt
,
request_output_len
,
sequence_start
,
sequence_end
):
for
status
,
res
,
tokens
in
self
.
_stream_infer
(
self
.
_session
,
self
.
_session
.
prompt
,
request_output_len
,
sequence_start
,
sequence_end
,
skip_special_tokens
=
skip_special_tokens
):
if
status
.
value
<
0
:
break
if
status
==
StatusCode
.
TRITON_STREAM_END
:
# remove stop_words
...
...
@@ -420,6 +427,7 @@ class Chatbot:
request_output_len
:
int
=
512
,
sequence_start
:
bool
=
True
,
sequence_end
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
cancel
:
bool
=
False
):
"""communicate with inference server to chat, or cancel a session, or
end a session.
...
...
@@ -431,10 +439,12 @@ class Chatbot:
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
cancel (bool): indicator for cancelling the session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
tuple: status, text, generated token number
"""
logger
=
get_logger
(
log_level
=
self
.
log_level
)
logger
=
get_logger
(
'service.ft'
,
log_level
=
self
.
log_level
)
logger
.
info
(
f
'session
{
session
.
session_id
}
, '
f
'request id
{
session
.
request_id
}
, '
f
'request_output_len
{
request_output_len
}
, '
...
...
@@ -498,7 +508,8 @@ class Chatbot:
producer
.
start
()
for
status
,
res
,
n_token
in
self
.
stream_consumer
(
self
.
postprocess
,
que
,
session
,
input_tokens
,
preseq_length
,
cancel
,
logger
,
self
.
display
,
self
.
eos_id
):
cancel
,
logger
,
self
.
display
,
self
.
eos_id
,
skip_special_tokens
):
yield
status
,
res
,
n_token
producer
.
join
()
...
...
@@ -591,7 +602,8 @@ class Chatbot:
@
staticmethod
def
stream_consumer
(
postprocess
,
res_queue
,
session
,
n_input_token
,
preseq_length
,
cancel
,
logger
,
display
,
eos_id
):
preseq_length
,
cancel
,
logger
,
display
,
eos_id
,
skip_special_tokens
):
"""Consume the response from the triton inference server.
Args:
...
...
@@ -605,11 +617,15 @@ class Chatbot:
logger (util.Logger):
display (bool): display the text in the consolo interface or not
eos_id (int): eos token id
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Yields:
tuple: status, text, generated token number
"""
status
,
res
,
n_token
=
None
,
''
,
0
output_ids
=
np
.
zeros
((
1
,
1
,
0
),
dtype
=
np
.
uint32
)
text
=
''
while
True
:
result
=
res_queue
.
get
()
if
result
is
None
:
...
...
@@ -648,7 +664,8 @@ class Chatbot:
output_ids
=
output_ids
[:,
:,
:
-
1
]
output_str
=
postprocess
(
output_ids
,
np
.
array
([[
n_token
]],
dtype
=
np
.
uint32
))
output_ids
,
np
.
array
([[
n_token
]],
dtype
=
np
.
uint32
),
np
.
array
([[
int
(
skip_special_tokens
)]],
dtype
=
np
.
int32
))
text
=
output_str
[
0
].
decode
()
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
...
...
lmdeploy/serve/turbomind/triton_models/postprocessing/1/model.py
View file @
d7117b95
...
...
@@ -84,10 +84,13 @@ class TritonPythonModel:
request
,
'TOKENS_BATCH'
).
as_numpy
()
sequence_length
=
pb_utils
.
get_input_tensor_by_name
(
request
,
'sequence_length'
).
as_numpy
()
skip_special_tokens
=
pb_utils
.
get_input_tensor_by_name
(
request
,
'skip_special_tokens'
).
as_numpy
()
# Postprocessing output data.
outputs
=
self
.
_postprocessing
(
tokens_batch
.
tolist
(),
sequence_length
)
sequence_length
,
skip_special_tokens
)
# Create output tensors. You need pb_utils.Tensor
# objects to create pb_utils.InferenceResponse.
...
...
@@ -118,12 +121,16 @@ class TritonPythonModel:
"""
print
(
'Cleaning up...'
)
def
_postprocessing
(
self
,
tokens_batch
,
sequence_length
):
def
_postprocessing
(
self
,
tokens_batch
,
sequence_length
,
skip_special_tokens
):
"""decode token ids into texts."""
outputs
=
[]
for
beam_tokens
,
beam_len
in
zip
(
tokens_batch
,
sequence_length
):
for
tokens
,
_len
in
zip
(
beam_tokens
,
beam_len
):
output
=
self
.
tokenizer
.
decode
(
tokens
,
_len
)
for
beam_tokens
,
beam_len
,
beam_skip_special
in
zip
(
tokens_batch
,
sequence_length
,
skip_special_tokens
):
for
tokens
,
_len
,
skip_special
in
zip
(
beam_tokens
,
beam_len
,
beam_skip_special
):
output
=
self
.
tokenizer
.
decode
(
tokens
,
_len
,
skip_special_tokens
=
bool
(
skip_special
))
output
=
output
.
encode
(
'utf8'
)
outputs
.
append
(
output
)
return
outputs
lmdeploy/serve/turbomind/triton_models/postprocessing/config.pbtxt
View file @
d7117b95
...
...
@@ -11,6 +11,11 @@ input [
name: "sequence_length"
data_type: TYPE_UINT32
dims: [ -1 ]
},
{
name: "skip_special_tokens"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
output [
...
...
lmdeploy/serve/turbomind/utils.py
View file @
d7117b95
...
...
@@ -72,22 +72,29 @@ class Postprocessor:
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
infer
(
*
args
,
**
kwargs
)
def
infer
(
self
,
output_ids
:
np
.
ndarray
,
seqlen
:
np
.
ndarray
):
def
infer
(
self
,
output_ids
:
np
.
ndarray
,
seqlen
:
np
.
ndarray
,
skip_special_tokens
:
bool
=
True
):
"""De-tokenize tokens for text.
Args:
output_ids(np.ndarray): tokens' id
seqlen(np.ndarray): sequence length
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
Returns:
str: decoded tokens
"""
inputs
=
[
prepare_tensor
(
'TOKENS_BATCH'
,
output_ids
),
prepare_tensor
(
'sequence_length'
,
seqlen
)
prepare_tensor
(
'sequence_length'
,
seqlen
),
prepare_tensor
(
'skip_special_tokens'
,
skip_special_tokens
)
]
inputs
[
0
].
set_data_from_numpy
(
output_ids
)
inputs
[
1
].
set_data_from_numpy
(
seqlen
)
inputs
[
2
].
set_data_from_numpy
(
skip_special_tokens
)
model_name
=
'postprocessing'
with
grpcclient
.
InferenceServerClient
(
self
.
tritonserver_addr
)
\
as
client
:
...
...
lmdeploy/tokenizer.py
View file @
d7117b95
# Copyright (c) OpenMMLab. All rights reserved.
import
json
import
os
import
os.path
as
osp
from
typing
import
Optional
,
Sequence
,
Union
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
lmdeploy.utils
import
get_logger
# this file will be copied to triton server, make sure all
# importing are starting from the package root lmdeploy
@
dataclass
class
DetokenizeState
:
"""A state collection of incrementally detekenization.
Args:
ids_offset (int): offset to all input ids. In LMDeploy, the output
ids length is not one by one. It could be random by random.
prev_tokens (List[str] | None): for incrementally decoding.
Default to None, which means the first round.
prefix_offset (int): the start index of tokens to be converted to
string (prev + new tokens). Default to 0 for the first round.
read_offset (int): the end index of tokens to be converted to
string (prev token). Default to 0 for the first round.
"""
ids_offset
:
int
=
0
prev_tokens
:
Optional
[
List
[
str
]]
=
None
prefix_offset
:
int
=
0
read_offset
:
int
=
0
def
as_tuple
(
self
)
->
Tuple
:
"""Return a tuple of states."""
return
(
self
.
ids_offset
,
self
.
prev_tokens
,
self
.
prefix_offset
,
self
.
read_offset
)
class
SentencePieceTokenizer
:
"""Tokenizer of sentencepiece.
...
...
@@ -18,6 +49,12 @@ class SentencePieceTokenizer:
from
sentencepiece
import
SentencePieceProcessor
self
.
model
=
SentencePieceProcessor
(
model_file
=
model_file
)
self
.
_prefix_space_tokens
=
None
# for stop words
self
.
_maybe_decode_bytes
:
bool
=
None
# TODO maybe lack a constant.py
self
.
_indexes_tokens_deque
=
deque
(
maxlen
=
10
)
self
.
max_indexes_num
=
5
self
.
logger
=
get_logger
(
'lmdeploy'
)
@
property
def
vocab_size
(
self
):
...
...
@@ -53,6 +90,27 @@ class SentencePieceTokenizer:
else
:
return
decoded
def
indexes_containing_token
(
self
,
token
:
str
):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for
_token
,
_indexes
in
self
.
_indexes_tokens_deque
:
if
token
==
_token
:
return
_indexes
if
token
==
' '
:
# ' ' is special
token
=
'▁'
vocab
=
self
.
model
.
IdToPiece
(
list
(
range
(
self
.
vocab_size
)))
indexes
=
[
i
for
i
,
voc
in
enumerate
(
vocab
)
if
token
in
voc
]
if
len
(
indexes
)
>
self
.
max_indexes_num
:
indexes
=
self
.
encode
(
token
,
add_bos
=
False
)[
-
1
:]
self
.
logger
.
warning
(
f
'There are too many(>
{
self
.
max_indexes_num
}
) possible '
f
'indexes may decoding
{
token
}
, we will use
{
indexes
}
only'
)
self
.
_indexes_tokens_deque
.
append
((
token
,
indexes
))
return
indexes
def
encode
(
self
,
s
:
str
,
add_bos
:
bool
=
True
,
**
kwargs
):
"""Tokenize a prompt.
...
...
@@ -63,13 +121,18 @@ class SentencePieceTokenizer:
"""
return
self
.
model
.
Encode
(
s
,
add_bos
=
add_bos
,
**
kwargs
)
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
):
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
**
kwargs
):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
offset (int): for incrementally decoding. Default to None, which
means not applied.
skip_special_tokens (boo): not used in SentencePieceTokenizer.
Returns:
str: text of decoding tokens
"""
...
...
@@ -81,6 +144,34 @@ class SentencePieceTokenizer:
out_string
=
self
.
_maybe_add_prefix_space
(
t
,
out_string
)
return
out_string
def
detokenize_incrementally
(
self
,
all_input_ids
:
Sequence
[
int
],
state
:
DetokenizeState
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
out_string
=
self
.
model
.
Decode
(
all_input_ids
)
if
state
.
prev_tokens
is
not
None
:
out_string
=
self
.
_maybe_add_prefix_space
(
all_input_ids
,
out_string
)
state
.
prev_tokens
=
[]
# not None for the above condition
return
out_string
,
state
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
"""Tokenize prompts.
...
...
@@ -106,20 +197,10 @@ class HuggingFaceTokenizer:
def
__init__
(
self
,
model_dir
:
str
):
from
transformers
import
AutoTokenizer
model_file
=
osp
.
join
(
model_dir
,
'tokenizer.model'
)
backend_tokenizer_file
=
osp
.
join
(
model_dir
,
'tokenizer.json'
)
model_file_exists
=
osp
.
exists
(
model_file
)
if
not
osp
.
exists
(
backend_tokenizer_file
)
and
model_file_exists
:
print
(
'WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.'
)
self
.
logger
=
get_logger
(
'lmdeploy'
)
self
.
model
=
AutoTokenizer
.
from_pretrained
(
model_dir
,
trust_remote_code
=
True
)
self
.
_prefix_space_tokens
=
None
# save tokenizer.json to reuse
if
not
osp
.
exists
(
backend_tokenizer_file
)
and
model_file_exists
:
if
hasattr
(
self
.
model
,
'backend_tokenizer'
):
if
os
.
access
(
model_dir
,
os
.
W_OK
):
self
.
model
.
backend_tokenizer
.
save
(
backend_tokenizer_file
)
if
self
.
model
.
eos_token_id
is
None
:
generation_config_file
=
osp
.
join
(
model_dir
,
...
...
@@ -131,11 +212,27 @@ class HuggingFaceTokenizer:
elif
hasattr
(
self
.
model
,
'eod_id'
):
# Qwen remote
self
.
model
.
eos_token_id
=
self
.
model
.
eod_id
# for stop words
self
.
_vocab_size_with_added
:
int
=
None
self
.
_maybe_decode_bytes
:
bool
=
None
# TODO maybe lack a constant.py
self
.
_indexes_tokens_deque
=
deque
(
maxlen
=
10
)
self
.
max_indexes_num
=
5
self
.
token2id
=
{}
@
property
def
vocab_size
(
self
):
"""vocabulary size."""
return
self
.
model
.
vocab_size
@
property
def
vocab_size_with_added
(
self
):
"""vocabulary size with added vocab."""
if
self
.
_vocab_size_with_added
is
not
None
:
return
self
.
_vocab_size_with_added
self
.
_vocab_size_with_added
=
len
(
self
.
model
.
get_vocab
())
return
self
.
_vocab_size_with_added
@
property
def
bos_token_id
(
self
):
"""begine of the sentence token id."""
...
...
@@ -159,7 +256,7 @@ class HuggingFaceTokenizer:
}
return
self
.
_prefix_space_tokens
def
_maybe_add_prefix_space
(
self
,
tokens
,
decoded
):
def
_maybe_add_prefix_space
(
self
,
tokens
:
List
[
int
]
,
decoded
:
str
):
"""maybe add prefix space for incremental decoding."""
if
len
(
tokens
)
and
not
decoded
.
startswith
(
' '
)
and
\
tokens
[
0
]
in
self
.
prefix_space_tokens
:
...
...
@@ -167,6 +264,66 @@ class HuggingFaceTokenizer:
else
:
return
decoded
@
property
def
maybe_decode_bytes
(
self
):
"""Check if self.model.convert_ids_to_tokens return not a str value."""
if
self
.
_maybe_decode_bytes
is
None
:
self
.
_maybe_decode_bytes
=
False
vocab
=
self
.
model
.
convert_ids_to_tokens
(
list
(
range
(
self
.
vocab_size
)))
for
tok
in
vocab
:
if
not
isinstance
(
tok
,
str
):
self
.
_maybe_decode_bytes
=
True
break
return
self
.
_maybe_decode_bytes
def
indexes_containing_token
(
self
,
token
:
str
):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
# traversing vocab is time consuming, can not be accelerated with
# multi threads (computation) or multi process (can't pickle tokenizer)
# so, we maintain latest 10 stop words and return directly if matched
for
_token
,
_indexes
in
self
.
_indexes_tokens_deque
:
if
token
==
_token
:
return
_indexes
if
self
.
token2id
==
{}:
# decode is slower than convert_ids_to_tokens
if
self
.
maybe_decode_bytes
:
try
:
self
.
token2id
=
{
self
.
model
.
decode
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)
}
except
Exception
as
e
:
# qwen-vl
assert
str
(
e
)
==
'Unclosed image token'
else
:
self
.
token2id
=
{
self
.
model
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)
}
if
token
==
' '
:
# ' ' is special
token
=
'▁'
indexes
=
[
i
for
_token
,
i
in
self
.
token2id
.
items
()
if
token
in
_token
]
if
len
(
indexes
)
>
self
.
max_indexes_num
:
# multiple id decode to same token
indexes
=
[
i
for
i
in
indexes
if
self
.
decode
([
i
])
==
token
]
indexes
=
indexes
[:
self
.
max_indexes_num
]
self
.
logger
.
warning
(
f
'There are too many(>
{
self
.
max_indexes_num
}
) possible '
f
'indexes may decoding
{
token
}
, we will use
{
indexes
}
only'
)
# there might be token id that exceeds self.vocab_size
if
len
(
indexes
)
==
0
:
indexes
=
self
.
encode
(
token
,
False
)
if
len
(
indexes
)
!=
1
:
self
.
logger
.
warning
(
f
'The token
{
token
}
, its length of indexes
{
indexes
}
is '
'not 1. Currently, it can not be used as stop words'
)
indexes
=
[]
self
.
_indexes_tokens_deque
.
append
((
token
,
indexes
))
return
indexes
def
encode
(
self
,
s
:
str
,
add_bos
:
bool
=
True
,
**
kwargs
):
"""Tokenize a prompt.
...
...
@@ -182,7 +339,10 @@ class HuggingFaceTokenizer:
encoded
=
encoded
[
1
:]
return
encoded
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
):
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
):
"""De-tokenize.
Args:
...
...
@@ -192,14 +352,121 @@ class HuggingFaceTokenizer:
Returns:
str: text of decoding tokens
"""
skip_special_tokens
=
True
t
=
t
[
offset
:]
out_string
=
self
.
model
.
decode
(
t
,
skip_special_tokens
=
skip_special_tokens
)
if
offset
:
logger
=
get_logger
(
'lmdeploy'
)
logger
.
warning
(
'For incrementally detokenization, please try '
'detokenize_incrementally function instead.'
)
out_string
=
self
.
_maybe_add_prefix_space
(
t
,
out_string
)
return
out_string
@
staticmethod
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
)
->
str
:
if
tokenizer
.
is_fast
or
not
tokenizer
.
get_added_vocab
():
return
tokenizer
.
convert_tokens_to_string
(
output_tokens
)
# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99
sub_texts
=
[]
current_sub_text
=
[]
all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
for
token
in
output_tokens
:
if
skip_special_tokens
and
token
in
all_special_tokens
:
continue
if
token
in
tokenizer
.
get_added_vocab
():
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
current_sub_text
=
[]
sub_texts
.
append
(
token
)
else
:
current_sub_text
.
append
(
token
)
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
if
spaces_between_special_tokens
:
return
' '
.
join
(
sub_texts
)
else
:
return
''
.
join
(
sub_texts
)
# Based on
# https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165
def
detokenize_incrementally
(
self
,
all_input_ids
:
Sequence
[
int
],
state
:
DetokenizeState
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
tokenizer
=
self
.
model
ids_offset
,
prev_tokens
,
prefix_offset
,
read_offset
=
state
.
as_tuple
()
# This is the first iteration for this sequence
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
all_input_ids
[
ids_offset
:],
skip_special_tokens
=
skip_special_tokens
)
if
prev_tokens
is
None
:
# Please notice that in VLLM, indexes are detokenized one by one
# while in LMDeploy, every turn, the detokenized indexes length
# can be different.
if
skip_special_tokens
and
new_tokens
and
new_tokens
[
0
]
in
tokenizer
.
all_special_ids
:
read_offset
=
1
# skip special token
output_tokens
=
new_tokens
prev_tokens
=
new_tokens
else
:
# Put new_token_id in a list so skip_special_tokens is respected
output_tokens
=
prev_tokens
+
new_tokens
prev_tokens
+=
new_tokens
prefix_text
=
self
.
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
new_text
=
self
.
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
# update state and get final decoded output
if
len
(
new_text
)
>
len
(
prefix_text
)
and
not
new_text
.
endswith
(
'�'
):
# utf-8 char at the end means it's a potential unfinished byte
# sequence from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
prefix_offset
=
read_offset
read_offset
=
len
(
output_tokens
)
new_text
=
new_text
[
len
(
prefix_text
):]
else
:
new_text
=
''
return
new_text
,
DetokenizeState
(
len
(
all_input_ids
),
prev_tokens
,
prefix_offset
,
read_offset
)
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
"""Tokenize prompts.
...
...
@@ -230,7 +497,7 @@ class Tokenizer:
model_file_exists
=
osp
.
exists
(
model_file
)
config_exists
=
osp
.
exists
(
tokenizer_config_file
)
use_hf_model
=
config_exists
or
not
model_file_exists
self
.
logger
=
get_logger
(
'lmdeploy'
)
if
not
use_hf_model
:
self
.
model
=
SentencePieceTokenizer
(
model_file
)
else
:
...
...
@@ -261,7 +528,12 @@ class Tokenizer:
"""
return
self
.
model
.
encode
(
s
,
add_bos
,
**
kwargs
)
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
):
def
decode
(
self
,
t
:
Sequence
[
int
],
offset
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
):
"""De-tokenize.
Args:
...
...
@@ -271,7 +543,34 @@ class Tokenizer:
Returns:
str: text of decoding tokens
"""
return
self
.
model
.
decode
(
t
,
offset
)
return
self
.
model
.
decode
(
t
,
offset
,
skip_special_tokens
)
def
detokenize_incrementally
(
self
,
all_input_ids
:
Sequence
[
int
],
state
:
DetokenizeState
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
):
"""Incrementally detokenize the input indexes.
Args:
all_input_ids (List[int]): a list of token ids. Expected to be
different sections of a long sequence.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
spaces_between_special_tokens (bool): Whether or not to add spaces
between special tokens. Default to be True.
Returns:
str: decoding output string of the current round.
state (DetokenizeState): an instance of DetokenizeState. Consists
of incrementally decoding states.
"""
return
self
.
model
.
detokenize_incrementally
(
all_input_ids
,
state
=
state
,
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
)
def
__call__
(
self
,
s
:
Union
[
str
,
Sequence
[
str
]]):
"""Tokenize prompts.
...
...
@@ -282,3 +581,14 @@ class Tokenizer:
list[int]: token ids
"""
return
self
.
model
(
s
)
def
indexes_containing_token
(
self
,
token
):
"""Return all the possible indexes, whose decoding output may contain
the input token."""
encoded
=
self
.
encode
(
token
,
add_bos
=
False
)
if
len
(
encoded
)
>
1
:
self
.
logger
.
warning
(
f
'The token
{
token
}
, its length of indexes
{
encoded
}
is over '
'than 1. Currently, it can not be used as stop words'
)
return
[]
return
self
.
model
.
indexes_containing_token
(
token
)
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment