Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0ac184ce
Unverified
Commit
0ac184ce
authored
Feb 24, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 24, 2023
Browse files
feat(server): add special token bool (#85)
parent
4b1c9720
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
134 additions
and
72 deletions
+134
-72
launcher/tests/bloom_560m.json
launcher/tests/bloom_560m.json
+47
-27
launcher/tests/integration_tests.rs
launcher/tests/integration_tests.rs
+2
-0
launcher/tests/mt0_base.json
launcher/tests/mt0_base.json
+45
-25
proto/generate.proto
proto/generate.proto
+3
-1
router/src/infer.rs
router/src/infer.rs
+4
-3
router/src/lib.rs
router/src/lib.rs
+13
-1
router/src/server.rs
router/src/server.rs
+7
-6
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+4
-1
server/text_generation/models/model.py
server/text_generation/models/model.py
+1
-0
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+6
-8
server/text_generation/models/types.py
server/text_generation/models/types.py
+2
-0
No files found.
launcher/tests/bloom_560m.json
View file @
0ac184ce
{
"generated_text"
:
".get(
\"
action
\"
);
\n
if (action == null) {
\n
throw new RuntimeException"
,
"details"
:
{
"finish_reason"
:
"length"
,
"generated_tokens"
:
20
,
"seed"
:
null
,
"prefill"
:
[
{
"id"
:
10264
,
"
logprob"
:
null
,
"
text"
:
"Test"
"
text"
:
"Test"
,
"
logprob"
:
null
},
{
"id"
:
8821
,
"
logprob"
:
-11.894989
,
"
text"
:
" request"
"
text"
:
" request"
,
"
logprob"
:
-11.894989
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
17
,
"text"
:
"."
,
"logprob"
:
-1.8267672
,
"
text"
:
"."
"
special"
:
false
},
{
"id"
:
1587
,
"text"
:
"get"
,
"logprob"
:
-2.4674969
,
"
text"
:
"get"
"
special"
:
false
},
{
"id"
:
11
,
"text"
:
"("
,
"logprob"
:
-1.906001
,
"
text"
:
"("
"
special"
:
false
},
{
"id"
:
5
,
"text"
:
"
\"
"
,
"logprob"
:
-1.2279545
,
"
text"
:
"
\"
"
"
special"
:
false
},
{
"id"
:
4899
,
"text"
:
"action"
,
"logprob"
:
-4.170299
,
"
text"
:
"action"
"
special"
:
false
},
{
"id"
:
5
,
"text"
:
"
\"
"
,
"logprob"
:
-0.32478866
,
"
text"
:
"
\"
"
"
special"
:
false
},
{
"id"
:
12
,
"text"
:
")"
,
"logprob"
:
-1.0773665
,
"
text"
:
")"
"
special"
:
false
},
{
"id"
:
30
,
"text"
:
";"
,
"logprob"
:
-0.27640742
,
"
text"
:
";"
"
special"
:
false
},
{
"id"
:
837
,
"text"
:
"
\n
"
,
"logprob"
:
-1.6970354
,
"
text"
:
"
\n
"
"
special"
:
false
},
{
"id"
:
1320
,
"text"
:
" if"
,
"logprob"
:
-1.4495516
,
"
text"
:
" if"
"
special"
:
false
},
{
"id"
:
375
,
"text"
:
" ("
,
"logprob"
:
-0.23609057
,
"
text"
:
" ("
"
special"
:
false
},
{
"id"
:
4899
,
"text"
:
"action"
,
"logprob"
:
-1.1916996
,
"
text"
:
"action"
"
special"
:
false
},
{
"id"
:
3535
,
"text"
:
" =="
,
"logprob"
:
-0.8918753
,
"
text"
:
" =="
"
special"
:
false
},
{
"id"
:
5109
,
"text"
:
" null"
,
"logprob"
:
-0.3933342
,
"
text"
:
" null"
"
special"
:
false
},
{
"id"
:
12
,
"text"
:
")"
,
"logprob"
:
-0.43212673
,
"
text"
:
")"
"
special"
:
false
},
{
"id"
:
731
,
"text"
:
" {"
,
"logprob"
:
-0.17702064
,
"
text"
:
" {"
"
special"
:
false
},
{
"id"
:
1260
,
"text"
:
"
\n
"
,
"logprob"
:
-0.07027565
,
"
text"
:
"
\n
"
"
special"
:
false
},
{
"id"
:
10519
,
"text"
:
" throw"
,
"logprob"
:
-1.3915029
,
"
text"
:
" throw"
"
special"
:
false
},
{
"id"
:
2084
,
"text"
:
" new"
,
"logprob"
:
-0.04201372
,
"
text"
:
" new"
"
special"
:
false
},
{
"id"
:
150858
,
"text"
:
" RuntimeException"
,
"logprob"
:
-1.7329919
,
"
text"
:
" RuntimeException"
"
special"
:
false
}
]
},
"generated_text"
:
".get(
\"
action
\"
);
\n
if (action == null) {
\n
throw new RuntimeException"
}
}
\ No newline at end of file
launcher/tests/integration_tests.rs
View file @
0ac184ce
...
...
@@ -14,6 +14,7 @@ pub struct Token {
id
:
u32
,
text
:
String
,
logprob
:
Option
<
f32
>
,
special
:
bool
,
}
#[derive(Deserialize)]
...
...
@@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
{
assert_eq!
(
token
.id
,
expected_token
.id
);
assert_eq!
(
token
.text
,
expected_token
.text
);
assert_eq!
(
token
.special
,
expected_token
.special
);
if
let
Some
(
logprob
)
=
token
.logprob
{
let
expected_logprob
=
expected_token
.logprob
.unwrap
();
assert_float_eq!
(
logprob
,
expected_logprob
,
abs
<=
0.001
);
...
...
launcher/tests/mt0_base.json
View file @
0ac184ce
{
"generated_text"
:
"
\"\"\"
Test the contents of the contents of the contents.
\"\"\"
test_test"
,
"details"
:
{
"finish_reason"
:
"length"
,
"generated_tokens"
:
20
,
"seed"
:
null
,
"prefill"
:
[
{
"id"
:
0
,
"
logprob"
:
null
,
"
text"
:
"<pad>"
"
text"
:
"<pad>"
,
"
logprob"
:
null
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
259
,
"text"
:
""
,
"logprob"
:
-1.3656927
,
"
text"
:
""
"
special"
:
false
},
{
"id"
:
215100
,
"text"
:
"
\"\"\"
"
,
"logprob"
:
-2.6551573
,
"
text"
:
"
\"\"\"
"
"
special"
:
false
},
{
"id"
:
46138
,
"text"
:
"Test"
,
"logprob"
:
-1.8059857
,
"
text"
:
"Test"
"
special"
:
false
},
{
"id"
:
287
,
"text"
:
"the"
,
"logprob"
:
-1.2102449
,
"
text"
:
"the"
"
special"
:
false
},
{
"id"
:
259
,
"text"
:
""
,
"logprob"
:
-1.6057279
,
"
text"
:
""
"
special"
:
false
},
{
"id"
:
49076
,
"text"
:
"contents"
,
"logprob"
:
-3.6060903
,
"
text"
:
"contents"
"
special"
:
false
},
{
"id"
:
304
,
"text"
:
"of"
,
"logprob"
:
-0.5270343
,
"
text"
:
"of"
"
special"
:
false
},
{
"id"
:
287
,
"text"
:
"the"
,
"logprob"
:
-0.62522805
,
"
text"
:
"the"
"
special"
:
false
},
{
"id"
:
259
,
"text"
:
""
,
"logprob"
:
-1.4069618
,
"
text"
:
""
"
special"
:
false
},
{
"id"
:
49076
,
"text"
:
"contents"
,
"logprob"
:
-2.621994
,
"
text"
:
"contents"
"
special"
:
false
},
{
"id"
:
304
,
"text"
:
"of"
,
"logprob"
:
-1.3172221
,
"
text"
:
"of"
"
special"
:
false
},
{
"id"
:
287
,
"text"
:
"the"
,
"logprob"
:
-0.3501925
,
"
text"
:
"the"
"
special"
:
false
},
{
"id"
:
259
,
"text"
:
""
,
"logprob"
:
-0.7219573
,
"
text"
:
""
"
special"
:
false
},
{
"id"
:
49076
,
"text"
:
"contents"
,
"logprob"
:
-1.0494149
,
"
text"
:
"contents"
"
special"
:
false
},
{
"id"
:
260
,
"text"
:
"."
,
"logprob"
:
-1.0803378
,
"
text"
:
"."
"
special"
:
false
},
{
"id"
:
259
,
"text"
:
""
,
"logprob"
:
-0.32933083
,
"
text"
:
""
"
special"
:
false
},
{
"id"
:
215100
,
"text"
:
"
\"\"\"
"
,
"logprob"
:
-0.11268901
,
"
text"
:
"
\"\"\"
"
"
special"
:
false
},
{
"id"
:
2978
,
"text"
:
"test"
,
"logprob"
:
-1.5846587
,
"
text"
:
"test"
"
special"
:
false
},
{
"id"
:
290
,
"text"
:
"_"
,
"logprob"
:
-0.49796978
,
"
text"
:
"_"
"
special"
:
false
},
{
"id"
:
4125
,
"text"
:
"test"
,
"logprob"
:
-2.0026445
,
"
text"
:
"test"
"
special"
:
false
}
]
},
"generated_text"
:
"
\"\"\"
Test the contents of the contents of the contents.
\"\"\"
test_test"
}
}
\ No newline at end of file
proto/generate.proto
View file @
0ac184ce
...
...
@@ -108,8 +108,10 @@ message Generation {
float
token_logprob
=
4
;
/// Text
string
token_text
=
5
;
/// Is it a special token
bool
token_is_special
=
6
;
/// Complete generated text
GeneratedText
generated_text
=
6
;
GeneratedText
generated_text
=
7
;
}
message
PrefillRequest
{
...
...
router/src/infer.rs
View file @
0ac184ce
/// Batching and inference logic
use
crate
::
validation
::{
Validation
,
ValidationError
};
use
crate
::
GenerateRequest
;
use
crate
::{
Entry
,
Queue
,
Token
};
use
crate
::{
GenerateRequest
,
PrefillToken
};
use
nohash_hasher
::
IntMap
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
...
...
@@ -138,7 +138,7 @@ impl Infer {
.into_iter
()
.zip
(
tokens
.logprobs
.into_iter
())
.zip
(
tokens
.texts
.into_iter
())
.map
(|((
id
,
logprob
),
text
)|
Token
{
id
,
text
,
logprob
})
.map
(|((
id
,
logprob
),
text
)|
Prefill
Token
{
id
,
text
,
logprob
})
.collect
();
}
// Push last token
...
...
@@ -372,6 +372,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
id
:
generation
.token_id
,
text
:
generation
.token_text
,
logprob
:
generation
.token_logprob
,
special
:
generation
.token_is_special
,
};
if
let
Some
(
generated_text
)
=
generation
.generated_text
{
...
...
@@ -420,7 +421,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)]
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
prefill
:
Vec
<
Token
>
,
pub
(
crate
)
prefill
:
Vec
<
Prefill
Token
>
,
pub
(
crate
)
tokens
:
Vec
<
Token
>
,
pub
(
crate
)
generated_text
:
GeneratedText
,
pub
(
crate
)
queued
:
Instant
,
...
...
router/src/lib.rs
View file @
0ac184ce
...
...
@@ -86,6 +86,16 @@ pub(crate) struct GenerateRequest {
pub
parameters
:
GenerateParameters
,
}
#[derive(Debug,
Serialize,
ToSchema)]
pub
struct
PrefillToken
{
#[schema(example
=
0
)]
id
:
u32
,
#[schema(example
=
"test"
)]
text
:
String
,
#[schema(nullable
=
true
,
example
=
-
0.34
)]
logprob
:
f32
,
}
#[derive(Debug,
Serialize,
ToSchema)]
pub
struct
Token
{
#[schema(example
=
0
)]
...
...
@@ -94,6 +104,8 @@ pub struct Token {
text
:
String
,
#[schema(nullable
=
true
,
example
=
-
0.34
)]
logprob
:
f32
,
#[schema(example
=
"false"
)]
special
:
bool
,
}
#[derive(Serialize,
ToSchema)]
...
...
@@ -116,7 +128,7 @@ pub(crate) struct Details {
pub
generated_tokens
:
u32
,
#[schema(example
=
42
)]
pub
seed
:
Option
<
u64
>
,
pub
prefill
:
Option
<
Vec
<
Token
>>
,
pub
prefill
:
Option
<
Vec
<
Prefill
Token
>>
,
pub
tokens
:
Option
<
Vec
<
Token
>>
,
}
...
...
router/src/server.rs
View file @
0ac184ce
...
...
@@ -2,7 +2,7 @@
use
crate
::
infer
::{
InferError
,
InferStreamResponse
};
use
crate
::{
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
Infer
,
PrefillToken
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
...
...
@@ -255,11 +255,11 @@ async fn generate_stream(
let
time_per_token
=
inference_time
/
generated_text
.generated_tokens
;
// Tracing metadata
span
.record
(
"total_time"
,
format!
(
"{
:?}"
,
total_time
));
span
.record
(
"validation_time"
,
format!
(
"{
:?}"
,
validation_time
));
span
.record
(
"queue_time"
,
format!
(
"{
:?}"
,
queue_time
));
span
.record
(
"inference_time"
,
format!
(
"{
:?}"
,
inference_time
));
span
.record
(
"time_per_token"
,
format!
(
"{
:?}"
,
time_per_token
));
span
.record
(
"total_time"
,
format!
(
"{total_time
:?}"
));
span
.record
(
"validation_time"
,
format!
(
"{validation_time
:?}"
));
span
.record
(
"queue_time"
,
format!
(
"{queue_time
:?}"
));
span
.record
(
"inference_time"
,
format!
(
"{inference_time
:?}"
));
span
.record
(
"time_per_token"
,
format!
(
"{time_per_token
:?}"
));
span
.record
(
"seed"
,
format!
(
"{:?}"
,
generated_text
.seed
));
tracing
::
info!
(
parent
:
&
span
,
"Output: {}"
,
generated_text
.text
);
...
...
@@ -349,6 +349,7 @@ pub async fn run(
schemas(
GenerateRequest,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
Details,
...
...
server/text_generation/models/causal_lm.py
View file @
0ac184ce
...
...
@@ -172,7 +172,9 @@ class CausalLMBatch(Batch):
# and to remove unused allocated space
left_offset
=
max_sequence_length
-
batch
.
max_sequence_length
batch_left_offset
=
(
batch
.
attention_mask
.
shape
[
1
]
-
batch
.
max_sequence_length
-
batch
.
padding_right_offset
batch
.
attention_mask
.
shape
[
1
]
-
batch
.
max_sequence_length
-
batch
.
padding_right_offset
)
attention_mask
[
start_index
:
end_index
,
...
...
@@ -443,6 +445,7 @@ class CausalLM(Model):
next_token_id_squeezed
,
next_token_logprob
,
next_token_text
,
next_token_id_squeezed
in
self
.
all_special_ids
,
generated_text
,
)
...
...
server/text_generation/models/model.py
View file @
0ac184ce
...
...
@@ -12,6 +12,7 @@ B = TypeVar("B", bound=Batch)
class
Model
(
ABC
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
):
self
.
tokenizer
=
tokenizer
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
device
=
device
@
property
...
...
server/text_generation/models/seq2seq_lm.py
View file @
0ac184ce
...
...
@@ -205,7 +205,8 @@ class Seq2SeqLMBatch(Batch):
else
:
batch_left_offset
=
(
batch
.
decoder_attention_mask
.
shape
[
1
]
-
batch
.
max_decoder_input_length
-
batch
.
padding_right_offset
-
batch
.
max_decoder_input_length
-
batch
.
padding_right_offset
)
decoder_attention_mask
[
start_index
:
end_index
,
...
...
@@ -494,14 +495,10 @@ class Seq2SeqLM(Model):
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
prefill_token_ids
=
decoder_input_ids
[
-
new_decoder_input_length
:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
prefill_token_ids
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
[
float
(
"nan"
)],
prefill_texts
[
self
.
tokenizer
.
bos_token_id
],
[
float
(
"nan"
)],
[
self
.
tokenizer
.
bos_token
],
)
else
:
prefill_tokens
=
None
...
...
@@ -512,6 +509,7 @@ class Seq2SeqLM(Model):
next_token_id_squeezed
,
next_token_logprob
,
next_token_text
,
next_token_id_squeezed
in
self
.
all_special_ids
,
generated_text
,
)
...
...
server/text_generation/models/types.py
View file @
0ac184ce
...
...
@@ -73,6 +73,7 @@ class Generation:
token_id
:
int
token_logprob
:
float
token_text
:
str
token_is_special
:
bool
generated_text
:
Optional
[
GeneratedText
]
def
to_pb
(
self
)
->
generate_pb2
.
Generation
:
...
...
@@ -84,6 +85,7 @@ class Generation:
token_id
=
self
.
token_id
,
token_logprob
=
self
.
token_logprob
,
token_text
=
self
.
token_text
,
token_is_special
=
self
.
token_is_special
,
generated_text
=
self
.
generated_text
.
to_pb
()
if
self
.
generated_text
is
not
None
else
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment