Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Qwen_lmdeploy
Commits
ac3500b5
Unverified
Commit
ac3500b5
authored
Oct 25, 2023
by
AllentDan
Committed by
GitHub
Oct 25, 2023
Browse files
support inference a batch of prompts (#467)
* support inference a batch of prompts * docstring and assert
parent
169d5169
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
13 deletions
+75
-13
lmdeploy/serve/async_engine.py
lmdeploy/serve/async_engine.py
+61
-7
lmdeploy/serve/openai/api_server.py
lmdeploy/serve/openai/api_server.py
+13
-5
lmdeploy/serve/openai/protocol.py
lmdeploy/serve/openai/protocol.py
+1
-1
No files found.
lmdeploy/serve/async_engine.py
View file @
ac3500b5
...
...
@@ -4,7 +4,7 @@ import dataclasses
import
os.path
as
osp
import
random
from
contextlib
import
contextmanager
from
typing
import
Literal
,
Optional
from
typing
import
List
,
Literal
,
Optional
from
lmdeploy.model
import
MODELS
,
BaseModel
...
...
@@ -46,6 +46,7 @@ class AsyncEngine:
self
.
available
=
[
True
]
*
instance_num
self
.
starts
=
[
None
]
*
instance_num
self
.
steps
=
{}
self
.
loop
=
asyncio
.
get_event_loop
()
def
stop_session
(
self
,
session_id
:
int
):
instance_id
=
session_id
%
self
.
instance_num
...
...
@@ -82,6 +83,59 @@ class AsyncEngine:
await
asyncio
.
sleep
(
0.1
)
return
self
.
generators
[
instance_id
]
def
batch_infer
(
self
,
prompts
:
List
[
str
],
request_output_len
=
512
,
top_k
=
40
,
top_p
=
0.8
,
temperature
=
0.8
,
repetition_penalty
=
1.0
,
ignore_eos
=
False
,
**
kwargs
):
"""Inference a batch of prompts.
Args:
prompts (List[str]): a batch of prompts
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
"""
assert
isinstance
(
prompts
,
List
),
'prompts should be a list'
batch_size
=
len
(
prompts
)
outputs
=
[
''
]
*
batch_size
generators
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
generators
.
append
(
self
.
generate
(
prompt
,
i
,
stream_response
=
True
,
sequence_start
=
True
,
sequence_end
=
True
,
request_output_len
=
request_output_len
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
ignore_eos
=
ignore_eos
,
repetition_penalty
=
repetition_penalty
))
async
def
_inner_call
(
i
,
generator
):
async
for
out
in
generator
:
outputs
[
i
]
+=
out
.
response
async
def
gather
():
await
asyncio
.
gather
(
*
[
_inner_call
(
i
,
generators
[
i
])
for
i
in
range
(
batch_size
)])
self
.
loop
.
run_until_complete
(
gather
())
return
outputs
async
def
generate
(
self
,
messages
,
...
...
@@ -109,11 +163,11 @@ class AsyncEngine:
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): whether stop inference
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
...
...
@@ -195,11 +249,11 @@ class AsyncEngine:
renew_session (bool): renew the session
request_output_len (int): output token nums
stop (bool): whether stop inference
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
...
...
lmdeploy/serve/openai/api_server.py
View file @
ac3500b5
...
...
@@ -229,11 +229,19 @@ async def create_embeddings(request: EmbeddingsRequest,
error_check_ret
=
await
check_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
embedding
=
await
VariableInterface
.
async_engine
.
get_embeddings
(
request
.
input
)
data
=
[{
'object'
:
'embedding'
,
'embedding'
:
embedding
,
'index'
:
0
}]
token_num
=
len
(
embedding
)
if
isinstance
(
request
.
input
,
str
):
request
.
input
=
[
request
.
input
]
data
=
[]
token_num
=
0
for
i
,
prompt
in
enumerate
(
request
.
input
):
embedding
=
await
VariableInterface
.
async_engine
.
get_embeddings
(
prompt
)
data
.
append
({
'object'
:
'embedding'
,
'embedding'
:
embedding
,
'index'
:
i
})
token_num
+=
len
(
embedding
)
return
EmbeddingsResponse
(
data
=
data
,
model
=
request
.
model
,
...
...
lmdeploy/serve/openai/protocol.py
View file @
ac3500b5
...
...
@@ -175,7 +175,7 @@ class CompletionStreamResponse(BaseModel):
class
EmbeddingsRequest
(
BaseModel
):
"""Embedding request."""
model
:
str
=
None
input
:
Union
[
str
,
List
[
Any
]]
input
:
Union
[
str
,
List
[
str
]]
user
:
Optional
[
str
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment