Unverified Commit 41c26237 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: allow any supported payload on /invocations (#2683)

* feat: allow any supported payload on /invocations

* update openAPI

* update doc
parent 27ff1871
......@@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash
curl localhost:3000/v1/chat/completions \
curl localhost:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
......
......@@ -3,7 +3,7 @@ use std::collections::HashMap;
use std::path::PathBuf;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server;
use text_generation_router::{server, usage_stats};
use tokenizers::{FromPretrainedParameters, Tokenizer};
/// App Configuration
......@@ -48,14 +48,14 @@ struct Args {
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
#[tokio::main]
......@@ -83,10 +83,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
auth_token,
executor_worker,
usage_stats,
} = args;
// Launch Tokio runtime
......@@ -155,11 +155,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
false,
None,
None,
messages_api_enabled,
true,
max_client_batch_size,
false,
false,
usage_stats,
)
.await?;
Ok(())
......
......@@ -63,8 +63,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
......@@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
......@@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
......
......@@ -63,8 +63,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
......@@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
......@@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
......
......@@ -316,6 +316,98 @@
}
}
},
"/invocations": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens from Sagemaker request",
"operationId": "sagemaker_compatibility",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/SagemakerStreamResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error",
"error_type": "validation"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation",
"error_type": "generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded",
"error_type": "overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation",
"error_type": "incomplete_generation"
}
}
}
}
}
}
},
"/metrics": {
"get": {
"tags": [
......@@ -1865,6 +1957,45 @@
"type": "string"
}
},
"SagemakerRequest": {
"oneOf": [
{
"$ref": "#/components/schemas/CompatGenerateRequest"
},
{
"$ref": "#/components/schemas/ChatRequest"
},
{
"$ref": "#/components/schemas/CompletionRequest"
}
]
},
"SagemakerResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/GenerateResponse"
},
{
"$ref": "#/components/schemas/ChatCompletion"
},
{
"$ref": "#/components/schemas/CompletionFinal"
}
]
},
"SagemakerStreamResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/StreamResponse"
},
{
"$ref": "#/components/schemas/ChatCompletionChunk"
},
{
"$ref": "#/components/schemas/Chunk"
}
]
},
"SimpleToken": {
"type": "object",
"required": [
......
......@@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
Amazon Sagemaker natively supports the message API:
```python
import json
......@@ -161,12 +159,11 @@ except ValueError:
hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
env=hub,
role=role,
)
......
......@@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
"max_top_n_tokens": 5,
"max_total_tokens": 2048,
"max_waiting_tokens": 20,
"messages_api_enabled": false,
"model_config": {
"model_type": "Bloom"
},
......
......@@ -8,6 +8,7 @@ pub mod validation;
mod kserve;
pub mod logging;
mod sagemaker;
pub mod usage_stats;
mod vertex;
......
This diff is collapsed.
use crate::infer::Infer;
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
use crate::{
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerRequest {
Generate(CompatGenerateRequest),
Chat(ChatRequest),
Completion(CompletionRequest),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerResponse {
Generate(GenerateResponse),
Chat(ChatCompletion),
Completion(CompletionFinal),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerStreamResponse {
Generate(StreamResponse),
Chat(ChatCompletionChunk),
Completion(Chunk),
}
/// Generate tokens from Sagemaker request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/invocations",
request_body = SagemakerRequest,
responses(
(status = 200, description = "Generated Chat Completion",
content(
("application/json" = SagemakerResponse),
("text/event-stream" = SagemakerStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
)
)]
#[instrument(skip_all)]
pub(crate) async fn sagemaker_compatibility(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
info: Extension<Info>,
Json(req): Json<SagemakerRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
match req {
SagemakerRequest::Generate(req) => {
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
}
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
SagemakerRequest::Completion(req) => {
completions(infer, compute_type, info, Json(req)).await
}
}
}
......@@ -7,6 +7,10 @@ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::sagemaker::{
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
__path_sagemaker_compatibility,
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
......@@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn compat_generate(
pub(crate) async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
......@@ -678,7 +682,7 @@ time_per_token,
seed,
)
)]
async fn completions(
pub(crate) async fn completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
......@@ -1202,7 +1206,7 @@ time_per_token,
seed,
)
)]
async fn chat_completions(
pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
......@@ -1513,11 +1517,13 @@ completions,
tokenize,
metrics,
openai_get_model_info,
sagemaker_compatibility,
),
components(
schemas(
Info,
CompatGenerateRequest,
SagemakerRequest,
GenerateRequest,
GrammarType,
ChatRequest,
......@@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
SagemakerResponse,
SagemakerStreamResponse,
Chunk,
Completion,
CompletionFinal,
......@@ -1607,7 +1615,6 @@ pub async fn run(
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
......@@ -1836,7 +1843,6 @@ pub async fn run(
// max_batch_size,
revision.clone(),
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,
......@@ -1878,7 +1884,6 @@ pub async fn run(
ngrok,
_ngrok_authtoken,
_ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
model_info,
......@@ -1938,7 +1943,6 @@ async fn start(
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
model_info: HubModelInfo,
......@@ -2253,6 +2257,7 @@ async fn start(
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/invocations", post(sagemaker_compatibility))
.route("/tokenize", post(tokenize));
if let Some(api_key) = api_key {
......@@ -2288,13 +2293,6 @@ async fn start(
.route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
......@@ -2302,8 +2300,7 @@ async fn start(
let mut app = Router::new()
.merge(swagger_ui)
.merge(base_routes)
.merge(info_routes)
.merge(aws_sagemaker_route);
.merge(info_routes);
#[cfg(feature = "google")]
{
......
......@@ -93,7 +93,6 @@ pub struct Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
......@@ -117,7 +116,6 @@ impl Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
......@@ -138,7 +136,6 @@ impl Args {
// max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,
......
......@@ -172,6 +172,8 @@ def check_openapi(check: bool):
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"lint",
"--skip-rule",
"security-defined",
filename,
],
capture_output=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment