Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
e8bfe199
Unverified
Commit
e8bfe199
authored
Mar 09, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 09, 2023
Browse files
feat(router): support left truncation (#115)
closes #111
parent
c0795de2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
43 deletions
+72
-43
router/src/lib.rs
router/src/lib.rs
+5
-1
router/src/server.rs
router/src/server.rs
+1
-0
router/src/validation.rs
router/src/validation.rs
+66
-42
No files found.
router/src/lib.rs
View file @
e8bfe199
...
...
@@ -56,12 +56,15 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum
=
0
,
exclusive_maximum
=
512
,
default
=
"20"
)]
pub
max_new_tokens
:
u32
,
#[serde(default)]
#[schema(default
=
"
None
"
,
example
=
false
)]
#[schema(default
=
"
null
"
,
example
=
false
)]
pub
return_full_text
:
Option
<
bool
>
,
#[serde(default)]
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
#[schema(default
=
"null"
,
example
=
"null"
)]
pub
truncate
:
Option
<
usize
>
,
#[serde(default)]
#[schema(default
=
"false"
,
example
=
true
)]
pub
watermark
:
bool
,
#[serde(default)]
...
...
@@ -86,6 +89,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens
:
default_max_new_tokens
(),
return_full_text
:
None
,
stop
:
Vec
::
new
(),
truncate
:
None
,
watermark
:
false
,
details
:
false
,
seed
:
None
,
...
...
router/src/server.rs
View file @
e8bfe199
...
...
@@ -73,6 +73,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
max_new_tokens
:
1
,
return_full_text
:
None
,
stop
:
Vec
::
new
(),
truncate
:
None
,
watermark
:
false
,
details
:
false
,
seed
:
None
,
...
...
router/src/validation.rs
View file @
e8bfe199
...
...
@@ -6,6 +6,7 @@ use rand::Rng;
use
text_generation_client
::{
NextTokenChooserParameters
,
StoppingCriteriaParameters
};
use
thiserror
::
Error
;
use
tokenizers
::
tokenizer
::
Tokenizer
;
use
tokenizers
::
TruncationDirection
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tracing
::{
instrument
,
Span
};
...
...
@@ -157,6 +158,7 @@ fn validate(
do_sample
,
max_new_tokens
,
stop
:
stop_sequences
,
truncate
,
seed
,
watermark
,
..
...
...
@@ -223,21 +225,45 @@ fn validate(
return
Err
(
EmptyInput
);
}
// Check if truncate is strictly positive and less than max_input_length
let
truncate
=
truncate
.map
(|
value
|
{
if
value
==
0
||
value
>
max_input_length
{
return
Err
(
ValidationError
::
Truncate
(
max_input_length
,
value
));
}
Ok
(
Some
(
value
))
})
.unwrap_or
(
Ok
(
None
))
?
;
// Get the number of tokens in the input
match
tokenizer
.encode
(
request
.inputs
.clone
(),
true
)
{
Ok
(
encoding
)
=>
{
let
input_length
=
encoding
.len
();
let
total_tokens
=
input_length
+
max_new_tokens
as
usize
;
let
mut
encoding
=
tokenizer
.encode
(
request
.inputs
.clone
(),
true
)
.map_err
(|
err
|
ValidationError
::
Tokenizer
(
err
.to_string
()))
?
;
let
(
inputs
,
input_length
)
=
if
let
Some
(
truncate
)
=
truncate
{
// truncate encoding and decode new inputs
encoding
.truncate
(
truncate
,
0
,
TruncationDirection
::
Left
);
let
inputs
=
tokenizer
.decode
(
Vec
::
from
(
encoding
.get_ids
()),
false
)
.map_err
(|
err
|
ValidationError
::
Tokenizer
(
err
.to_string
()))
?
;
(
inputs
,
encoding
.len
())
}
else
{
(
request
.inputs
,
encoding
.len
())
};
if
input_length
>
max_input_length
{
Err
(
ValidationError
::
InputLength
(
max_input_length
,
input_length
))
}
else
if
total_tokens
>
max_total_tokens
{
Err
(
ValidationError
::
MaxTotalTokens
(
return
Err
(
ValidationError
::
InputLength
(
max_input_length
,
input_length
));
}
let
total_tokens
=
input_length
+
max_new_tokens
as
usize
;
if
total_tokens
>
max_total_tokens
{
return
Err
(
ValidationError
::
MaxTotalTokens
(
max_total_tokens
,
input_length
,
max_new_tokens
,
))
}
else
{
));
}
// Return ValidGenerateRequest
let
parameters
=
NextTokenChooserParameters
{
temperature
,
...
...
@@ -258,15 +284,11 @@ fn validate(
metrics
::
histogram!
(
"tgi_request_max_new_tokens"
,
max_new_tokens
as
f64
);
Ok
(
ValidGenerateRequest
{
inputs
:
request
.
inputs
,
inputs
,
input_length
:
input_length
as
u32
,
parameters
,
stopping_parameters
,
})
}
}
Err
(
err
)
=>
Err
(
ValidationError
::
Tokenizer
(
err
.to_string
())),
}
}
type
ValidationRequest
=
(
...
...
@@ -293,6 +315,8 @@ pub enum ValidationError {
TopP
,
#[error(
"`top_k` must be strictly positive"
)]
TopK
,
#[error(
"`truncate` must be strictly positive and less than {0}. Given: {1}"
)]
Truncate
(
usize
,
usize
),
#[error(
"`typical_p` must be > 0.0 and < 1.0"
)]
TypicalP
,
#[error(
"`max_new_tokens` must be strictly positive"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment