Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
55bd4fed
Unverified
Commit
55bd4fed
authored
Mar 09, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 09, 2023
Browse files
feat(router): add best_of parameter (#117)
parent
e8bfe199
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
367 additions
and
117 deletions
+367
-117
docs/openapi.json
docs/openapi.json
+90
-7
launcher/src/main.rs
launcher/src/main.rs
+5
-0
router/src/infer.rs
router/src/infer.rs
+38
-0
router/src/lib.rs
router/src/lib.rs
+30
-4
router/src/main.rs
router/src/main.rs
+4
-0
router/src/server.rs
router/src/server.rs
+151
-104
router/src/validation.rs
router/src/validation.rs
+49
-2
No files found.
docs/openapi.json
View file @
55bd4fed
...
...
@@ -210,13 +210,63 @@
},
"components"
:
{
"schemas"
:
{
"BestOfSequence"
:
{
"type"
:
"object"
,
"required"
:
[
"generated_text"
,
"finish_reason"
,
"generated_tokens"
,
"prefill"
,
"tokens"
],
"properties"
:
{
"finish_reason"
:
{
"$ref"
:
"#/components/schemas/FinishReason"
},
"generated_text"
:
{
"type"
:
"string"
,
"example"
:
"test"
},
"generated_tokens"
:
{
"type"
:
"integer"
,
"format"
:
"int32"
,
"example"
:
1
},
"prefill"
:
{
"type"
:
"array"
,
"items"
:
{
"$ref"
:
"#/components/schemas/PrefillToken"
}
},
"seed"
:
{
"type"
:
"integer"
,
"format"
:
"int64"
,
"example"
:
42
,
"nullable"
:
true
},
"tokens"
:
{
"type"
:
"array"
,
"items"
:
{
"$ref"
:
"#/components/schemas/Token"
}
}
}
},
"Details"
:
{
"type"
:
"object"
,
"required"
:
[
"finish_reason"
,
"generated_tokens"
"generated_tokens"
,
"prefill"
,
"tokens"
],
"properties"
:
{
"best_of_sequences"
:
{
"type"
:
"array"
,
"items"
:
{
"$ref"
:
"#/components/schemas/BestOfSequence"
}
},
"finish_reason"
:
{
"$ref"
:
"#/components/schemas/FinishReason"
},
...
...
@@ -234,7 +284,8 @@
"seed"
:
{
"type"
:
"integer"
,
"format"
:
"int64"
,
"example"
:
42
"example"
:
42
,
"nullable"
:
true
},
"tokens"
:
{
"type"
:
"array"
,
...
...
@@ -247,11 +298,15 @@
"ErrorResponse"
:
{
"type"
:
"object"
,
"required"
:
[
"error"
"error"
,
"error_type"
],
"properties"
:
{
"error"
:
{
"type"
:
"string"
},
"error_type"
:
{
"type"
:
"string"
}
}
},
...
...
@@ -266,6 +321,13 @@
"GenerateParameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"best_of"
:
{
"type"
:
"integer"
,
"default"
:
"null"
,
"example"
:
1
,
"nullable"
:
true
,
"exclusiveMinimum"
:
0.0
},
"details"
:
{
"type"
:
"boolean"
,
"default"
:
"true"
...
...
@@ -292,12 +354,17 @@
},
"return_full_text"
:
{
"type"
:
"boolean"
,
"default"
:
"None"
,
"example"
:
false
"default"
:
"null"
,
"example"
:
false
,
"nullable"
:
true
},
"seed"
:
{
"type"
:
"integer"
,
"format"
:
"int64"
"format"
:
"int64"
,
"default"
:
"null"
,
"example"
:
"null"
,
"nullable"
:
true
,
"exclusiveMinimum"
:
0.0
},
"stop"
:
{
"type"
:
"array"
,
...
...
@@ -334,6 +401,21 @@
"maximum"
:
1.0
,
"exclusiveMinimum"
:
0.0
},
"truncate"
:
{
"type"
:
"integer"
,
"default"
:
"null"
,
"example"
:
"null"
,
"nullable"
:
true
},
"typical_p"
:
{
"type"
:
"number"
,
"format"
:
"float"
,
"default"
:
"null"
,
"example"
:
0.95
,
"nullable"
:
true
,
"maximum"
:
1.0
,
"exclusiveMinimum"
:
0.0
},
"watermark"
:
{
"type"
:
"boolean"
,
"default"
:
"false"
,
...
...
@@ -414,7 +496,8 @@
"seed"
:
{
"type"
:
"integer"
,
"format"
:
"int64"
,
"example"
:
42
"example"
:
42
,
"nullable"
:
true
}
}
},
...
...
launcher/src/main.rs
View file @
55bd4fed
...
...
@@ -31,6 +31,8 @@ struct Args {
quantize
:
bool
,
#[clap(default_value
=
"128"
,
long,
env)]
max_concurrent_requests
:
usize
,
#[clap(default_value
=
"2"
,
long,
env)]
max_best_of
:
usize
,
#[clap(default_value
=
"4"
,
long,
env)]
max_stop_sequences
:
usize
,
#[clap(default_value
=
"1000"
,
long,
env)]
...
...
@@ -86,6 +88,7 @@ fn main() -> ExitCode {
num_shard
,
quantize
,
max_concurrent_requests
,
max_best_of
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
...
...
@@ -363,6 +366,8 @@ fn main() -> ExitCode {
"text-generation-router"
.to_string
(),
"--max-concurrent-requests"
.to_string
(),
max_concurrent_requests
.to_string
(),
"--max-best-of"
.to_string
(),
max_best_of
.to_string
(),
"--max-stop-sequences"
.to_string
(),
max_stop_sequences
.to_string
(),
"--max-input-length"
.to_string
(),
...
...
router/src/infer.rs
View file @
55bd4fed
...
...
@@ -2,6 +2,7 @@
use
crate
::
validation
::{
Validation
,
ValidationError
};
use
crate
::{
Entry
,
Queue
,
Token
};
use
crate
::{
GenerateRequest
,
PrefillToken
};
use
futures
::
future
::
try_join_all
;
use
nohash_hasher
::
IntMap
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
...
...
@@ -177,6 +178,43 @@ impl Infer {
Err
(
err
)
}
}
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self))]
pub
(
crate
)
async
fn
generate_best_of
(
&
self
,
request
:
GenerateRequest
,
best_of
:
usize
,
)
->
Result
<
(
InferResponse
,
Vec
<
InferResponse
>
),
InferError
>
{
// validate best_of parameter separately
let
best_of
=
self
.validation
.validate_best_of
(
best_of
)
?
;
// create multiple generate requests
let
mut
infer_responses
:
Vec
<
InferResponse
>
=
try_join_all
((
0
..
best_of
)
.map
(|
_
|
self
.generate
(
request
.clone
())))
.await
?
;
// get the sequence with the highest log probability per token
let
mut
max_index
=
0
;
let
mut
max_logprob
:
f32
=
f32
::
MIN
;
for
(
i
,
response
)
in
infer_responses
.iter
()
.enumerate
()
{
// mean logprobs of the generated tokens
let
sequence_logprob
=
response
.tokens
.iter
()
.map
(|
token
|
token
.logprob
)
.sum
::
<
f32
>
()
/
response
.tokens
.len
()
as
f32
;
// set best sequence
if
sequence_logprob
>
max_logprob
{
max_index
=
i
;
max_logprob
=
sequence_logprob
;
}
}
let
best_response
=
infer_responses
.remove
(
max_index
);
Ok
((
best_response
,
infer_responses
))
}
}
/// Batching logic
...
...
router/src/lib.rs
View file @
55bd4fed
...
...
@@ -12,6 +12,9 @@ use validation::Validation;
#[derive(Clone,
Debug,
Deserialize,
ToSchema)]
pub
(
crate
)
struct
GenerateParameters
{
#[serde(default)]
#[schema(exclusive_minimum
=
0
,
nullable
=
true
,
default
=
"null"
,
example
=
1
)]
pub
best_of
:
Option
<
usize
>
,
#[serde(default)]
#[schema(
exclusive_minimum
=
0.0
,
...
...
@@ -56,13 +59,13 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum
=
0
,
exclusive_maximum
=
512
,
default
=
"20"
)]
pub
max_new_tokens
:
u32
,
#[serde(default)]
#[schema(default
=
"null"
,
example
=
false
)]
#[schema(
nullable
=
true
,
default
=
"null"
,
example
=
false
)]
pub
return_full_text
:
Option
<
bool
>
,
#[serde(default)]
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
#[schema(default
=
"null"
,
example
=
"null"
)]
#[schema(
nullable
=
true
,
default
=
"null"
,
example
=
"null"
)]
pub
truncate
:
Option
<
usize
>
,
#[serde(default)]
#[schema(default
=
"false"
,
example
=
true
)]
...
...
@@ -71,6 +74,12 @@ pub(crate) struct GenerateParameters {
#[schema(default
=
"true"
)]
pub
details
:
bool
,
#[serde(default)]
#[schema(
exclusive_minimum
=
0
,
nullable
=
true
,
default
=
"null"
,
example
=
"null"
)]
pub
seed
:
Option
<
u64
>
,
}
...
...
@@ -80,6 +89,7 @@ fn default_max_new_tokens() -> u32 {
fn
default_parameters
()
->
GenerateParameters
{
GenerateParameters
{
best_of
:
None
,
temperature
:
None
,
repetition_penalty
:
None
,
top_k
:
None
,
...
...
@@ -158,16 +168,32 @@ pub(crate) enum FinishReason {
StopSequence
,
}
#[derive(Serialize,
ToSchema)]
pub
(
crate
)
struct
BestOfSequence
{
#[schema(example
=
"test"
)]
pub
generated_text
:
String
,
#[schema(example
=
"length"
)]
pub
finish_reason
:
FinishReason
,
#[schema(example
=
1
)]
pub
generated_tokens
:
u32
,
#[schema(nullable
=
true
,
example
=
42
)]
pub
seed
:
Option
<
u64
>
,
pub
prefill
:
Vec
<
PrefillToken
>
,
pub
tokens
:
Vec
<
Token
>
,
}
#[derive(Serialize,
ToSchema)]
pub
(
crate
)
struct
Details
{
#[schema(example
=
"length"
)]
pub
finish_reason
:
FinishReason
,
#[schema(example
=
1
)]
pub
generated_tokens
:
u32
,
#[schema(example
=
42
)]
#[schema(
nullable
=
true
,
example
=
42
)]
pub
seed
:
Option
<
u64
>
,
pub
prefill
:
Vec
<
PrefillToken
>
,
pub
tokens
:
Vec
<
Token
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
best_of_sequences
:
Option
<
Vec
<
BestOfSequence
>>
,
}
#[derive(Serialize,
ToSchema)]
...
...
@@ -184,7 +210,7 @@ pub(crate) struct StreamDetails {
pub
finish_reason
:
FinishReason
,
#[schema(example
=
1
)]
pub
generated_tokens
:
u32
,
#[schema(example
=
42
)]
#[schema(
nullable
=
true
,
example
=
42
)]
pub
seed
:
Option
<
u64
>
,
}
...
...
router/src/main.rs
View file @
55bd4fed
...
...
@@ -23,6 +23,8 @@ use tracing_subscriber::{EnvFilter, Layer};
struct
Args
{
#[clap(default_value
=
"128"
,
long,
env)]
max_concurrent_requests
:
usize
,
#[clap(default_value
=
"2"
,
long,
env)]
max_best_of
:
usize
,
#[clap(default_value
=
"4"
,
long,
env)]
max_stop_sequences
:
usize
,
#[clap(default_value
=
"1000"
,
long,
env)]
...
...
@@ -55,6 +57,7 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration
let
Args
{
max_concurrent_requests
,
max_best_of
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
...
...
@@ -145,6 +148,7 @@ fn main() -> Result<(), std::io::Error> {
server
::
run
(
compat_return_full_text
,
max_concurrent_requests
,
max_best_of
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
...
...
router/src/server.rs
View file @
55bd4fed
/// HTTP Server logic
use
crate
::
infer
::{
InferError
,
InferStreamResponse
};
use
crate
::
infer
::{
InferError
,
InferResponse
,
InferStreamResponse
};
use
crate
::
validation
::
ValidationError
;
use
crate
::{
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
PrefillToken
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
BestOfSequence
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
PrefillToken
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
...
...
@@ -64,6 +65,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
.generate
(
GenerateRequest
{
inputs
:
"liveness"
.to_string
(),
parameters
:
GenerateParameters
{
best_of
:
None
,
temperature
:
None
,
repetition_penalty
:
None
,
top_k
:
None
,
...
...
@@ -128,17 +130,51 @@ async fn generate(
let
details
=
req
.0
.parameters.details
;
// Inference
let
response
=
infer
.generate
(
req
.0
)
.await
?
;
let
(
response
,
best_of_responses
)
=
match
req
.0
.parameters.best_of
{
Some
(
best_of
)
if
best_of
>
1
=>
{
let
(
response
,
best_of_responses
)
=
infer
.generate_best_of
(
req
.0
,
best_of
)
.await
?
;
(
response
,
Some
(
best_of_responses
))
}
_
=>
(
infer
.generate
(
req
.0
)
.await
?
,
None
),
};
// Token details
let
details
=
match
details
{
true
=>
Some
(
Details
{
finish_reason
:
FinishReason
::
from
(
response
.generated_text.finish_reason
),
generated_tokens
:
response
.generated_text.generated_tokens
,
prefill
:
response
.prefill
,
tokens
:
response
.tokens
,
seed
:
response
.generated_text.seed
,
}),
true
=>
{
// convert best_of_responses
let
best_of_sequences
=
best_of_responses
.map
(|
responses
:
Vec
<
InferResponse
>
|
{
responses
.into_iter
()
.map
(|
response
:
InferResponse
|
{
// Add prompt if return_full_text
let
mut
output_text
=
response
.generated_text.text
;
if
let
Some
(
prompt
)
=
&
add_prompt
{
output_text
=
prompt
.clone
()
+
&
output_text
;
}
BestOfSequence
{
generated_text
:
output_text
,
finish_reason
:
FinishReason
::
from
(
response
.generated_text.finish_reason
,
),
generated_tokens
:
response
.generated_text.generated_tokens
,
prefill
:
response
.prefill
,
tokens
:
response
.tokens
,
seed
:
response
.generated_text.seed
,
}
})
.collect
()
});
Some
(
Details
{
finish_reason
:
FinishReason
::
from
(
response
.generated_text.finish_reason
),
generated_tokens
:
response
.generated_text.generated_tokens
,
prefill
:
response
.prefill
,
tokens
:
response
.tokens
,
seed
:
response
.generated_text.seed
,
best_of_sequences
,
})
}
false
=>
None
,
};
...
...
@@ -279,107 +315,115 @@ async fn generate_stream(
}
let
details
=
req
.0
.parameters.details
;
match
infer
.generate_stream
(
req
.0
)
.instrument
(
info_span!
(
parent
:
&
span
,
"async_stream"
))
.await
{
Ok
(
mut
response_stream
)
=>
{
// Server-Sent Event stream
while
let
Some
(
response
)
=
response_stream
.next
()
.await
{
match
response
{
Ok
(
response
)
=>
{
match
response
{
// Prefill is ignored
InferStreamResponse
::
Prefill
(
_
)
=>
{}
// Yield event for every new token
InferStreamResponse
::
Token
(
token
)
=>
{
// StreamResponse
let
stream_token
=
StreamResponse
{
token
,
generated_text
:
None
,
details
:
None
,
};
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
())
}
// Yield event for last token and compute timings
InferStreamResponse
::
End
{
token
,
generated_text
,
start
,
queued
,
}
=>
{
// Token details
let
details
=
match
details
{
true
=>
Some
(
StreamDetails
{
finish_reason
:
FinishReason
::
from
(
generated_text
.finish_reason
),
generated_tokens
:
generated_text
.generated_tokens
,
seed
:
generated_text
.seed
,
}),
false
=>
None
,
};
// Timings
let
total_time
=
start_time
.elapsed
();
let
validation_time
=
queued
-
start_time
;
let
queue_time
=
start
-
queued
;
let
inference_time
=
Instant
::
now
()
-
start
;
let
time_per_token
=
inference_time
/
generated_text
.generated_tokens
;
// Tracing metadata
span
.record
(
"total_time"
,
format!
(
"{total_time:?}"
));
span
.record
(
"validation_time"
,
format!
(
"{validation_time:?}"
));
span
.record
(
"queue_time"
,
format!
(
"{queue_time:?}"
));
span
.record
(
"inference_time"
,
format!
(
"{inference_time:?}"
));
span
.record
(
"time_per_token"
,
format!
(
"{time_per_token:?}"
));
span
.record
(
"seed"
,
format!
(
"{:?}"
,
generated_text
.seed
));
tracing
::
info!
(
parent
:
&
span
,
"Output: {}"
,
generated_text
.text
);
// Metrics
metrics
::
increment_counter!
(
"tgi_request_success"
);
metrics
::
histogram!
(
"tgi_request_duration"
,
total_time
);
metrics
::
histogram!
(
"tgi_request_validation_duration"
,
validation_time
);
metrics
::
histogram!
(
"tgi_request_queue_duration"
,
queue_time
);
metrics
::
histogram!
(
"tgi_request_inference_duration"
,
inference_time
);
metrics
::
histogram!
(
"tgi_request_mean_time_per_token_duration"
,
time_per_token
);
metrics
::
histogram!
(
"tgi_request_generated_tokens"
,
generated_text
.generated_tokens
as
f64
);
// StreamResponse
end_reached
=
true
;
let
mut
output_text
=
generated_text
.text
;
if
let
Some
(
prompt
)
=
add_prompt
{
output_text
=
prompt
+
&
output_text
;
let
best_of
=
req
.0
.parameters.best_of
.unwrap_or
(
1
);
if
best_of
==
1
{
match
infer
.generate_stream
(
req
.0
)
.instrument
(
info_span!
(
parent
:
&
span
,
"async_stream"
))
.await
{
Ok
(
mut
response_stream
)
=>
{
// Server-Sent Event stream
while
let
Some
(
response
)
=
response_stream
.next
()
.await
{
match
response
{
Ok
(
response
)
=>
{
match
response
{
// Prefill is ignored
InferStreamResponse
::
Prefill
(
_
)
=>
{}
// Yield event for every new token
InferStreamResponse
::
Token
(
token
)
=>
{
// StreamResponse
let
stream_token
=
StreamResponse
{
token
,
generated_text
:
None
,
details
:
None
,
};
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
())
}
let
stream_token
=
StreamResponse
{
// Yield event for last token and compute timings
Infer
StreamResponse
::
End
{
token
,
generated_text
:
Some
(
output_text
),
details
};
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
());
break
;
generated_text
,
start
,
queued
,
}
=>
{
// Token details
let
details
=
match
details
{
true
=>
Some
(
StreamDetails
{
finish_reason
:
FinishReason
::
from
(
generated_text
.finish_reason
),
generated_tokens
:
generated_text
.generated_tokens
,
seed
:
generated_text
.seed
,
}),
false
=>
None
,
};
// Timings
let
total_time
=
start_time
.elapsed
();
let
validation_time
=
queued
-
start_time
;
let
queue_time
=
start
-
queued
;
let
inference_time
=
Instant
::
now
()
-
start
;
let
time_per_token
=
inference_time
/
generated_text
.generated_tokens
;
// Tracing metadata
span
.record
(
"total_time"
,
format!
(
"{total_time:?}"
));
span
.record
(
"validation_time"
,
format!
(
"{validation_time:?}"
));
span
.record
(
"queue_time"
,
format!
(
"{queue_time:?}"
));
span
.record
(
"inference_time"
,
format!
(
"{inference_time:?}"
));
span
.record
(
"time_per_token"
,
format!
(
"{time_per_token:?}"
));
span
.record
(
"seed"
,
format!
(
"{:?}"
,
generated_text
.seed
));
tracing
::
info!
(
parent
:
&
span
,
"Output: {}"
,
generated_text
.text
);
// Metrics
metrics
::
increment_counter!
(
"tgi_request_success"
);
metrics
::
histogram!
(
"tgi_request_duration"
,
total_time
);
metrics
::
histogram!
(
"tgi_request_validation_duration"
,
validation_time
);
metrics
::
histogram!
(
"tgi_request_queue_duration"
,
queue_time
);
metrics
::
histogram!
(
"tgi_request_inference_duration"
,
inference_time
);
metrics
::
histogram!
(
"tgi_request_mean_time_per_token_duration"
,
time_per_token
);
metrics
::
histogram!
(
"tgi_request_generated_tokens"
,
generated_text
.generated_tokens
as
f64
);
// StreamResponse
end_reached
=
true
;
let
mut
output_text
=
generated_text
.text
;
if
let
Some
(
prompt
)
=
add_prompt
{
output_text
=
prompt
+
&
output_text
;
}
let
stream_token
=
StreamResponse
{
token
,
generated_text
:
Some
(
output_text
),
details
};
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
());
break
;
}
}
}
}
// yield error
Err
(
err
)
=>
{
error
=
true
;
yield
Ok
(
Event
::
from
(
err
))
;
break
;
// yield error
Err
(
err
)
=>
{
error
=
true
;
yield
Ok
(
Event
::
from
(
err
))
;
break
;
}
}
}
},
// yield error
Err
(
err
)
=>
{
error
=
true
;
yield
Ok
(
Event
::
from
(
err
));
}
},
// yield error
Err
(
err
)
=>
{
error
=
true
;
}
// Check if generation reached the end
// Skip if we already sent an error
if
!
end_reached
&&
!
error
{
let
err
=
InferError
::
IncompleteGeneration
;
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"incomplete"
);
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
));
}
}
// Check if generation reached the end
// Skip if we already sent an error
if
!
end_reached
&&
!
error
{
let
err
=
InferError
::
IncompleteGeneration
;
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"incomplete"
);
}
else
{
let
err
=
InferError
::
from
(
ValidationError
::
BestOfStream
);
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"validation"
);
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
));
}
...
...
@@ -404,6 +448,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
pub
async
fn
run
(
compat_return_full_text
:
bool
,
max_concurrent_requests
:
usize
,
max_best_of
:
usize
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
...
...
@@ -430,6 +475,7 @@ pub async fn run(
PrefillToken,
Token,
GenerateResponse,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
...
...
@@ -454,6 +500,7 @@ pub async fn run(
let
validation
=
Validation
::
new
(
validation_workers
,
tokenizer
,
max_best_of
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
...
...
router/src/validation.rs
View file @
55bd4fed
use
crate
::
validation
::
ValidationError
::
EmptyInput
;
use
crate
::
validation
::
ValidationError
::
{
BestOfSampling
,
BestOfSeed
,
EmptyInput
}
;
/// Payload validation logic
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
rand
::
rngs
::
ThreadRng
;
...
...
@@ -13,6 +13,9 @@ use tracing::{instrument, Span};
/// Validation
#[derive(Debug,
Clone)]
pub
struct
Validation
{
/// maximum value for the best_of parameter
#[allow(dead_code)]
max_best_of
:
usize
,
/// Channel to communicate with the background validation task
sender
:
mpsc
::
UnboundedSender
<
ValidationRequest
>
,
}
...
...
@@ -21,6 +24,7 @@ impl Validation {
pub
(
crate
)
fn
new
(
workers
:
usize
,
tokenizer
:
Tokenizer
,
max_best_of
:
usize
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
...
...
@@ -39,6 +43,7 @@ impl Validation {
));
Self
{
max_best_of
,
sender
:
validation_sender
,
}
}
...
...
@@ -60,6 +65,20 @@ impl Validation {
// Unwrap is safe here
receiver
.await
.unwrap
()
}
/// Validate the best_of parameter
#[instrument(skip_all)]
pub
(
crate
)
fn
validate_best_of
(
&
self
,
best_of
:
usize
)
->
Result
<
usize
,
ValidationError
>
{
if
self
.max_best_of
==
1
&&
best_of
!=
1
{
return
Err
(
ValidationError
::
BestOfDisabled
);
}
if
best_of
>
self
.max_best_of
{
return
Err
(
ValidationError
::
BestOf
(
self
.max_best_of
,
best_of
));
}
Ok
(
best_of
)
}
}
/// Validation task
...
...
@@ -150,6 +169,7 @@ fn validate(
rng
:
&
mut
ThreadRng
,
)
->
Result
<
ValidGenerateRequest
,
ValidationError
>
{
let
GenerateParameters
{
best_of
,
temperature
,
repetition_penalty
,
top_k
,
...
...
@@ -164,6 +184,18 @@ fn validate(
..
}
=
request
.parameters
;
// sampling must be true when best_of > 1
let
best_of
=
best_of
.unwrap_or
(
1
);
let
sampling
=
do_sample
||
temperature
.is_some
()
||
top_k
.is_some
()
||
top_p
.is_some
()
||
typical_p
.is_some
();
if
best_of
>
1
&&
!
sampling
{
return
Err
(
BestOfSampling
);
}
let
temperature
=
temperature
.unwrap_or
(
1.0
);
if
temperature
<=
0.0
{
return
Err
(
ValidationError
::
Temperature
);
...
...
@@ -217,7 +249,12 @@ fn validate(
// If seed is None, assign a random one
let
seed
=
match
seed
{
None
=>
rng
.gen
(),
Some
(
seed
)
=>
seed
,
Some
(
seed
)
=>
{
if
best_of
>
1
{
return
Err
(
BestOfSeed
);
}
seed
}
};
// Check if inputs is empty
...
...
@@ -307,6 +344,16 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error,
Debug)]
pub
enum
ValidationError
{
#[error(
"`best_of` must be > 0 and <= {0}. Given: {1}"
)]
BestOf
(
usize
,
usize
),
#[error(
"`best_of` != 1 is not allowed for this endpoint"
)]
BestOfDisabled
,
#[error(
"you must use sampling when `best_of` is > 1"
)]
BestOfSampling
,
#[error(
"`seed` must not be set when `best_of` > 1"
)]
BestOfSeed
,
#[error(
"`best_of` != 1 is not supported when streaming tokens"
)]
BestOfStream
,
#[error(
"`temperature` must be strictly positive"
)]
Temperature
,
#[error(
"`repetition_penalty` must be strictly positive"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment