Unverified Commit ac3500b5 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

support inference a batch of prompts (#467)

* support inference a batch of prompts

* docstring and assert
parent 169d5169
...@@ -4,7 +4,7 @@ import dataclasses ...@@ -4,7 +4,7 @@ import dataclasses
import os.path as osp import os.path as osp
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from typing import Literal, Optional from typing import List, Literal, Optional
from lmdeploy.model import MODELS, BaseModel from lmdeploy.model import MODELS, BaseModel
...@@ -46,6 +46,7 @@ class AsyncEngine: ...@@ -46,6 +46,7 @@ class AsyncEngine:
self.available = [True] * instance_num self.available = [True] * instance_num
self.starts = [None] * instance_num self.starts = [None] * instance_num
self.steps = {} self.steps = {}
self.loop = asyncio.get_event_loop()
def stop_session(self, session_id: int): def stop_session(self, session_id: int):
instance_id = session_id % self.instance_num instance_id = session_id % self.instance_num
...@@ -82,6 +83,59 @@ class AsyncEngine: ...@@ -82,6 +83,59 @@ class AsyncEngine:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return self.generators[instance_id] return self.generators[instance_id]
def batch_infer(self,
prompts: List[str],
request_output_len=512,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
**kwargs):
"""Inference a batch of prompts.
Args:
prompts (List[str]): a batch of prompts
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
"""
assert isinstance(prompts, List), 'prompts should be a list'
batch_size = len(prompts)
outputs = [''] * batch_size
generators = []
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
stream_response=True,
sequence_start=True,
sequence_end=True,
request_output_len=request_output_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
ignore_eos=ignore_eos,
repetition_penalty=repetition_penalty))
async def _inner_call(i, generator):
async for out in generator:
outputs[i] += out.response
async def gather():
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(batch_size)])
self.loop.run_until_complete(gather())
return outputs
async def generate( async def generate(
self, self,
messages, messages,
...@@ -109,11 +163,11 @@ class AsyncEngine: ...@@ -109,11 +163,11 @@ class AsyncEngine:
sequence_end (bool): indicator for ending a sequence sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache step (int): the offset of the k/v cache
stop (bool): whether stop inference stop (bool): whether stop inference
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher probable tokens with probabilities that add up to top_p or higher
are kept for generation. are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
...@@ -195,11 +249,11 @@ class AsyncEngine: ...@@ -195,11 +249,11 @@ class AsyncEngine:
renew_session (bool): renew the session renew_session (bool): renew the session
request_output_len (int): output token nums request_output_len (int): output token nums
stop (bool): whether stop inference stop (bool): whether stop inference
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher probable tokens with probabilities that add up to top_p or higher
are kept for generation. are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty. repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty 1.0 means no penalty
......
...@@ -229,11 +229,19 @@ async def create_embeddings(request: EmbeddingsRequest, ...@@ -229,11 +229,19 @@ async def create_embeddings(request: EmbeddingsRequest,
error_check_ret = await check_request(request) error_check_ret = await check_request(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if isinstance(request.input, str):
embedding = await VariableInterface.async_engine.get_embeddings( request.input = [request.input]
request.input)
data = [{'object': 'embedding', 'embedding': embedding, 'index': 0}] data = []
token_num = len(embedding) token_num = 0
for i, prompt in enumerate(request.input):
embedding = await VariableInterface.async_engine.get_embeddings(prompt)
data.append({
'object': 'embedding',
'embedding': embedding,
'index': i
})
token_num += len(embedding)
return EmbeddingsResponse( return EmbeddingsResponse(
data=data, data=data,
model=request.model, model=request.model,
......
...@@ -175,7 +175,7 @@ class CompletionStreamResponse(BaseModel): ...@@ -175,7 +175,7 @@ class CompletionStreamResponse(BaseModel):
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
"""Embedding request.""" """Embedding request."""
model: str = None model: str = None
input: Union[str, List[Any]] input: Union[str, List[str]]
user: Optional[str] = None user: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment