Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ebecc061
Unverified
Commit
ebecc061
authored
Jan 26, 2024
by
Nicolas Patry
Committed by
GitHub
Jan 26, 2024
Browse files
Update the docs to include newer models. (#1492)
parent
50a20a83
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
14 deletions
+41
-14
docs/openapi.json
docs/openapi.json
+1
-1
router/src/lib.rs
router/src/lib.rs
+25
-6
router/src/server.rs
router/src/server.rs
+15
-7
No files found.
docs/openapi.json
View file @
ebecc061
This diff is collapsed.
Click to expand it.
router/src/lib.rs
View file @
ebecc061
...
@@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
...
@@ -188,18 +188,20 @@ fn default_parameters() -> GenerateParameters {
}
}
}
}
#[derive(Clone,
Deserialize,
Serialize)]
#[derive(Clone,
Deserialize,
Serialize
,
ToSchema
)]
pub
(
crate
)
struct
ChatCompletion
{
pub
(
crate
)
struct
ChatCompletion
{
pub
id
:
String
,
pub
id
:
String
,
pub
object
:
String
,
pub
object
:
String
,
#[schema(example
=
"1706270835"
)]
pub
created
:
u64
,
pub
created
:
u64
,
#[schema(example
=
"mistralai/Mistral-7B-Instruct-v0.2"
)]
pub
model
:
String
,
pub
model
:
String
,
pub
system_fingerprint
:
String
,
pub
system_fingerprint
:
String
,
pub
choices
:
Vec
<
ChatCompletionComplete
>
,
pub
choices
:
Vec
<
ChatCompletionComplete
>
,
pub
usage
:
Usage
,
pub
usage
:
Usage
,
}
}
#[derive(Clone,
Deserialize,
Serialize)]
#[derive(Clone,
Deserialize,
Serialize
,
ToSchema
)]
pub
(
crate
)
struct
ChatCompletionComplete
{
pub
(
crate
)
struct
ChatCompletionComplete
{
pub
index
:
u32
,
pub
index
:
u32
,
pub
message
:
Message
,
pub
message
:
Message
,
...
@@ -248,17 +250,19 @@ impl ChatCompletion {
...
@@ -248,17 +250,19 @@ impl ChatCompletion {
}
}
}
}
#[derive(Clone,
Deserialize,
Serialize)]
#[derive(Clone,
Deserialize,
Serialize
,
ToSchema
)]
pub
(
crate
)
struct
ChatCompletionChunk
{
pub
(
crate
)
struct
ChatCompletionChunk
{
pub
id
:
String
,
pub
id
:
String
,
pub
object
:
String
,
pub
object
:
String
,
#[schema(example
=
"1706270978"
)]
pub
created
:
u64
,
pub
created
:
u64
,
#[schema(example
=
"mistralai/Mistral-7B-Instruct-v0.2"
)]
pub
model
:
String
,
pub
model
:
String
,
pub
system_fingerprint
:
String
,
pub
system_fingerprint
:
String
,
pub
choices
:
Vec
<
ChatCompletionChoice
>
,
pub
choices
:
Vec
<
ChatCompletionChoice
>
,
}
}
#[derive(Clone,
Deserialize,
Serialize)]
#[derive(Clone,
Deserialize,
Serialize
,
ToSchema
)]
pub
(
crate
)
struct
ChatCompletionChoice
{
pub
(
crate
)
struct
ChatCompletionChoice
{
pub
index
:
u32
,
pub
index
:
u32
,
pub
delta
:
ChatCompletionDelta
,
pub
delta
:
ChatCompletionDelta
,
...
@@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
...
@@ -266,9 +270,11 @@ pub(crate) struct ChatCompletionChoice {
pub
finish_reason
:
Option
<
String
>
,
pub
finish_reason
:
Option
<
String
>
,
}
}
#[derive(Clone,
Debug,
Deserialize,
Serialize)]
#[derive(Clone,
Debug,
Deserialize,
Serialize
,
ToSchema
)]
pub
(
crate
)
struct
ChatCompletionDelta
{
pub
(
crate
)
struct
ChatCompletionDelta
{
#[schema(example
=
"user"
)]
pub
role
:
String
,
pub
role
:
String
,
#[schema(example
=
"What is Deep Learning?"
)]
pub
content
:
String
,
pub
content
:
String
,
}
}
...
@@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
...
@@ -311,7 +317,7 @@ fn default_request_messages() -> Vec<Message> {
#[derive(Clone,
Deserialize,
ToSchema,
Serialize)]
#[derive(Clone,
Deserialize,
ToSchema,
Serialize)]
pub
(
crate
)
struct
ChatRequest
{
pub
(
crate
)
struct
ChatRequest
{
/// UNUSED
/// UNUSED
#[schema(example
=
"
bigscience/blomm-560m
"
)]
#[schema(example
=
"
mistralai/Mistral-7B-Instruct-v0.2
"
)]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub
model
:
String
,
/* NOTE: UNUSED */
pub
model
:
String
,
/* NOTE: UNUSED */
...
@@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
...
@@ -322,6 +328,7 @@ pub(crate) struct ChatRequest {
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[serde(default)]
#[schema(example
=
"1.0"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
pub
frequency_penalty
:
Option
<
f32
>
,
/// UNUSED
/// UNUSED
...
@@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
...
@@ -336,28 +343,33 @@ pub(crate) struct ChatRequest {
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
/// output token returned in the content of message.
#[serde(default)]
#[serde(default)]
#[schema(example
=
"false"
)]
pub
logprobs
:
Option
<
bool
>
,
pub
logprobs
:
Option
<
bool
>
,
/// UNUSED
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[serde(default)]
#[schema(example
=
"5"
)]
pub
top_logprobs
:
Option
<
u32
>
,
pub
top_logprobs
:
Option
<
u32
>
,
/// The maximum number of tokens that can be generated in the chat completion.
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[serde(default)]
#[schema(example
=
"32"
)]
pub
max_tokens
:
Option
<
u32
>
,
pub
max_tokens
:
Option
<
u32
>
,
/// UNUSED
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[serde(default)]
#[schema(nullable
=
true
,
example
=
"2"
)]
pub
n
:
Option
<
u32
>
,
pub
n
:
Option
<
u32
>
,
/// UNUSED
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[serde(default)]
#[schema(nullable
=
true
,
example
=
0.1
)]
pub
presence_penalty
:
Option
<
f32
>
,
pub
presence_penalty
:
Option
<
f32
>
,
#[serde(default
=
"bool::default"
)]
#[serde(default
=
"bool::default"
)]
...
@@ -371,11 +383,13 @@ pub(crate) struct ChatRequest {
...
@@ -371,11 +383,13 @@ pub(crate) struct ChatRequest {
///
///
/// We generally recommend altering this or `top_p` but not both.
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[serde(default)]
#[schema(nullable
=
true
,
example
=
1.0
)]
pub
temperature
:
Option
<
f32
>
,
pub
temperature
:
Option
<
f32
>
,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[serde(default)]
#[schema(nullable
=
true
,
example
=
0.95
)]
pub
top_p
:
Option
<
f32
>
,
pub
top_p
:
Option
<
f32
>
,
}
}
...
@@ -458,6 +472,7 @@ pub struct SimpleToken {
...
@@ -458,6 +472,7 @@ pub struct SimpleToken {
#[derive(Serialize,
ToSchema)]
#[derive(Serialize,
ToSchema)]
#[serde(rename_all(serialize
=
"snake_case"
))]
#[serde(rename_all(serialize
=
"snake_case"
))]
#[schema(example
=
"Length"
)]
pub
(
crate
)
enum
FinishReason
{
pub
(
crate
)
enum
FinishReason
{
#[schema(rename
=
"length"
)]
#[schema(rename
=
"length"
)]
Length
,
Length
,
...
@@ -518,6 +533,10 @@ pub(crate) struct GenerateResponse {
...
@@ -518,6 +533,10 @@ pub(crate) struct GenerateResponse {
pub
details
:
Option
<
Details
>
,
pub
details
:
Option
<
Details
>
,
}
}
#[derive(Serialize,
ToSchema)]
#[serde(transparent)]
pub
(
crate
)
struct
TokenizeResponse
(
Vec
<
SimpleToken
>
);
#[derive(Serialize,
ToSchema)]
#[derive(Serialize,
ToSchema)]
pub
(
crate
)
struct
StreamDetails
{
pub
(
crate
)
struct
StreamDetails
{
#[schema(example
=
"length"
)]
#[schema(example
=
"length"
)]
...
...
router/src/server.rs
View file @
ebecc061
...
@@ -3,10 +3,10 @@ use crate::health::Health;
...
@@ -3,10 +3,10 @@ use crate::health::Health;
use
crate
::
infer
::{
InferError
,
InferResponse
,
InferStreamResponse
};
use
crate
::
infer
::{
InferError
,
InferResponse
,
InferStreamResponse
};
use
crate
::
validation
::
ValidationError
;
use
crate
::
validation
::
ValidationError
;
use
crate
::{
use
crate
::{
BestOfSequence
,
ChatCompletion
,
ChatCompletionCh
unk
,
Chat
Request
,
CompatGenerateRequest
,
BestOfSequence
,
ChatCompletion
,
ChatCompletionCh
oice
,
Chat
CompletionChunk
,
ChatCompletionDelta
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
ChatRequest
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
HubModelInfo
,
HubTokenizerConfig
,
Infer
,
Info
,
PrefillToken
,
SimpleToken
,
StreamDetails
,
GenerateRequest
,
GenerateResponse
,
HubModelInfo
,
HubTokenizerConfig
,
Infer
,
Info
,
Message
,
StreamResponse
,
Token
,
Validation
,
PrefillToken
,
SimpleToken
,
StreamDetails
,
StreamResponse
,
Token
,
TokenizeResponse
,
Validation
,
};
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
...
@@ -677,7 +677,7 @@ async fn chat_completions(
...
@@ -677,7 +677,7 @@ async fn chat_completions(
post,
post,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/tokenize"
,
path
=
"/tokenize"
,
request_body
=
Tokeniz
eRequest,
request_body
=
Generat
eRequest,
responses(
responses(
(status
=
200
,
description
=
"Tokenized ids"
,
body
=
TokenizeResponse),
(status
=
200
,
description
=
"Tokenized ids"
,
body
=
TokenizeResponse),
(status
=
404
,
description
=
"No tokenizer found"
,
body
=
ErrorResponse,
(status
=
404
,
description
=
"No tokenizer found"
,
body
=
ErrorResponse,
...
@@ -688,7 +688,7 @@ async fn chat_completions(
...
@@ -688,7 +688,7 @@ async fn chat_completions(
async
fn
tokenize
(
async
fn
tokenize
(
Extension
(
infer
):
Extension
<
Infer
>
,
Extension
(
infer
):
Extension
<
Infer
>
,
Json
(
req
):
Json
<
GenerateRequest
>
,
Json
(
req
):
Json
<
GenerateRequest
>
,
)
->
Result
<
Response
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
)
->
Result
<
Json
<
Tokenize
Response
>
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
let
input
=
req
.inputs
.clone
();
let
input
=
req
.inputs
.clone
();
let
encoding
=
infer
.tokenize
(
req
)
.await
?
;
let
encoding
=
infer
.tokenize
(
req
)
.await
?
;
if
let
Some
(
encoding
)
=
encoding
{
if
let
Some
(
encoding
)
=
encoding
{
...
@@ -706,7 +706,7 @@ async fn tokenize(
...
@@ -706,7 +706,7 @@ async fn tokenize(
}
}
})
})
.collect
();
.collect
();
Ok
(
Json
(
t
oken
s
)
.into_r
esponse
())
Ok
(
Json
(
T
oken
izeR
esponse
(
tokens
)
))
}
else
{
}
else
{
Err
((
Err
((
StatusCode
::
NOT_FOUND
,
StatusCode
::
NOT_FOUND
,
...
@@ -774,10 +774,18 @@ pub async fn run(
...
@@ -774,10 +774,18 @@ pub async fn run(
Info,
Info,
CompatGenerateRequest,
CompatGenerateRequest,
GenerateRequest,
GenerateRequest,
ChatRequest,
Message,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletion,
GenerateParameters,
GenerateParameters,
PrefillToken,
PrefillToken,
Token,
Token,
GenerateResponse,
GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence,
BestOfSequence,
Details,
Details,
FinishReason,
FinishReason,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment