Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
21340f24
Unverified
Commit
21340f24
authored
Feb 27, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 27, 2023
Browse files
feat(router): add legacy route for api-inference support (#88)
parent
65e2f162
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
7 deletions
+47
-7
router/src/lib.rs
router/src/lib.rs
+23
-3
router/src/server.rs
router/src/server.rs
+24
-4
No files found.
router/src/lib.rs
View file @
21340f24
...
@@ -47,7 +47,7 @@ pub(crate) struct GenerateParameters {
...
@@ -47,7 +47,7 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum
=
0
,
exclusive_maximum
=
512
,
default
=
"20"
)]
#[schema(exclusive_minimum
=
0
,
exclusive_maximum
=
512
,
default
=
"20"
)]
pub
max_new_tokens
:
u32
,
pub
max_new_tokens
:
u32
,
#[serde(default)]
#[serde(default)]
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
pub
stop
:
Vec
<
String
>
,
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
#[serde(default)]
#[schema(default
=
"true"
)]
#[schema(default
=
"true"
)]
...
@@ -86,13 +86,33 @@ pub(crate) struct GenerateRequest {
...
@@ -86,13 +86,33 @@ pub(crate) struct GenerateRequest {
pub
parameters
:
GenerateParameters
,
pub
parameters
:
GenerateParameters
,
}
}
#[derive(Clone,
Debug,
Deserialize,
ToSchema)]
pub
(
crate
)
struct
CompatGenerateRequest
{
#[schema(example
=
"My name is Olivier and I"
)]
pub
inputs
:
String
,
#[serde(default
=
"default_parameters"
)]
pub
parameters
:
GenerateParameters
,
#[serde(default)]
#[allow(dead_code)]
pub
stream
:
bool
,
}
impl
From
<
CompatGenerateRequest
>
for
GenerateRequest
{
fn
from
(
req
:
CompatGenerateRequest
)
->
Self
{
Self
{
inputs
:
req
.inputs
,
parameters
:
req
.parameters
,
}
}
}
#[derive(Debug,
Serialize,
ToSchema)]
#[derive(Debug,
Serialize,
ToSchema)]
pub
struct
PrefillToken
{
pub
struct
PrefillToken
{
#[schema(example
=
0
)]
#[schema(example
=
0
)]
id
:
u32
,
id
:
u32
,
#[schema(example
=
"test"
)]
#[schema(example
=
"test"
)]
text
:
String
,
text
:
String
,
#[schema(nullable
=
true
,
example
=
-
0.34
)]
#[schema(nullable
=
true
,
example
=
-
0.34
)]
logprob
:
f32
,
logprob
:
f32
,
}
}
...
@@ -102,7 +122,7 @@ pub struct Token {
...
@@ -102,7 +122,7 @@ pub struct Token {
id
:
u32
,
id
:
u32
,
#[schema(example
=
"test"
)]
#[schema(example
=
"test"
)]
text
:
String
,
text
:
String
,
#[schema(nullable
=
true
,
example
=
-
0.34
)]
#[schema(nullable
=
true
,
example
=
-
0.34
)]
logprob
:
f32
,
logprob
:
f32
,
#[schema(example
=
"false"
)]
#[schema(example
=
"false"
)]
special
:
bool
,
special
:
bool
,
...
...
router/src/server.rs
View file @
21340f24
/// HTTP Server logic
/// HTTP Server logic
use
crate
::
infer
::{
InferError
,
InferStreamResponse
};
use
crate
::
infer
::{
InferError
,
InferStreamResponse
};
use
crate
::{
use
crate
::{
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
Infer
,
PrefillToken
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
GenerateRequest
,
GenerateResponse
,
Infer
,
PrefillToken
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
};
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
...
@@ -25,6 +26,25 @@ use tracing::{info_span, instrument, Instrument};
...
@@ -25,6 +26,25 @@ use tracing::{info_span, instrument, Instrument};
use
utoipa
::
OpenApi
;
use
utoipa
::
OpenApi
;
use
utoipa_swagger_ui
::
SwaggerUi
;
use
utoipa_swagger_ui
::
SwaggerUi
;
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async
fn
compat_generate
(
infer
:
Extension
<
Infer
>
,
req
:
Json
<
CompatGenerateRequest
>
,
)
->
Result
<
impl
IntoResponse
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
// switch on stream
let
req
=
req
.0
;
if
req
.stream
{
Ok
(
generate_stream
(
infer
,
Json
(
req
.into
()))
.await
.into_response
())
}
else
{
let
(
headers
,
generation
)
=
generate
(
infer
,
Json
(
req
.into
()))
.await
?
;
// wrap generation inside a Vec to match api-inference
Ok
((
headers
,
Json
(
vec!
[
generation
.0
]))
.into_response
())
}
}
/// Health check method
/// Health check method
#[instrument(skip(infer))]
#[instrument(skip(infer))]
async
fn
health
(
infer
:
Extension
<
Infer
>
)
->
Result
<
(),
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
async
fn
health
(
infer
:
Extension
<
Infer
>
)
->
Result
<
(),
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
...
@@ -84,7 +104,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
...
@@ -84,7 +104,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
async
fn
generate
(
async
fn
generate
(
infer
:
Extension
<
Infer
>
,
infer
:
Extension
<
Infer
>
,
req
:
Json
<
GenerateRequest
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
Result
<
impl
Into
Response
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
)
->
Result
<
(
HeaderMap
,
Json
<
Generate
Response
>
)
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
let
span
=
tracing
::
Span
::
current
();
let
span
=
tracing
::
Span
::
current
();
let
start_time
=
Instant
::
now
();
let
start_time
=
Instant
::
now
();
...
@@ -404,7 +424,7 @@ pub async fn run(
...
@@ -404,7 +424,7 @@ pub async fn run(
// Create router
// Create router
let
app
=
Router
::
new
()
let
app
=
Router
::
new
()
.merge
(
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()))
.merge
(
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()))
.route
(
"/"
,
post
(
generate
))
.route
(
"/"
,
post
(
compat_
generate
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate_stream"
,
post
(
generate_stream
))
.route
(
"/generate_stream"
,
post
(
generate_stream
))
.route
(
"/"
,
get
(
health
))
.route
(
"/"
,
get
(
health
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment