Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
df230625
Unverified
Commit
df230625
authored
Feb 20, 2024
by
drbh
Committed by
GitHub
Feb 20, 2024
Browse files
improve endpoint support (#1577)
small PR to add a new interface endpoint behind a feature
parent
d19c768c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
170 additions
and
7 deletions
+170
-7
router/Cargo.toml
router/Cargo.toml
+1
-0
router/src/lib.rs
router/src/lib.rs
+21
-2
router/src/main.rs
router/src/main.rs
+9
-0
router/src/server.rs
router/src/server.rs
+139
-5
No files found.
router/Cargo.toml
View file @
df230625
...
...
@@ -52,3 +52,4 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
[features]
default
=
["ngrok"]
ngrok
=
["dep:ngrok"]
google
=
[]
router/src/lib.rs
View file @
df230625
...
...
@@ -20,6 +20,25 @@ pub(crate) type GenerateStreamResponse = (
UnboundedReceiverStream
<
Result
<
InferStreamResponse
,
InferError
>>
,
);
#[derive(Clone,
Deserialize,
ToSchema)]
pub
(
crate
)
struct
VertexInstance
{
#[schema(example
=
"What is Deep Learning?"
)]
pub
inputs
:
String
,
#[schema(nullable
=
true
,
default
=
"null"
,
example
=
"null"
)]
pub
parameters
:
Option
<
GenerateParameters
>
,
}
#[derive(Deserialize,
ToSchema)]
pub
(
crate
)
struct
VertexRequest
{
#[serde(rename
=
"instances"
)]
pub
instances
:
Vec
<
VertexInstance
>
,
}
#[derive(Clone,
Deserialize,
ToSchema,
Serialize)]
pub
(
crate
)
struct
VertexResponse
{
pub
predictions
:
Vec
<
String
>
,
}
/// Hub type
#[derive(Clone,
Debug,
Deserialize)]
pub
struct
HubModelInfo
{
...
...
@@ -70,7 +89,7 @@ mod json_object_or_string_to_string {
}
}
#[derive(Clone,
Debug,
Deserialize)]
#[derive(Clone,
Debug,
Deserialize
,
ToSchema
)]
#[serde(tag
=
"type"
,
content
=
"value"
)]
pub
(
crate
)
enum
GrammarType
{
#[serde(
...
...
@@ -153,7 +172,7 @@ pub struct Info {
pub
docker_label
:
Option
<&
'static
str
>
,
}
#[derive(Clone,
Debug,
Deserialize,
ToSchema)]
#[derive(Clone,
Debug,
Deserialize,
ToSchema
,
Default
)]
pub
(
crate
)
struct
GenerateParameters
{
#[serde(default)]
#[schema(exclusive_minimum
=
0
,
nullable
=
true
,
default
=
"null"
,
example
=
1
)]
...
...
router/src/main.rs
View file @
df230625
...
...
@@ -328,6 +328,15 @@ async fn main() -> Result<(), RouterError> {
tracing
::
info!
(
"Setting max batch total tokens to {max_supported_batch_total_tokens}"
);
tracing
::
info!
(
"Connected"
);
// Determine the server port based on the feature and environment variable.
let
port
=
if
cfg!
(
feature
=
"google"
)
{
std
::
env
::
var
(
"AIP_HTTP_PORT"
)
.map
(|
aip_http_port
|
aip_http_port
.parse
::
<
u16
>
()
.unwrap_or
(
port
))
.unwrap_or
(
port
)
}
else
{
port
};
let
addr
=
match
hostname
.parse
()
{
Ok
(
ip
)
=>
SocketAddr
::
new
(
ip
,
port
),
Err
(
_
)
=>
{
...
...
router/src/server.rs
View file @
df230625
...
...
@@ -5,9 +5,9 @@ use crate::validation::ValidationError;
use
crate
::{
BestOfSequence
,
ChatCompletion
,
ChatCompletionChoice
,
ChatCompletionChunk
,
ChatCompletionDelta
,
ChatCompletionLogprobs
,
ChatRequest
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
HubModelInfo
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
GrammarType
,
HubModelInfo
,
HubTokenizerConfig
,
Infer
,
Info
,
Message
,
PrefillToken
,
SimpleToken
,
StreamDetails
,
StreamResponse
,
Token
,
TokenizeResponse
,
Validation
,
StreamResponse
,
Token
,
TokenizeResponse
,
Validation
,
VertexRequest
,
VertexResponse
,
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
...
...
@@ -16,8 +16,10 @@ use axum::response::{IntoResponse, Response};
use
axum
::
routing
::{
get
,
post
};
use
axum
::{
http
,
Json
,
Router
};
use
axum_tracing_opentelemetry
::
middleware
::
OtelAxumLayer
;
use
futures
::
stream
::
FuturesUnordered
;
use
futures
::
stream
::
StreamExt
;
use
futures
::
Stream
;
use
futures
::
TryStreamExt
;
use
metrics_exporter_prometheus
::{
Matcher
,
PrometheusBuilder
,
PrometheusHandle
};
use
std
::
convert
::
Infallible
;
use
std
::
net
::
SocketAddr
;
...
...
@@ -693,6 +695,97 @@ async fn chat_completions(
}
}
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag
=
"Text Generation Inference"
,
path
=
"/vertex"
,
request_body
=
VertexRequest,
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
VertexResponse),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async
fn
vertex_compatibility
(
Extension
(
infer
):
Extension
<
Infer
>
,
Extension
(
compute_type
):
Extension
<
ComputeType
>
,
Json
(
req
):
Json
<
VertexRequest
>
,
)
->
Result
<
Response
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
metrics
::
increment_counter!
(
"tgi_request_count"
);
// check that theres at least one instance
if
req
.instances
.is_empty
()
{
return
Err
((
StatusCode
::
UNPROCESSABLE_ENTITY
,
Json
(
ErrorResponse
{
error
:
"Input validation error"
.to_string
(),
error_type
:
"Input validation error"
.to_string
(),
}),
));
}
// Process all instances
let
predictions
=
req
.instances
.iter
()
.map
(|
instance
|
{
let
generate_request
=
GenerateRequest
{
inputs
:
instance
.inputs
.clone
(),
parameters
:
GenerateParameters
{
do_sample
:
true
,
max_new_tokens
:
instance
.parameters
.as_ref
()
.and_then
(|
p
|
p
.max_new_tokens
),
seed
:
instance
.parameters
.as_ref
()
.and_then
(|
p
|
p
.seed
),
details
:
true
,
decoder_input_details
:
true
,
..
Default
::
default
()
},
};
async
{
generate
(
Extension
(
infer
.clone
()),
Extension
(
compute_type
.clone
()),
Json
(
generate_request
),
)
.await
.map
(|(
_
,
Json
(
generation
))|
generation
.generated_text
)
.map_err
(|
_
|
{
(
StatusCode
::
INTERNAL_SERVER_ERROR
,
Json
(
ErrorResponse
{
error
:
"Incomplete generation"
.into
(),
error_type
:
"Incomplete generation"
.into
(),
}),
)
})
}
})
.collect
::
<
FuturesUnordered
<
_
>>
()
.try_collect
::
<
Vec
<
_
>>
()
.await
?
;
let
response
=
VertexResponse
{
predictions
};
Ok
((
HeaderMap
::
new
(),
Json
(
response
))
.into_response
())
}
/// Tokenize inputs
#[utoipa::path(
post,
...
...
@@ -818,6 +911,7 @@ pub async fn run(
StreamResponse,
StreamDetails,
ErrorResponse,
GrammarType,
)
),
tags(
...
...
@@ -942,8 +1036,30 @@ pub async fn run(
docker_label
:
option_env!
(
"DOCKER_LABEL"
),
};
// Define VertextApiDoc conditionally only if the "google" feature is enabled
#[cfg(feature
=
"google"
)]
#[derive(OpenApi)]
#[openapi(
paths(vertex_compatibility),
components(schemas(VertexInstance,
VertexRequest,
VertexResponse))
)]
struct
VertextApiDoc
;
let
doc
=
{
// avoid `mut` if possible
#[cfg(feature
=
"google"
)]
{
// limiting mutability to the smallest scope necessary
let
mut
doc
=
doc
;
doc
.merge
(
VertextApiDoc
::
openapi
());
doc
}
#[cfg(not(feature
=
"google"
))]
ApiDoc
::
openapi
()
};
// Configure Swagger UI
let
swagger_ui
=
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()
);
let
swagger_ui
=
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
doc
);
// Define base and health routes
let
base_routes
=
Router
::
new
()
...
...
@@ -953,6 +1069,7 @@ pub async fn run(
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate_stream"
,
post
(
generate_stream
))
.route
(
"/v1/chat/completions"
,
post
(
chat_completions
))
.route
(
"/vertex"
,
post
(
vertex_compatibility
))
.route
(
"/tokenize"
,
post
(
tokenize
))
.route
(
"/health"
,
get
(
health
))
.route
(
"/ping"
,
get
(
health
))
...
...
@@ -969,10 +1086,27 @@ pub async fn run(
ComputeType
(
std
::
env
::
var
(
"COMPUTE_TYPE"
)
.unwrap_or
(
"gpu+optimized"
.to_string
()));
// Combine routes and layers
let
app
=
Router
::
new
()
let
mut
app
=
Router
::
new
()
.merge
(
swagger_ui
)
.merge
(
base_routes
)
.merge
(
aws_sagemaker_route
)
.merge
(
aws_sagemaker_route
);
#[cfg(feature
=
"google"
)]
{
tracing
::
info!
(
"Built with `google` feature"
);
tracing
::
info!
(
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
);
if
let
Ok
(
env_predict_route
)
=
std
::
env
::
var
(
"AIP_PREDICT_ROUTE"
)
{
app
=
app
.route
(
&
env_predict_route
,
post
(
vertex_compatibility
));
}
if
let
Ok
(
env_health_route
)
=
std
::
env
::
var
(
"AIP_HEALTH_ROUTE"
)
{
app
=
app
.route
(
&
env_health_route
,
get
(
health
));
}
}
// add layers after routes
app
=
app
.layer
(
Extension
(
info
))
.layer
(
Extension
(
health_ext
.clone
()))
.layer
(
Extension
(
compat_return_full_text
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment