Unverified Commit 06c3d4b1 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: accept list as prompt and use first string (#1702)

This PR allows the `CompletionRequest.prompt` to be sent as a string or
array of strings. When an array is sent the first value will be used if
it's a string; otherwise the according error will be thrown

Fixes:
https://github.com/huggingface/text-generation-inference/issues/1690
Similar to: https://github.com/vllm-project/vllm/pull/323/files
parent e4d31a40
......@@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel):
usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel):
name: Optional[str]
arguments: str
......@@ -104,6 +115,16 @@ class ChatComplete(BaseModel):
usage: Any
class Completion(BaseModel):
# Completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[CompletionComplete]
class ChatRequest(BaseModel):
# Model identifier
model: str
......
......@@ -398,6 +398,15 @@ Options:
-e, --env
Display a lot of information about your runtime environment
```
## MAX_CLIENT_BATCH_SIZE
```shell
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
Control the maximum number of inputs that a client can send in a single request
[env: MAX_CLIENT_BATCH_SIZE=]
[default: 4]
```
## HELP
```shell
......
......@@ -9,6 +9,7 @@ import json
import math
import time
import random
import re
from docker.errors import NotFound
from typing import Optional, List, Dict
......@@ -26,6 +27,7 @@ from text_generation.types import (
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
Completion,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
......@@ -69,17 +71,22 @@ class ResponseComparator(JSONSnapshotExtension):
data = json.loads(data)
if isinstance(data, Dict) and "choices" in data:
choices = data["choices"]
if (
isinstance(choices, List)
and len(choices) >= 1
and "delta" in choices[0]
):
return ChatCompletionChunk(**data)
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
if isinstance(data, Dict):
return Response(**data)
if isinstance(data, List):
if (
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data]
raise NotImplementedError
......@@ -161,6 +168,9 @@ class ResponseComparator(JSONSnapshotExtension):
)
)
def eq_completion(response: Completion, other: Completion) -> bool:
return response.choices[0].text == other.choices[0].text
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return (
response.choices[0].message.content == other.choices[0].message.content
......@@ -184,6 +194,11 @@ class ResponseComparator(JSONSnapshotExtension):
if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data]
if isinstance(serialized_data[0], Completion):
return len(snapshot_data) == len(serialized_data) and all(
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
......
{
"choices": [
{
"finish_reason": "eos_token",
"index": 1,
"logprobs": null,
"text": " PR for more information?"
},
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
},
{
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": " severely flawed and often has a substandard"
},
{
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": "hd20220811-"
}
],
"created": 1713284455,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 8,
"total_tokens": 44
}
}
[
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "hd"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "aho"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "ima"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Sarah"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Yes"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " And"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "i"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": ","
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " what"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "s"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Moh"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " is"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "m"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Room"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "s"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " the"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " tired"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": ":"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " capital"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " of"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " She"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " scale"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " of"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " being"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
}
]
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " PR for flake8"
}
],
"created": 1713284454,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 5,
"prompt_tokens": 6,
"total_tokens": 11
}
}
import pytest
import requests
import json
from aiohttp import ClientSession
from text_generation.types import (
Completion,
)
@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_completion(flash_llama_completion_handle):
await flash_llama_completion_handle.health(300)
return flash_llama_completion_handle.client
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": "Say this is a test",
"max_tokens": 5,
"seed": 0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 1
assert response == response_snapshot
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a"],
"max_tokens": 10,
"seed": 0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]]
all_indexes.sort()
assert all_indexes == [0, 1, 2, 3]
assert response == response_snapshot
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
request = {
"model": "tgi",
"prompt": [
"What color is the sky?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"stream": True,
}
url = f"{flash_llama_completion.base_url}/v1/completions"
chunks = []
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(Completion(**c))
assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4
assert response.status == 200
assert chunks == response_snapshot
......@@ -414,6 +414,10 @@ struct Args {
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}
#[derive(Debug)]
......@@ -1078,6 +1082,8 @@ fn spawn_webserver(
// Start webserver
tracing::info!("Starting Webserver");
let mut router_args = vec![
"--max-client-batch-size".to_string(),
args.max_client_batch_size.to_string(),
"--max-concurrent-requests".to_string(),
args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(),
......
......@@ -155,6 +155,8 @@ pub struct Info {
pub max_batch_size: Option<usize>,
#[schema(example = "2")]
pub validation_workers: usize,
#[schema(example = "32")]
pub max_client_batch_size: usize,
/// Router Info
#[schema(example = "0.5.0")]
pub version: &'static str,
......@@ -280,6 +282,34 @@ fn default_parameters() -> GenerateParameters {
}
}
mod prompt_serde {
use serde::{self, Deserialize, Deserializer};
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(vec![s]),
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
"Empty array detected. Do not use an empty array for the prompt.",
)),
Value::Array(arr) => arr
.iter()
.map(|v| match v {
Value::String(s) => Ok(s.to_owned()),
_ => Err(serde::de::Error::custom("Expected a string")),
})
.collect(),
_ => Err(serde::de::Error::custom(
"Expected a string or an array of strings",
)),
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub struct CompletionRequest {
/// UNUSED
......@@ -289,7 +319,8 @@ pub struct CompletionRequest {
/// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")]
pub prompt: String,
#[serde(deserialize_with = "prompt_serde::deserialize")]
pub prompt: Vec<String>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
......
......@@ -78,6 +78,8 @@ struct Args {
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
}
#[tokio::main]
......@@ -112,6 +114,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
} = args;
// Launch Tokio runtime
......@@ -393,6 +396,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_config,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
)
.await?;
Ok(())
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment