Unverified Commit ac3500b5 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

support inference a batch of prompts (#467)

* support inference a batch of prompts

* docstring and assert
parent 169d5169
......@@ -4,7 +4,7 @@ import dataclasses
import os.path as osp
import random
from contextlib import contextmanager
from typing import Literal, Optional
from typing import List, Literal, Optional
from lmdeploy.model import MODELS, BaseModel
......@@ -46,6 +46,7 @@ class AsyncEngine:
self.available = [True] * instance_num
self.starts = [None] * instance_num
self.steps = {}
self.loop = asyncio.get_event_loop()
def stop_session(self, session_id: int):
instance_id = session_id % self.instance_num
......@@ -82,6 +83,59 @@ class AsyncEngine:
await asyncio.sleep(0.1)
return self.generators[instance_id]
def batch_infer(self,
prompts: List[str],
request_output_len=512,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
**kwargs):
"""Inference a batch of prompts.
Args:
prompts (List[str]): a batch of prompts
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
"""
assert isinstance(prompts, List), 'prompts should be a list'
batch_size = len(prompts)
outputs = [''] * batch_size
generators = []
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
stream_response=True,
sequence_start=True,
sequence_end=True,
request_output_len=request_output_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
ignore_eos=ignore_eos,
repetition_penalty=repetition_penalty))
async def _inner_call(i, generator):
async for out in generator:
outputs[i] += out.response
async def gather():
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(batch_size)])
self.loop.run_until_complete(gather())
return outputs
async def generate(
self,
messages,
......@@ -109,11 +163,11 @@ class AsyncEngine:
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): whether stop inference
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
......@@ -195,11 +249,11 @@ class AsyncEngine:
renew_session (bool): renew the session
request_output_len (int): output token nums
stop (bool): whether stop inference
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
......
......@@ -229,11 +229,19 @@ async def create_embeddings(request: EmbeddingsRequest,
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
embedding = await VariableInterface.async_engine.get_embeddings(
request.input)
data = [{'object': 'embedding', 'embedding': embedding, 'index': 0}]
token_num = len(embedding)
if isinstance(request.input, str):
request.input = [request.input]
data = []
token_num = 0
for i, prompt in enumerate(request.input):
embedding = await VariableInterface.async_engine.get_embeddings(prompt)
data.append({
'object': 'embedding',
'embedding': embedding,
'index': i
})
token_num += len(embedding)
return EmbeddingsResponse(
data=data,
model=request.model,
......
......@@ -175,7 +175,7 @@ class CompletionStreamResponse(BaseModel):
class EmbeddingsRequest(BaseModel):
"""Embedding request."""
model: str = None
input: Union[str, List[Any]]
input: Union[str, List[str]]
user: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment