Unverified Commit 41c26237 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: allow any supported payload on /invocations (#2683)

* feat: allow any supported payload on /invocations

* update openAPI

* update doc
parent 27ff1871
...@@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \ ...@@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses. You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash ```bash
curl localhost:3000/v1/chat/completions \ curl localhost:8080/v1/chat/completions \
-X POST \ -X POST \
-d '{ -d '{
"model": "tgi", "model": "tgi",
......
...@@ -3,7 +3,7 @@ use std::collections::HashMap; ...@@ -3,7 +3,7 @@ use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server; use text_generation_router::{server, usage_stats};
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
/// App Configuration /// App Configuration
...@@ -48,14 +48,14 @@ struct Args { ...@@ -48,14 +48,14 @@ struct Args {
otlp_service_name: String, otlp_service_name: String,
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Option<Vec<String>>, cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(long, env)] #[clap(long, env)]
auth_token: Option<String>, auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf, executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
} }
#[tokio::main] #[tokio::main]
...@@ -83,10 +83,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { ...@@ -83,10 +83,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
otlp_endpoint, otlp_endpoint,
otlp_service_name, otlp_service_name,
cors_allow_origin, cors_allow_origin,
messages_api_enabled,
max_client_batch_size, max_client_batch_size,
auth_token, auth_token,
executor_worker, executor_worker,
usage_stats,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
...@@ -155,11 +155,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { ...@@ -155,11 +155,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
false, false,
None, None,
None, None,
messages_api_enabled,
true, true,
max_client_batch_size, max_client_batch_size,
false, usage_stats,
false,
) )
.await?; .await?;
Ok(()) Ok(())
......
...@@ -63,8 +63,6 @@ struct Args { ...@@ -63,8 +63,6 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
...@@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> { ...@@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
...@@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> { ...@@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
......
...@@ -63,8 +63,6 @@ struct Args { ...@@ -63,8 +63,6 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
...@@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> { ...@@ -110,7 +108,6 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
...@@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> { ...@@ -190,7 +187,6 @@ async fn main() -> Result<(), RouterError> {
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
......
...@@ -316,6 +316,98 @@ ...@@ -316,6 +316,98 @@
} }
} }
}, },
"/invocations": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens from Sagemaker request",
"operationId": "sagemaker_compatibility",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/SagemakerStreamResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error",
"error_type": "validation"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation",
"error_type": "generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded",
"error_type": "overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation",
"error_type": "incomplete_generation"
}
}
}
}
}
}
},
"/metrics": { "/metrics": {
"get": { "get": {
"tags": [ "tags": [
...@@ -1865,6 +1957,45 @@ ...@@ -1865,6 +1957,45 @@
"type": "string" "type": "string"
} }
}, },
"SagemakerRequest": {
"oneOf": [
{
"$ref": "#/components/schemas/CompatGenerateRequest"
},
{
"$ref": "#/components/schemas/ChatRequest"
},
{
"$ref": "#/components/schemas/CompletionRequest"
}
]
},
"SagemakerResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/GenerateResponse"
},
{
"$ref": "#/components/schemas/ChatCompletion"
},
{
"$ref": "#/components/schemas/CompletionFinal"
}
]
},
"SagemakerStreamResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/StreamResponse"
},
{
"$ref": "#/components/schemas/ChatCompletionChunk"
},
{
"$ref": "#/components/schemas/Chunk"
}
]
},
"SimpleToken": { "SimpleToken": {
"type": "object", "type": "object",
"required": [ "required": [
......
...@@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene ...@@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
## Amazon SageMaker ## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`. Amazon Sagemaker natively supports the message API:
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
```python ```python
import json import json
...@@ -161,12 +159,11 @@ except ValueError: ...@@ -161,12 +159,11 @@ except ValueError:
hub = { hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta', 'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1), 'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
} }
# create Hugging Face Model Class # create Hugging Face Model Class
huggingface_model = HuggingFaceModel( huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"), image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
env=hub, env=hub,
role=role, role=role,
) )
......
...@@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected: ...@@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
"max_top_n_tokens": 5, "max_top_n_tokens": 5,
"max_total_tokens": 2048, "max_total_tokens": 2048,
"max_waiting_tokens": 20, "max_waiting_tokens": 20,
"messages_api_enabled": false,
"model_config": { "model_config": {
"model_type": "Bloom" "model_type": "Bloom"
}, },
......
...@@ -8,6 +8,7 @@ pub mod validation; ...@@ -8,6 +8,7 @@ pub mod validation;
mod kserve; mod kserve;
pub mod logging; pub mod logging;
mod sagemaker;
pub mod usage_stats; pub mod usage_stats;
mod vertex; mod vertex;
......
This diff is collapsed.
use crate::infer::Infer;
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
use crate::{
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerRequest {
Generate(CompatGenerateRequest),
Chat(ChatRequest),
Completion(CompletionRequest),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerResponse {
Generate(GenerateResponse),
Chat(ChatCompletion),
Completion(CompletionFinal),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerStreamResponse {
Generate(StreamResponse),
Chat(ChatCompletionChunk),
Completion(Chunk),
}
/// Generate tokens from Sagemaker request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/invocations",
request_body = SagemakerRequest,
responses(
(status = 200, description = "Generated Chat Completion",
content(
("application/json" = SagemakerResponse),
("text/event-stream" = SagemakerStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
)
)]
#[instrument(skip_all)]
pub(crate) async fn sagemaker_compatibility(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
info: Extension<Info>,
Json(req): Json<SagemakerRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
match req {
SagemakerRequest::Generate(req) => {
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
}
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
SagemakerRequest::Completion(req) => {
completions(infer, compute_type, info, Json(req)).await
}
}
}
...@@ -7,6 +7,10 @@ use crate::kserve::{ ...@@ -7,6 +7,10 @@ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::sagemaker::{
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
__path_sagemaker_compatibility,
};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility; use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse; use crate::ChatTokenizeResponse;
...@@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})), ...@@ -83,7 +87,7 @@ example = json ! ({"error": "Incomplete generation"})),
) )
)] )]
#[instrument(skip(infer, req))] #[instrument(skip(infer, req))]
async fn compat_generate( pub(crate) async fn compat_generate(
Extension(default_return_full_text): Extension<bool>, Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
compute_type: Extension<ComputeType>, compute_type: Extension<ComputeType>,
...@@ -678,7 +682,7 @@ time_per_token, ...@@ -678,7 +682,7 @@ time_per_token,
seed, seed,
) )
)] )]
async fn completions( pub(crate) async fn completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
...@@ -1202,7 +1206,7 @@ time_per_token, ...@@ -1202,7 +1206,7 @@ time_per_token,
seed, seed,
) )
)] )]
async fn chat_completions( pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
...@@ -1513,11 +1517,13 @@ completions, ...@@ -1513,11 +1517,13 @@ completions,
tokenize, tokenize,
metrics, metrics,
openai_get_model_info, openai_get_model_info,
sagemaker_compatibility,
), ),
components( components(
schemas( schemas(
Info, Info,
CompatGenerateRequest, CompatGenerateRequest,
SagemakerRequest,
GenerateRequest, GenerateRequest,
GrammarType, GrammarType,
ChatRequest, ChatRequest,
...@@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob, ...@@ -1540,6 +1546,8 @@ ChatCompletionTopLogprob,
ChatCompletion, ChatCompletion,
CompletionRequest, CompletionRequest,
CompletionComplete, CompletionComplete,
SagemakerResponse,
SagemakerStreamResponse,
Chunk, Chunk,
Completion, Completion,
CompletionFinal, CompletionFinal,
...@@ -1607,7 +1615,6 @@ pub async fn run( ...@@ -1607,7 +1615,6 @@ pub async fn run(
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel, usage_stats_level: usage_stats::UsageStatsLevel,
...@@ -1836,7 +1843,6 @@ pub async fn run( ...@@ -1836,7 +1843,6 @@ pub async fn run(
// max_batch_size, // max_batch_size,
revision.clone(), revision.clone(),
validation_workers, validation_workers,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,
...@@ -1878,7 +1884,6 @@ pub async fn run( ...@@ -1878,7 +1884,6 @@ pub async fn run(
ngrok, ngrok,
_ngrok_authtoken, _ngrok_authtoken,
_ngrok_edge, _ngrok_edge,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
model_info, model_info,
...@@ -1938,7 +1943,6 @@ async fn start( ...@@ -1938,7 +1943,6 @@ async fn start(
ngrok: bool, ngrok: bool,
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
model_info: HubModelInfo, model_info: HubModelInfo,
...@@ -2253,6 +2257,7 @@ async fn start( ...@@ -2253,6 +2257,7 @@ async fn start(
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions)) .route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility)) .route("/vertex", post(vertex_compatibility))
.route("/invocations", post(sagemaker_compatibility))
.route("/tokenize", post(tokenize)); .route("/tokenize", post(tokenize));
if let Some(api_key) = api_key { if let Some(api_key) = api_key {
...@@ -2288,13 +2293,6 @@ async fn start( ...@@ -2288,13 +2293,6 @@ async fn start(
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info)); .route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type = let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
...@@ -2302,8 +2300,7 @@ async fn start( ...@@ -2302,8 +2300,7 @@ async fn start(
let mut app = Router::new() let mut app = Router::new()
.merge(swagger_ui) .merge(swagger_ui)
.merge(base_routes) .merge(base_routes)
.merge(info_routes) .merge(info_routes);
.merge(aws_sagemaker_route);
#[cfg(feature = "google")] #[cfg(feature = "google")]
{ {
......
...@@ -93,7 +93,6 @@ pub struct Args { ...@@ -93,7 +93,6 @@ pub struct Args {
// max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
...@@ -117,7 +116,6 @@ impl Args { ...@@ -117,7 +116,6 @@ impl Args {
// max_batch_size: Option<usize>, // max_batch_size: Option<usize>,
revision: Option<String>, revision: Option<String>,
validation_workers: usize, validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel, usage_stats_level: UsageStatsLevel,
...@@ -138,7 +136,6 @@ impl Args { ...@@ -138,7 +136,6 @@ impl Args {
// max_batch_size, // max_batch_size,
revision, revision,
validation_workers, validation_workers,
messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats_level, usage_stats_level,
......
...@@ -172,6 +172,8 @@ def check_openapi(check: bool): ...@@ -172,6 +172,8 @@ def check_openapi(check: bool):
# allow for trailing whitespace since it's not significant # allow for trailing whitespace since it's not significant
# and the precommit hook will remove it # and the precommit hook will remove it
"lint", "lint",
"--skip-rule",
"security-defined",
filename, filename,
], ],
capture_output=True, capture_output=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment